Calculate confidence score of a neural network prediction

I am using a deep neural network model to make predictions. My problem is a classification(binary) problem. I wish to calculate the confidence score of each prediction. As of now, I use predict_proba of sklearn to get probability of a class and consider it as the model confidence for that prediction. Is it the right approach? What could be a better way to calculate the confidence score? Kindly add relevant literature to support your approach.

Topic predict prediction scikit-learn predictive-modeling

Category Data Science


One way to estimate the level of confidence we have about an ANN prediction is to use dropout perturbations. The idea was proposed in this paper: Dropout as a Bayesian Approximation. Representing Model Uncertainty in Deep Learning. The core idea is to use dropout as a perturbation method, and check how predictions change with varying levels of dropout. Once you sampled a sufficient number of "distorted predictions", you can estimate something analogous to a confidence interval around the initial model prediction. This technique works either for classifiers and regressors.

You can read an explanation of this approach here.


EDIT:

To be more precise:

In order to implement this technique, use Dropout() layers, they can be used in prediction phase too (not just during training). You can train your Neural Network, then transfer its weights into another ANN with the same architecture + dropout layers. Something like:

new_model.set_weights(original_model.get_weights())

Once the new ANN with Dropout() has its weights, run it and vary its dropout hyperparams, get the predictions and calculate the CIs. I know it's time consuming so do it only if you think it's really worth it. Another possible way to implement can be done with custom layers like:

model.add(Lambda(lambda x: K.dropout(x, level=0.5)))

About

Geeks Mental is a community that publishes articles and tutorials about Web, Android, Data Science, new techniques and Linux security.