Class token in ViT and BERT

I'm trying to understand the architecture of the ViT Paper, and noticed they use a CLASS token like in BERT.

To the best of my understanding this token is used to gather knowledge of the entire class, and is then solely used to predict the class of the image. My question is — why does this token exist as input in all the transformer blocks and is treated the same as the word / patches tokens?

Treating the class token like the rest of the tokens means other tokens can attend to it. I'd expect that the class token will be able to attend other tokens while they could not attend it.

Also, specifically in ViT, why does the class token receive positional encodings? It represents the entire class and thus doesn't have any specific location.

Thanks!

Topic attention-mechanism computer-vision deep-learning nlp machine-learning

Category Data Science


My question is — why does this token exist as input in all the transformer blocks and is treated the same as the word / patches tokens?

The transformers, by default are sequence to sequence networks. As there is no decoder layer in ViT, then the length of input sequence (number of patches) equals the length of output sequence. So If the goal is classification, there is two choices:

  • Either apply a fully connected layer on top of the transformer (which is not a good idea because then we have to fix the number of patches--which translates to input image resolution)
  • Or apply the classification layer on one items of the output sequence, but which one?! The best answer here is none of them! We don't want to be biased toward any of the patches. So the best solution here is to add a dummy input, call it class token and apply the classification layer on the corresponding output item!

Treating the class token like the rest of the tokens means other tokens can attend to it. I'd expect that the class token will be able to attend other tokens while they could not attend it.

Not sure, but I think if other tokens can attend to class token, then they can use some intermediate information about image class in lower layers! Just a guess and it worth to test different scenarios!

Also, specifically in ViT, why does the class token receive positional encodings? It represents the entire class and thus doesn't have any specific location.

I think the main reason is that, this way the network can distinguish a class embedding from patch embedding and treat them differently!


My question is — why does this token exist as input in all the transformer blocks and is treated the same as the word / patches tokens?

The CLASS token exists as input with a learnable embedding, prepended with the input patch embeddings and all of these are given as input to the first transformer layer. The CLASS token gathers information from all the patches using Multihead Self Attention (MSA). It is basically treated the same as patch tokens but at the end when doing classification, only the hidden output from the CLASS token are used as input to the classification layer.

Treating the class token like the rest of the tokens means other tokens can attend to it. I'd expect that the class token will be able to attend other tokens while they could not attend it.

I think that would help training all the weights during back-prop, including the weights of the embedding layer of the input patches.

Also, specifically in ViT, why does the class token receive positional encodings? It represents the entire class and thus doesn't have any specific location.

The positional encoding would tell that this is the first element of the sequence. Would help if you have a bi-directional transformer model like BERT.

About

Geeks Mental is a community that publishes articles and tutorials about Web, Android, Data Science, new techniques and Linux security.