Calibrating probability thresholds for multiclass classification

I have built a network for the classification of three classes. The network consists of a CNN followed by two fully-connected layers. The CNN consists of convolutional layers, followed by batch normalization, a RELU activation, max pooling and drop out. The three classes are imbalanced (as can be seen in the confusion matrix below). I have optimized the parameters of the network to maximize AUC.

I'm calculating the AUC using macro- and micro-averaging. As can be seen in the ROC plot, the AUC is not that bad. On the other hand, the confusion matrix looks pretty bad, especially the first (low) class is badly predicted. The network tends to predict the majority class. As output of the network I'm getting a probability for each class. Then, I'm just taking the class according to the maximum probability for creating the confusion matrix.

I have tried to use balanced class weights while training the network (in the fit method of Keras). This helped that the network also predicts more often the minority class(es) but on the other hand the AUC was decreasing.

Is there a way to infer probability thresholds from the ROC plot? I think for two classes the optimal probability threshold can be inferred from the ROC plot by taking the max(TPR - FPR) but here I have three classes... Or is there another method?

Topic probability-calibration class-imbalance confusion-matrix classification machine-learning

Category Data Science


Disclaimer: This answer describes the thoughts I had about this problem, I don't offer any guarantee about their validity so use at your own risks ;)

There are two distinct parts in this problem:

  • Finding an optimal threshold over three classes
  • Improving performance with respect to the minority class

First part: I might be wrong but as far as I know there's no way to select a particular probability threshold when there are three classes. And if there were a way, it probably wouldn't be from the ROC plot since the curves for the different plots are independent, so picking a point on the curve would correspond to different thresholds for different classes and I don't see how this would be manageable. The only way that I know is what you did: label with the class which has the maximum probability.

Second part: if you want to force the model to take care of all the classes, you could also try optimizing with macro-AUC, but that would probably lead to the same problem of decreasing the micro-AUC since there would be more errors of true neutral or high predicted as low.

The way I see it, the three way model doesn't work well: the imbalance of the minority class is not that bad, it's only 2 or 3 times less than the other classes. So the fact that the model almost completely dismisses this class is a bit strange. I also see that instances of the true class low are almost as often predicted as class high than class neutral, even though I would expect that the vast majority of errors on class low to be predicted as neutral.

So my vague intuition is that maybe the system could be designed in a way which avoids the three-way problem for the model. I can think of two options in this perspective:

  • A two-steps system where the first model classifies between low and neutral+high then the second one between neutral and high (or first with low+neutral vs. high and then low vs. neutral). This way each model is binary classification so you have more control over the thresholds at each step. Normally this is not recommended in classification, but here I assume that the classes are not truly categorical.
  • Pushing the same idea further: there seems to be an order between the classes low < neutral < high, so maybe it would be possible to treat the task as a regression problem. This might help the model avoiding these "big" errors between low and high. Importantly you would also have flexibility with the predicted values: there would be two thresholds to determine and these could be found to optimize any appropriate evaluation measure.

About

Geeks Mental is a community that publishes articles and tutorials about Web, Android, Data Science, new techniques and Linux security.