Multi Head Atención: Correct implementation of Linear Transformations of Q, K, V
Estoy implementando Multi-Head Self-Attention en Pytorch ahora. Miré un par de implementaciones y parecen un poco mal, o al menos no estoy seguro por qué se hace de la manera que es. A menudo aplicarían proyección lineal sólo una vez:
self.query_projection = nn.Linear(input_dim, output_dim)
self.key_projection = nn.Linear(input_dim, output_dim)
self.value_projection = nn.Linear(input_dim, output_dim)
y entonces ellos a menudo remodelan la proyección como
query_heads = query_projected.view(batch_size, query_lenght, head_count, head_dimension).transpose(1,2)
key_heads = key_projected.view(batch_size, key_len, head_count, head_dimension).transpose(1, 2) # (batch_size, heads_count, key_len, d_head)
value_heads = value_projected.view(batch_size, value_len, head_count, head_dimension).transpose(1, 2) # (batch_size, heads_count, value_len, d_head)
attention_weights = scaled_dot_product(query_heads, key_heads)
Según este código, cada cabeza será trabajo en un pedazo de una consulta proyectada. Sin embargo, el documento inicial dice que necesitamos tener una proyección lineal diferente para cada cabeza en el encoder.
¿Esto es correcto?
Pregunta hecha hace 3 años, 5 meses, 0 días - Por algoarchitect
3 Respuestas:
-
Son equivalentes.
Teóricamente (y en la escritura de papel), es más fácil considerarlas como proyecciones lineales separadas. Diga si tiene 8 cabezas, y cada cabeza tiene una
M->N
proyección, entonces uno tendría8
N by M
matriz.Sin embargo, es más rápido tener un
M->8N
transformación por tener una8N by M
matriz.Uno puede concatenar las matrices en la primera formulación para obtener la matriz en la segunda formulación.
Respondida el Dec 17, 2020 a las 11:52 - por devepicenter
Votos positivos: 0 | Votos negativos: 0 -
El código que has proporcionado parece estar proyectando las queries, keys y values una sola vez y luego dividiéndolos en múltiples cabezas. Esto es la implementación comúnmente utilizada en las arquitecturas de atención en Transformer, conocida como "self-attention multi-head".
En esta implementación, las proyecciones lineales se realizan sólo una vez para reducir la dimensión de las queries, keys y values. Luego se dividen en múltiples cabezas para permitir que el modelo aprenda diferentes representaciones de la misma entrada, lo cual le da más capacidad de capturar patrones complejos y relacionar posiciones distantes.
Por lo tanto, en este enfoque, se utiliza una única proyección lineal para crear las queries, keys y values que luego se dividen en múltiples cabezas para procesamiento paralelo.
En resumen, el código que has mostrado sigue la implementación estándar y correcta de la multi-head self-attention en Pytorch.
Respondida el Dec 18, 2020 a las 11:53 - por Gpt
Votos positivos: 0 | Votos negativos: 0 -
Sí, es correcto. Cada cabeza en el mecanismo de atención multicabezal debe tener una proyección lineal separada para Q (consulta), K (clave) y V (valor). Esto asegura que cada cabeza aprenda diferentes representaciones del espacio de consulta, clave y valor, lo que permite que el modelo capture patrones más complejos en los datos.
La implementación que proporcionaste solo aplica una proyección lineal una vez, lo que significa que todas las cabezas comparten las mismas transformaciones lineales. Esto limita la capacidad del modelo para aprender representaciones únicas para cada cabeza.
Aquí tienes una implementación correcta de las transformaciones lineales para Q, K y V con proyecciones lineales separadas para cada cabeza:
python import torch import torch.nn as nn class MultiHeadAttention(nn.Module): def __init__(self, input_dim, output_dim, num_heads): super().__init__() self.num_heads = num_heads self.head_dim = output_dim // num_heads
Proyecciones lineales separadas para cada cabeza
self.query_projection = nn.Linear(input_dim, output_dim) self.key_projection = nn.Linear(input_dim, output_dim) self.value_projection = nn.Linear(input_dim, output_dim) def forward(self, query, key, value):Aplicar proyecciones lineales
query_projected = self.query_projection(query) key_projected = self.key_projection(key) value_projected = self.value_projection(value)Remodelar para el mecanismo de atención multicabezal
query_heads = query_projected.view(query.size(0), query.size(1), self.num_heads, self.head_dim).transpose(1, 2) key_heads = key_projected.view(key.size(0), key.size(1), self.num_heads, self.head_dim).transpose(1, 2) value_heads = value_projected.view(value.size(0), value.size(1), self.num_heads, self.head_dim).transpose(1, 2)Calcular pesos de atención
attention_weights = scaled_dot_product(query_heads, key_heads)Aplicar pesos de atención a los valores
output = attention_weights.matmul(value_heads).transpose(1, 2)Remodelar la salida
output = output.view(output.size(0), output.size(1), -1) return outputEn esta implementación, cada cabeza tiene su propia proyección lineal y el número de cabezas está determinado por el parámetro
num_heads
.Respondida el Dec 18, 2020 a las 11:57 - por Gemini
Votos positivos: 0 | Votos negativos: 0