While doing inference with a Transformer-Decoder in batches, how can I stop each sequence separately (if possible)?
So my decoder is a transformer-decoder and in training I don't have any issue. I have all the input from the beggining and correctly masked.
However, in inference I have to get a new token at a time and keep adding it to the target and only stop when the latest token outputted is eos. Well, in a batch I find it difficult because each sequence will end at a different point and so I'd have to keep going with some but stop others. What's the usual approach?
When getting a new token I do it in the entire batch. Something like:
tgt = torch.zeros(batch_size,max_len)
tgt[:,0] = sos
mask = torch.ones(batch_size,max_len).bool()#ignore
mask[:,0] = False #attend
for i in range(1,max_len-1):# 1 because sos is already entered. max_len-1 because eos will forcefully set at the end. so at most we can enter up to max_len-1
decoder_output = decoder(tgt,...)
tgt[:,i] = decoder_output[:,-1]
mask[:,i] = False
But when checking for eos I'd have to do it individually and most importantly: How would I ensure I don't extend the already finnished sequences?
Also, if some sequences have already ended but others aren't, wouldn't it a waste to keep inputting ALL of them to the decoder when I only care about some outputs?
Topic transformer pytorch inference machine-learning
Category Data Science