Pytorch: understanding the purpose of each argument in the forward function of nn.TransformerDecoder

According to, the forward function of nn.TransformerDecoder contemplates the following arguments:

  • tgt – the sequence to the decoder (required).
  • memory – the sequence from the last layer of the encoder (required).
  • tgt_mask – the mask for the tgt sequence (optional).
  • memory_mask – the mask for the memory sequence (optional).
  • tgt_key_padding_mask – the mask for the tgt keys per batch (optional).
  • memory_key_padding_mask – the mask for the memory keys per batch (optional).

Unfortunately, Pytorch's official documentation on the function isn't exactly very thorough at this point (April 2021), in terms of the expected dimensions of each tensor and when it does or doesn't make sense to use each of the optional arguments.

For example, in previous conversations it was explained to me that tgt_mask is usually a square matrix used for self attention masking to prevent future tokens from leaking into the prediction of past tokens. Similarly, tgt_key_padding_mask is used for masking padding tokens (which happens when you pad a batch of sequences of different lengths so that they can fit into a single tensor). In light of this, it makes total sense to use tgt_mask in the decoder, but I wouldn't be so sure about tgt_key_padding_mask. What would be the point of masking target padding tokens? Isn't it enough to simply ignore the predictions associated to padding tokens during training (say, you could do something like nn.CrossEntropyLoss(ignore_index=PADDING_INDEX) and that's it)?

More generally, and considering that the current documentation is not as thorough as one would like it to be, I would like to know what the purpose is of each argument of nn.TransformerDecoder's forward function, when it makes sense to use each of the optional arguments, and if there are nuances in the usage one should keep in mind when switching between training and inference modes.

Topic transformer sequence-to-sequence pytorch text-generation nlp

Category Data Science

About the need for tgt_key_padding_mask

While padding is usually applied after the normal tokens (i.e. right padding), it is perfectly fine to apply it before normal tokens (i.e. left padding). For instance, fairseq supports parameter left_pad to specify precisely this.

For left padding to be handled correctly, you must mask the padding tokens, because the self-attention mask would not prevent the hidden states of the normal token positions to depend on the padding positions.

About the meaning of each argument

I think the best documentation is in MultiHeadAttention.forward:

  • key_padding_mask – if provided, specified padding elements in the key will be ignored by the attention. When given a binary mask and a value is True, the corresponding value on the attention layer will be ignored. When given a byte mask and a value is non-zero, the corresponding value on the attention layer will be ignored
  • attn_mask – 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch.

With that information and knowing where keys, values and queries come from in each multi-head attention block, it should be clear the purpose of each parameter in nn.TransformerDecoder.forward. Also, the documentation of MultiheadAttention.forward contains info about the expected shapes.


Geeks Mental is a community that publishes articles and tutorials about Web, Android, Data Science, new techniques and Linux security.