One single-batch training on Huggingface Bert model "ruins" the model
For some reason, I need to do further (2nd-stage) pre-training on Huggingface Bert model, and I find my training outcome is very bad.
After debugging for hours, surprisingly, I find even training one single batch after loading the base model, will cause the model to predict a very bad choice when I ask it to unmask some test sentences. I boil down my code to the minimal reproducible version here:
import torch
from transformers import AdamW, BertTokenizer
from transformers import BertForPreTraining
MSK_CODE = 103
CE_IGN_IDX = -100 # CrossEntropyLoss ignore index value
def sanity_check(tokenizer, inputs):
print(tokenizer.decode(inputs['input_ids'][0]))
print(tokenizer.convert_ids_to_tokens(
inputs[labels][0]
))
print('Label:', inputs[next_sentence_label][0])
def test(tokenizer, model, topk=3):
test_data = She needs to [MASK] that [MASK] has only ten minutes.
print('\n \033[92m', test_data, '\033[0m')
test_inputs = tokenizer([test_data],
padding=True, truncation=True, return_tensors=pt)
def classifier_hook(module, inputs, outputs):
unmask_scores, seq_rel_scores = outputs
token_ids = test_inputs['input_ids'][0]
masked_idx = (
token_ids == torch.tensor([MSK_CODE])
)
scores = unmask_scores[0][masked_idx]
cands = torch.argsort(scores, dim=1, descending=True)
for i, mask_cands in enumerate(cands):
top_cands = mask_cands[:topk].detach().cpu()
print(f'MASK[{i}] top candidates:', end= )
print(tokenizer.convert_ids_to_tokens(top_cands))
classifier = model.cls
hook = classifier.register_forward_hook(classifier_hook)
model.eval()
model(**test_inputs)
hook.remove()
print()
# load model
model = BertForPreTraining.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
optimizer = AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
# first test
test(tokenizer, model)
# our single-iteration inputs
# [CLS] 1 2 3 4 5 6 [SEP] 8 9 10 11 12 [SEP]
pair = [['the man went to the store', 'penguins are flightless birds']]
relation_label = 1
# construct inputs
inputs = tokenizer(pair, padding=True, truncation=True, return_tensors=pt)
inputs[next_sentence_label] = torch.tensor([relation_label])
mask_labels = torch.full(inputs[input_ids].shape, fill_value=CE_IGN_IDX)
inputs[labels] = mask_labels
# mask two words
inputs[input_ids][0][4] = MSK_CODE
inputs[input_ids][0][9] = MSK_CODE
mask_labels[0][4] = tokenizer.convert_tokens_to_ids('to')
mask_labels[0][9] = tokenizer.convert_tokens_to_ids('are')
# train for one single iteration
sanity_check(tokenizer, inputs)
model.train()
optimizer.zero_grad()
outputs = model(**inputs)
loss = outputs.loss
loss.backward(loss)
optimizer.step()
# second test
test(tokenizer, model)
output:
Some weights of BertForPreTraining were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['cls.predictions.decoder.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
She needs to [MASK] that [MASK] has only ten minutes.
MASK[0] top candidates: ['know', 'understand', 'remember']
MASK[1] top candidates: ['she', 'he', 'it']
[CLS] the man went [MASK] the store [SEP] penguins [MASK] flightless birds [SEP]
['[UNK]', '[UNK]', '[UNK]', '[UNK]', 'to', '[UNK]', '[UNK]', '[UNK]', '[UNK]', 'are', '[UNK]', '[UNK]', '[UNK]', '[UNK]']
Label: tensor(1)
She needs to [MASK] that [MASK] has only ten minutes.
MASK[0] top candidates: ['are', 'know', 'be']
MASK[1] top candidates: ['are', 'is', 'she']
Basically, I use She needs to [MASK] that [MASK] has only ten minutes.
as a test sentence to test the unmasking.
As you may see, at the beginning when I tested the base model, it works perfectly.
However, after I feed the pre-train model with a single pair of training batch:
[CLS] the man went [MASK] the store [SEP] penguins [MASK] flightless birds [SEP]
The updated model no longer makes sense, it unmasks She needs to [MASK] that [MASK] has only ten minutes.
into She needs to [are] that [are] has only ten minutes.
I can think of two possibilities why this happens...
- Bert model is extremely sensitive to training batch size, a small batch causes unacceptable bias.
- there is a bug in the code?
So, any idea?
Topic pretraining transformer deep-learning
Category Data Science