Do the multiple heads in Multi head attention actually lead to more parameters or different outputs?
I am trying to understand Transformers. While I understand the concept of the encoder-decoder structure and the idea behind self-attention what I am stuck at is the multi head part of the MultiheadAttention-Layer.
Looking at this explanation https://jalammar.github.io/illustrated-transformer/, which I generally found very good, it appears that multiple weight matrices (one set of weight matrices per head) are used to transform the original input value into the query
, key
and value
, which are then used to calculate the attention scores and the actual output of the MultiheadAttention layer. I also understand the idea of multiple heads to the individual attention heads can focus on different parts (as depicted in the link).
However, this seems to contradict other observations I have made:
- In the original paper https://arxiv.org/abs/1706.03762, it is stated that the input is split into parts of equal size per attention head.
So, for example I have:
batch_size = 1
sequence_length = 12
embed_dim = 512 (I assume that the dimension for ```query```, ```key``` and ```value``` are equal)
Then the shape of my query, key and token would each be [1, 12, 512]
We assume we have two heads, so num_heads = 2
This results in a dimension per head of 512/2=256. According to my understanding this should result in the shape [1, 12, 256] for each attention head.
So, am I correct in assuming that this depiction https://jalammar.github.io/illustrated-transformer/ just does not display this factor appropriately?
- Does the splitting of the input into different heads actually lead to different calculations in the layer or is it just done to make computations faster?
I have looked at the implementation in torch.nn.MultiheadAttention
in pytorch
and printed out the shapes at various stages during the forward pass through the layer. To me it appears that the operations are conducted in the following order:
- Use the
in_projection
weight matrices to get thequery
,key
andvalue
from the original inputs. After this the shape forquery
,key
andvalue
is [1, 12, 512]. From my understanding the weights in this step are the parameters that are actually learned in the layer during training. - Then the shape is modified for the multiple heads into [2, 12, 256].
- After this the dot product between
query
andkey
is calculated, etc.. The output of this operation has the shape [2, 12, 256]. - Then the output of the heads is concatenated which results in the shape [12, 512].
- The attention_output is multiplied by the output projection weight matrices and we get [12, 1, 512] (The batch size and the sequence_length is sometimes switched around). Again here we have weights that are being trained inside the matrices.
I printed the shape of the parameters in the layer for different num_heads
and the amount of the parameters does not change:
- First parameter: [1536,512] (The input projection weight matrix, I assume, 1536=3*512)
- Second parameter: [1536] (The input projection bias, I assume)
- Third parameter: [512,512] (The output projection weight matrix, I assume)
- Fourth parameter: [512] (The output projection bias, I assume)
On this website https://towardsdatascience.com/transformers-explained-visually-part-3-multi-head-attention-deep-dive-1c1ff1024853, it is stated that this is only a logical split. This seems to fit my own observations using the pytorch implementation.
So does the number of attention heads actually change the values that are outputted by the layer and the weights learned by the model? The way I see it, the weights are not influenced by the number of heads. Then how can multiple heads focus on different parts (similar to the filters in convolutional layers)?
I also initialized a MultiheadAttention layer with weights which are all equal to one and the number of heads did not influence the result.
Thanks in advance for any advice!
Topic transformer attention-mechanism pytorch
Category Data Science