Do the multiple heads in Multi head attention actually lead to more parameters or different outputs?

I am trying to understand Transformers. While I understand the concept of the encoder-decoder structure and the idea behind self-attention what I am stuck at is the multi head part of the MultiheadAttention-Layer.

Looking at this explanation https://jalammar.github.io/illustrated-transformer/, which I generally found very good, it appears that multiple weight matrices (one set of weight matrices per head) are used to transform the original input value into the query, key and value, which are then used to calculate the attention scores and the actual output of the MultiheadAttention layer. I also understand the idea of multiple heads to the individual attention heads can focus on different parts (as depicted in the link).

However, this seems to contradict other observations I have made:

  1. In the original paper https://arxiv.org/abs/1706.03762, it is stated that the input is split into parts of equal size per attention head.

So, for example I have:

batch_size = 1
sequence_length = 12
embed_dim = 512 (I assume that the dimension for ```query```, ```key``` and ```value``` are equal)
Then the shape of my query, key and token would each be [1, 12, 512]
We assume we have two heads, so num_heads = 2
This results in a dimension per head of 512/2=256. According to my understanding this should result in the shape [1, 12, 256] for each attention head.

So, am I correct in assuming that this depiction https://jalammar.github.io/illustrated-transformer/ just does not display this factor appropriately?

  1. Does the splitting of the input into different heads actually lead to different calculations in the layer or is it just done to make computations faster?

I have looked at the implementation in torch.nn.MultiheadAttention in pytorch and printed out the shapes at various stages during the forward pass through the layer. To me it appears that the operations are conducted in the following order:

  1. Use the in_projection weight matrices to get the query, key and value from the original inputs. After this the shape for query, key and value is [1, 12, 512]. From my understanding the weights in this step are the parameters that are actually learned in the layer during training.
  2. Then the shape is modified for the multiple heads into [2, 12, 256].
  3. After this the dot product between query and key is calculated, etc.. The output of this operation has the shape [2, 12, 256].
  4. Then the output of the heads is concatenated which results in the shape [12, 512].
  5. The attention_output is multiplied by the output projection weight matrices and we get [12, 1, 512] (The batch size and the sequence_length is sometimes switched around). Again here we have weights that are being trained inside the matrices.

I printed the shape of the parameters in the layer for different num_heads and the amount of the parameters does not change:

  1. First parameter: [1536,512] (The input projection weight matrix, I assume, 1536=3*512)
  2. Second parameter: [1536] (The input projection bias, I assume)
  3. Third parameter: [512,512] (The output projection weight matrix, I assume)
  4. Fourth parameter: [512] (The output projection bias, I assume)

On this website https://towardsdatascience.com/transformers-explained-visually-part-3-multi-head-attention-deep-dive-1c1ff1024853, it is stated that this is only a logical split. This seems to fit my own observations using the pytorch implementation.

So does the number of attention heads actually change the values that are outputted by the layer and the weights learned by the model? The way I see it, the weights are not influenced by the number of heads. Then how can multiple heads focus on different parts (similar to the filters in convolutional layers)?

I also initialized a MultiheadAttention layer with weights which are all equal to one and the number of heads did not influence the result.

Thanks in advance for any advice!

Topic transformer attention-mechanism pytorch

Category Data Science

About

Geeks Mental is a community that publishes articles and tutorials about Web, Android, Data Science, new techniques and Linux security.