How to fine-tune GPT-J with small dataset

Firstly, thank you so much for looking at this post. I could really use some help.

I have followed this guide as closely as possible: https://github.com/kingoflolz/mesh-transformer-jax

I'm trying to fine-tune GPT-J with a small dataset of ~500 lines:

You are important to me. |endoftext|
I love spending time with you. |endoftext|
You make me smile. |endoftext|
feel so lucky to be your friend. |endoftext|
You can always talk to me, even if it’s about something that makes you nervous or scared or sad. |endoftext|
etc...

Using the create_finetune_tfrecords.py script (from the repo mentioned above) outputs a file with 2 in it. I understand that means my data has 2 sequences.

I could really use some advice with the .json config file. What constants do you recommend for this small dataset? The best I came up with trying to follow the guide:

{
  layers: 28,
  d_model: 4096,
  n_heads: 16,
  n_vocab: 50400,
  norm: layernorm,
  pe: rotary,
  pe_rotary_dims: 64,

  seq: 2048,
  cores_per_replica: 8,
  per_replica_batch: 1,
  gradient_accumulation_steps: 2,

  warmup_steps: 1,
  anneal_steps: 9,
  lr: 1.2e-4,
  end_lr: 1.2e-5,
  weight_decay: 0.1,
  total_steps: 10,

  tpu_size: 8,

  bucket: chat-app-tpu-bucket-europe,
  model_dir: finetune_dir,

  train_set: james_bond_1.train.index,
  val_set: {},

  eval_harness_tasks: [
  ],

  val_batches: 2,
  val_every: 400000,
  ckpt_every: 1,
  keep_every: 1,

  name: GPT3_6B_pile_rotary,
  wandb_project: mesh-transformer-jax,
  comment: 
}

Problem is... When I test the fine-tuned model, I get responses that make no sense:

Very much looking forward to hearing from you! :)

Topic tpu finetuning tensorflow

Category Data Science

About

Geeks Mental is a community that publishes articles and tutorials about Web, Android, Data Science, new techniques and Linux security.