Transfer learning with many small datasets
Context
I am working on a NLP-model that can classify documents into one of N categories. I have document data from a number of different customers. The document topics are similar across customers but they classify them into different categories.
For simplicity, assume that the documents can contain six different topics: A,B,...,F. Each customer classify the documents differently from the topics, i.e. N mentioned above is customer specific and the mix of topics is (somewhat) different: Customer 1 have three classes (ABC, D, EF), customer 2 have four (A, BC, DE, F), customer 3 have four (AB, CD, E, F), and so on.
The goal is to use the data from the current customers to develop a model that can transfered and retrained on a new customer (with a different mix of topics A,B...,F and potentially only a small amount of data)
What I have done
I have been searching here and in papers for strategies on transfer learning with multiple smaller datasets with different output categories but have not been able to find anything. I have only been able to find papers that deals with larger datasets.
What I am considering
My current thought is to use a deep learning (keras/tf) model with this layout: document(from customer X) - preprocessing(shared among customers) - feature extraction (shared among customers) - classification (specific to customer X)
With this approach, I would during training switch the customer specific classification layer for each customer. This would train the preprocessing and feature extraction layers on data from all customers but only train the classification layers on individual customer data.
Questions
Do you have any references to papers or any other material that describes a similar situation?
Is this the best strategy for training this model? More specific, is it possible to make a layout of the model where the all customers classification layers are present in the model at the same time and the switching between layers can be handled by a cleverly engineered loss function?
Topic transfer-learning tensorflow deep-learning
Category Data Science