Weighting the loss function based on previous seen true positive rates
Similiar to class imbalance there is always something I would call learnability imbalance in multi-class classification. What I mean by that: Even when the classes are evenly distributed in the dataset some classes will be classified more easily by the model than others. An example would be a CNN model that classifies dog, cat and car. Dog and cat will most likely have a lower true positive rate than car because cats and dogs look more similiar to each other. But I would like evenly distributed recall/precision/f1-scores/true-positive rates of all classes.
Here is the solution I tried:
- train model normaly
- evaluate the model and return a list with the true positive rates of each class
- use some function to create weights using the true positive rates with which I can weight the CrossEntropyLoss
- train again using the weighted loss function
I implemented this for a CIFAR-10 classifier. To calculate the weights I tried a lot of stuff for example
# car, dog, cat
true_positive_rates = [0.9560878243512974, 0.7365384615384616, 0.7807933194154488]
average = np.mean(true_positive_rates)
loss_weights = [pow(1 - ((score - average) / score), 4) for score in true_positive_rates]
print(loss_weights)
[0.743631046434835, 1.2530321887900169, 1.1150155156144919]
I tried different exponents to create more or less extreme weights (in this example 4). Since Im using Pytorch the weighting looks like this:
nn.CrossEntropyLoss(weight=loss_weights)
My results: Even though classes were more difficult to classify than others as I though, giving them a higher weights made almost no difference.
Here is a figure (the blue bars are the true positive rates from the normal first model, the orange bars are from the second model which was trained using the loss that was weighted with the true positive rates from the first model):
As you can see there is no improvent (desired result would have been that all bars have the same height). Is there a obvious reason why this doesnt work?