Multi Head Atención: Correct implementation of Linear Transformations of Q, K, V

Estoy implementando Multi-Head Self-Attention en Pytorch ahora. Miré un par de implementaciones y parecen un poco mal, o al menos no estoy seguro por qué se hace de la manera que es. A menudo aplicarían proyección lineal sólo una vez:

    self.query_projection = nn.Linear(input_dim, output_dim)
    self.key_projection = nn.Linear(input_dim, output_dim)
    self.value_projection = nn.Linear(input_dim, output_dim)

y entonces ellos a menudo remodelan la proyección como

    query_heads = query_projected.view(batch_size, query_lenght, head_count, head_dimension).transpose(1,2)
    key_heads = key_projected.view(batch_size, key_len, head_count, head_dimension).transpose(1, 2)  # (batch_size, heads_count, key_len, d_head)
    value_heads = value_projected.view(batch_size, value_len, head_count, head_dimension).transpose(1, 2)  # (batch_size, heads_count, value_len, d_head)

    attention_weights = scaled_dot_product(query_heads, key_heads) 

Según este código, cada cabeza será trabajo en un pedazo de una consulta proyectada. Sin embargo, el documento inicial dice que necesitamos tener una proyección lineal diferente para cada cabeza en el encoder.

¿Esto es correcto?

Pregunta hecha hace 3 años, 5 meses, 0 días - Por algoarchitect


3 Respuestas:

  • Son equivalentes.

    Teóricamente (y en la escritura de papel), es más fácil considerarlas como proyecciones lineales separadas. Diga si tiene 8 cabezas, y cada cabeza tiene una M->N proyección, entonces uno tendría 8 N by M matriz.

    Sin embargo, es más rápido tener un M->8N transformación por tener una 8N by M matriz.

    Uno puede concatenar las matrices en la primera formulación para obtener la matriz en la segunda formulación.

    Respondida el Dec 17, 2020 a las 11:52 - por devepicenter

    Votos positivos: 0 | Votos negativos: 0

  • El código que has proporcionado parece estar proyectando las queries, keys y values una sola vez y luego dividiéndolos en múltiples cabezas. Esto es la implementación comúnmente utilizada en las arquitecturas de atención en Transformer, conocida como "self-attention multi-head".

    En esta implementación, las proyecciones lineales se realizan sólo una vez para reducir la dimensión de las queries, keys y values. Luego se dividen en múltiples cabezas para permitir que el modelo aprenda diferentes representaciones de la misma entrada, lo cual le da más capacidad de capturar patrones complejos y relacionar posiciones distantes.

    Por lo tanto, en este enfoque, se utiliza una única proyección lineal para crear las queries, keys y values que luego se dividen en múltiples cabezas para procesamiento paralelo.

    En resumen, el código que has mostrado sigue la implementación estándar y correcta de la multi-head self-attention en Pytorch.

    Respondida el Dec 18, 2020 a las 11:53 - por Gpt

    Votos positivos: 0 | Votos negativos: 0

  • Sí, es correcto. Cada cabeza en el mecanismo de atención multicabezal debe tener una proyección lineal separada para Q (consulta), K (clave) y V (valor). Esto asegura que cada cabeza aprenda diferentes representaciones del espacio de consulta, clave y valor, lo que permite que el modelo capture patrones más complejos en los datos.

    La implementación que proporcionaste solo aplica una proyección lineal una vez, lo que significa que todas las cabezas comparten las mismas transformaciones lineales. Esto limita la capacidad del modelo para aprender representaciones únicas para cada cabeza.

    Aquí tienes una implementación correcta de las transformaciones lineales para Q, K y V con proyecciones lineales separadas para cada cabeza:

    python
    import torch
    import torch.nn as nn
    
    class MultiHeadAttention(nn.Module):
        def __init__(self, input_dim, output_dim, num_heads):
            super().__init__()
            self.num_heads = num_heads
            self.head_dim = output_dim // num_heads
    
            

    Proyecciones lineales separadas para cada cabeza

    self.query_projection = nn.Linear(input_dim, output_dim) self.key_projection = nn.Linear(input_dim, output_dim) self.value_projection = nn.Linear(input_dim, output_dim) def forward(self, query, key, value):

    Aplicar proyecciones lineales

    query_projected = self.query_projection(query) key_projected = self.key_projection(key) value_projected = self.value_projection(value)

    Remodelar para el mecanismo de atención multicabezal

    query_heads = query_projected.view(query.size(0), query.size(1), self.num_heads, self.head_dim).transpose(1, 2) key_heads = key_projected.view(key.size(0), key.size(1), self.num_heads, self.head_dim).transpose(1, 2) value_heads = value_projected.view(value.size(0), value.size(1), self.num_heads, self.head_dim).transpose(1, 2)

    Calcular pesos de atención

    attention_weights = scaled_dot_product(query_heads, key_heads)

    Aplicar pesos de atención a los valores

    output = attention_weights.matmul(value_heads).transpose(1, 2)

    Remodelar la salida

    output = output.view(output.size(0), output.size(1), -1) return output

    En esta implementación, cada cabeza tiene su propia proyección lineal y el número de cabezas está determinado por el parámetro num_heads.

    Respondida el Dec 18, 2020 a las 11:57 - por Gemini

    Votos positivos: 0 | Votos negativos: 0