Large jumps in loss in simple transformer model?
As an exercise, I created a very simple transformer model that just sees the same simple batch of dummy data repeatedly and (one would assume) should quickly learn to fit it perfectly.
And indeed, training reaches a loss of zero quickly. However I noticed that the loss does not stay at zero, or even close to it: there are occasional large jumps in the loss. The script below counts every time that the loss jumps by 10 or more between training steps in 100,000 training steps. Here is the output from one run:
Steps 0-10000: 24
Steps 10000-20000: 14
Steps 20000-30000: 34
Steps 30000-40000: 16
Steps 40000-50000: 6
Steps 50000-60000: 8
Steps 60000-70000: 7
Steps 70000-80000: 8
Steps 80000-90000: 6
Steps 90000-100000: 7
Zero first reached at step 597
And here's a plot of the loss over the first 10000 steps. As you can see, it shoots up over 400 sometimes!
My question is: why does this happen? Is it expected behavior? Does it have any practical implications? (I can imagine that the answer is no since, in practice, we train on much more complex data and loss will never actually go to zero.)
A couple notes:
- if I remove the transformer from the model (going directly from embedding to output layer), the loss goes to zero and stays there.
- The code below uses Huggingface transformers. I also implemented a version of the model using the transformer class from the pytorch library itself. The jumps also occurred, although somewhat less often.
Here is the script I have been using, in case it is helpful:
import random
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.optim.lr_scheduler import LambdaLR
from transformers import BertConfig, BertModel
N_ITERS = 100001
JUMP_INTERVAL = 10000
VOCAB_SIZE = 2
N_HIDDEN = 32
BATCH_SIZE = 2
SEQ_LEN = 2
MAX_NORM = 1
class Net(nn.Module):
def __init__(
self, n_inputs=VOCAB_SIZE, n_hidden=N_HIDDEN, n_outputs=VOCAB_SIZE
):
super().__init__()
self.embedding = nn.Embedding(n_inputs, n_hidden)
encoder_config = BertConfig(
vocab_size=1,
hidden_size=n_hidden,
num_hidden_layers=1,
num_attention_heads=2,
intermediate_size=n_hidden * 2,
is_decoder=True,
)
self.transformer = BertModel(encoder_config)
self.output = nn.Linear(n_hidden, n_outputs)
def forward(self, x):
embedded = self.embedding(x)
h = self.transformer(inputs_embeds=embedded).last_hidden_state
out = self.output(h)
return out
def rate(step, model_size, factor, warmup):
After http://nlp.seas.harvard.edu/annotated-transformer/#optimizer
we have to default the step to 1 for LambdaLR function
to avoid zero raising to negative power.
if step == 0:
step = 1
return factor * (
model_size ** (-0.5) * min(step ** (-0.5), step * warmup ** (-1.5))
)
def main():
random.seed(42), torch.manual_seed(42)
inp = torch.tensor([[1, 0], [0, 1]])
assert inp.shape == (BATCH_SIZE, SEQ_LEN)
assert inp.min() = 0 and inp.max() = VOCAB_SIZE - 1
tar = inp.max() - inp
net = Net()
loss_fn = nn.CrossEntropyLoss(reduction=mean)
optimizer = torch.optim.Adam(
net.parameters(), lr=1.0, betas=(0.9, 0.98), eps=1e-9
)
lr_scheduler = LambdaLR(
optimizer=optimizer,
lr_lambda=lambda step: rate(step, N_HIDDEN, 1, 3000),
)
prev_loss, loss_jump_counts, loss_jump_count = float(inf), [], 0
zero_first_reached_at = None
for i in tqdm(range(N_ITERS)):
preds = net(inp)
loss = loss_fn(preds.reshape(-1, preds.shape[-1]), tar.reshape(-1))
loss.backward()
torch.nn.utils.clip_grad_norm_(net.parameters(), MAX_NORM)
optimizer.step()
lr_scheduler.step()
if zero_first_reached_at is None and loss == 0:
zero_first_reached_at = i
if loss.item() - prev_loss 10:
loss_jump_count += 1
if i % JUMP_INTERVAL == 0 and i != 0:
loss_jump_counts.append((i, loss_jump_count))
loss_jump_count = 0
prev_loss = loss.item()
for i, count in loss_jump_counts:
print(fSteps {i - JUMP_INTERVAL}-{i}: {count})
print(fZero first reached at step {zero_first_reached_at})
if __name__ == __main__:
main()
Topic cross-entropy huggingface transformer loss-function deep-learning
Category Data Science