Save and Load Simple Transformer Model

I have trained Text classifier using simpleTranformer.ai I am struggling to save and load the model in docker container. Please let me know how can I save the trained model and then load it into different environment smoothly. I am using this library to : https://simpletransformers.ai/ to train a text model using these commands

 model = ClassificationModel('xlmroberta', 'xlm-roberta-base',use_cuda=cuda_available, num_labels=78, args={'learning_rate':1e-5, 'num_train_epochs': 1,'train_batch_size':256,'eval_batch_size':1048, 'n_gpu':4, 'reprocess_input_data': True, 
'overwrite_output_dir':True, 'overwrite_output_dir': True})

model.train_model(train_df)

I am saving the trained model using pytorch function:

torch.save(model, 'classifier')

But its showing error of some missing files when I tried to load this model from different virtual machine. So, I am looking for best alternative to save and load the simpleTransformer model.

Topic transformer torch python

Category Data Science


In PyTorch, the learnable parameters (i.e. weights and biases) of an torch.nn.Module model are contained in the model’s parameters (accessed with model.parameters()). A state_dict is simply a Python dictionary object that maps each layer to its parameter tensor. So, you can save a model in different ways,

  1. First way is to store a model like you have stored torch.save(model.state_dict(), PATH) and to load the same model on a different machine or some different place then first you have to make the instance of that model and then assign that model to the model parameter like this.

    model = TheModelClass(*args, **kwargs)

after making the instance of the class then load the model like this.

model.load_state_dict(torch.load(PATH))

It's a recommended way of saving and loading a model. If you are saving the model then before loading the model on a different machine first make the instance of that model then you can run the model.

  1. The other way is to save model.state_dict() using pickle pickle.dump(model.state_dict(), open(filename, 'wb')) and then load the model by pickle.load(open(filename, 'rb')) but, this isn't a standard way to save and load a model the most recommended way is the first one.

Besides this, if your error couldn't be resolved then you can ask in the comment section.


Your model is stored in outputs directory with the name pytorch_model.bin. Go and check that out.

About

Geeks Mental is a community that publishes articles and tutorials about Web, Android, Data Science, new techniques and Linux security.