Logic behind pre-trained weights and transfer learning

I am not sure about the logic behind, how pre-trained weights actually make sense and translate into a new problem.

To be more specific; for example in a object detection network, how would a model's weights that were trained, let's say, on the COCO dataset, with 80 categories, would translate into my new problem that only has 2 categories (classes). How does this make sense? What kind of meaningful features could even be transferred from the previously pre-trained model to my new problem, since the number of categories (classes) have been changed, and also I'm trying to detect completely different types of objects than the previous model? Why do we do transfer learning then?

Topic pretraining object-detection transfer-learning neural-network classification

Category Data Science


In the case of CNN, you are correct in the sense that you cannot use the final layer weights if the number of categories is different. But you CAN reuse the weights in the initial layers. These recognise the lower-level objects in the image. There is no need to train all over again. You would only have to train the upper layers specific to the categorisation you want to do. So transfer learning helps for the lower layers in the case of CNN.

Let me give one more example. Say there is a trained model to recognize 100 categories of objects including animals, plants, buildings, vehicles, etc. You want to only categorize vehicles. You could use the lower layers of the above model as-is and then add a couple of layers at the top which learns the exact types of vehicle categorization.

In the case of word embedding, the concept of transfer learning is more intuitive and may help understand the concept better. You can train a model on the entire Wikipedia and enable it to learn embeddings for the representation of words. These representations can then directly be used in other models - for e.g. tweet sentiment classification etc. In fact, BERT is one such pertained dataset of embeddings. It is a classic case of transfer learning. You no longer have to train on volumes of TeraBytes to learn the language embeddings. You can use it off the shelf.


If we classify new objects using transfer learning:

  1. We delete the top Dense layer of the pre-trained neural network.
  2. Now suppose you have to classify 5 classes, so your final dense layer will contain 5 nodes.
  3. Also you will add some dense layers prior to your new 5 nodes Dense layer, so that you can train the model with new data.
  4. All the layers prior to the few dense layers you added will be frozen.
  5. Now you will train your new model with your dataset.

The Logic behind these steps is:

  1. As CNN uses hierarchical-based learning, meaning the initial layers will try to learn basics features in a dataset ie. lines, corners, shapes, etc. So the frozen layers of your model have already learned the basic features which are common to your dataset as well as the dataset it has been previously trained on.
  2. The few Dense layers you added, will try to learn high-level features specific to your dataset when you train your model with your dataset.

About

Geeks Mental is a community that publishes articles and tutorials about Web, Android, Data Science, new techniques and Linux security.