One single-batch training on Huggingface Bert model "ruins" the model

For some reason, I need to do further (2nd-stage) pre-training on Huggingface Bert model, and I find my training outcome is very bad.

After debugging for hours, surprisingly, I find even training one single batch after loading the base model, will cause the model to predict a very bad choice when I ask it to unmask some test sentences. I boil down my code to the minimal reproducible version here:

import torch
from transformers import AdamW, BertTokenizer
from transformers import BertForPreTraining

MSK_CODE = 103
CE_IGN_IDX = -100 # CrossEntropyLoss ignore index value

def sanity_check(tokenizer, inputs):
    print(tokenizer.decode(inputs['input_ids'][0]))
    print(tokenizer.convert_ids_to_tokens(
        inputs[labels][0]
    ))
    print('Label:', inputs[next_sentence_label][0])

def test(tokenizer, model, topk=3):
    test_data = She needs to [MASK] that [MASK] has only ten minutes.
    print('\n \033[92m', test_data, '\033[0m')
    test_inputs = tokenizer([test_data],
                       padding=True, truncation=True, return_tensors=pt)
    def classifier_hook(module, inputs, outputs):
        unmask_scores, seq_rel_scores = outputs
        token_ids = test_inputs['input_ids'][0]
        masked_idx = (
            token_ids == torch.tensor([MSK_CODE])
        )
        scores = unmask_scores[0][masked_idx]
        cands = torch.argsort(scores, dim=1, descending=True)
        for i, mask_cands in enumerate(cands):
            top_cands = mask_cands[:topk].detach().cpu()
            print(f'MASK[{i}] top candidates:', end= )
            print(tokenizer.convert_ids_to_tokens(top_cands))
    classifier = model.cls
    hook = classifier.register_forward_hook(classifier_hook)
    model.eval()
    model(**test_inputs)
    hook.remove()
    print()

# load model
model = BertForPreTraining.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
optimizer = AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)

# first test
test(tokenizer, model)

# our single-iteration inputs
#   [CLS]  1   2   3   4   5   6   [SEP]  8      9    10  11    12  [SEP]
pair = [['the man went to the store', 'penguins are flightless birds']]
relation_label = 1

# construct inputs
inputs = tokenizer(pair, padding=True, truncation=True, return_tensors=pt)
inputs[next_sentence_label] = torch.tensor([relation_label])
mask_labels = torch.full(inputs[input_ids].shape, fill_value=CE_IGN_IDX)
inputs[labels] = mask_labels

# mask two words
inputs[input_ids][0][4] = MSK_CODE
inputs[input_ids][0][9] = MSK_CODE
mask_labels[0][4] = tokenizer.convert_tokens_to_ids('to')
mask_labels[0][9] = tokenizer.convert_tokens_to_ids('are')

# train for one single iteration
sanity_check(tokenizer, inputs)
model.train()
optimizer.zero_grad()
outputs = model(**inputs)
loss = outputs.loss
loss.backward(loss)
optimizer.step()

# second test
test(tokenizer, model)

output:

Some weights of BertForPreTraining were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['cls.predictions.decoder.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

  She needs to [MASK] that [MASK] has only ten minutes. 
MASK[0] top candidates: ['know', 'understand', 'remember']
MASK[1] top candidates: ['she', 'he', 'it']

[CLS] the man went [MASK] the store [SEP] penguins [MASK] flightless birds [SEP]
['[UNK]', '[UNK]', '[UNK]', '[UNK]', 'to', '[UNK]', '[UNK]', '[UNK]', '[UNK]', 'are', '[UNK]', '[UNK]', '[UNK]', '[UNK]']
Label: tensor(1)

  She needs to [MASK] that [MASK] has only ten minutes. 
MASK[0] top candidates: ['are', 'know', 'be']
MASK[1] top candidates: ['are', 'is', 'she']

Basically, I use She needs to [MASK] that [MASK] has only ten minutes. as a test sentence to test the unmasking. As you may see, at the beginning when I tested the base model, it works perfectly. However, after I feed the pre-train model with a single pair of training batch:

[CLS] the man went [MASK] the store [SEP] penguins [MASK] flightless birds [SEP]

The updated model no longer makes sense, it unmasks She needs to [MASK] that [MASK] has only ten minutes. into She needs to [are] that [are] has only ten minutes.

I can think of two possibilities why this happens...

  1. Bert model is extremely sensitive to training batch size, a small batch causes unacceptable bias.
  2. there is a bug in the code?

So, any idea?

Topic pretraining transformer deep-learning

Category Data Science

About

Geeks Mental is a community that publishes articles and tutorials about Web, Android, Data Science, new techniques and Linux security.