How to balance sensitivity(sn) and specificity(sp) of an Artificial Neural Network model?

I have been working on a binary classification problem of protein sequences. I have used a feed-forward neural network with two hidden layers. I have the training and validation accuracy/loss curves that the model has trained pretty well without overfitting/underfitting.

Then, while testing on independent dataset, I have following results:

Accuracy: 0.7672583826429981, 
MCC: 0.5401163645598229, 
Sensitivity: 0.8379446640316206,
Specificity:  0.6968503937007874, 
Confusion matrix: 
[[177  77]
 [ 41 212]] 

The results are already pretty impressive for this particular problem in terms of Accuracy and MCC, but there is a high imbalance between Sn and Sp. The threshold I have taken to separate the positive and negative classes is 0.5.

The question is, can we change the threshold value to balance the Sp and Sn? If yes, how can we find the best threshold before touching the independent test set? If not, what are the other ways to improve the balance between Sp and Sn?

(Please ask me if additional information (Network architecture, dataset, etc.) is needed to answer the question.)

Edit: My training dataset is balanced.

My suggestion would be to modify your loss function to penalize False Positives and False Negatives with different magnitudes. Simple Cross entropy loss looks like: $$\frac{-1}{N}\sum_{i=1}^{N} y_i \cdot log (\hat{y_i}) + (1 - y_i) \cdot log (1 - \hat{y_i})$$


  • $y_i \cdot log (\hat{y_i})$ penalizes false negatives $F_N$
  • $(1 - y_i) \cdot log (1 - \hat{y_i})$ penalizes false positives $F_P$.

You could introduce some new parameters $S_N=\frac{T_P}{F_N}=0.5$ and $S_P=\frac{T_N}{F_P}=0.5$ and change your loss function: $$\frac{-1}{N}\sum_{i=1}^{N} \left(\frac{2\cdot S_N}{S_N+S_P}\right) y_i \cdot log (\hat{y_i}) + \left(\frac{2\cdot S_P}{S_N+S_P}\right)(1 - y_i) \cdot log (1 - \hat{y_i})$$

This way, if sensitivity is too low, false positives are penalized more, and vice-versa.

This is just one way to modify a loss function. In my experience, custom loss functions can be tricky and require a little more debugging. My immediate suggestion would be to only change $S_P$ and $S_N$ every epoch using a moving average of the last 100 or so.

It's very likely that a modified loss function will cause you're network to degrade in its accuracy and ability to generalize, so check the loss, specificity and senility of the validation set. Consider making your neural network larger if these parameters are too high.

Here's a simple implementation using Pytorch and Scipy's butter filter:

import torch as pt
from scipy import signal
import matplotlib.pyplot as pp

class ModifiedCrossEntropLoss(pt.nn.Module):

    def __init__(self):
        super(ModifiedCrossEntropLoss, self).__init__()
        # filter_strength is the cut-off frequency of the low-pass filter.
        # This value is always less than 1. Smaller values means that the moving
        # average moves slower. This will make the training process more stable,
        # but more epochs may be required to get good results.
        filter_strength = 0.01
        self.b, self.a = signal.butter(1, filter_strength)  # filter coefficients
        self.SN = []
        self.SP = []
        self.max_filter_window = 1000

    def forward(self, input: pt.Tensor, target: pt.Tensor) -> pt.Tensor:
        # assume target is a float tensor with 1s and 0s
        # assume 1 indicates a positive result

        assert list(input.shape) == list(target.shape), 'input shapes do not match'
        assert (input >= 0.0).all(), 'input has negative values'
        assert (target >= 0.0).all(), 'target has negative values'
        assert (input <= 1.0).all(), 'input has values G.T. 1.0'
        assert (target <= 1.0).all(), 'target has values G.T. 1.0'

        eps = pt.finfo(input.dtype).tiny * 100   # for numerical stability, cannot take Log(0)
        FN_penalty = target * pt.log(input + eps)
        FP_penalty = (1 - target) * pt.log(1 - input + eps)
        SN, SP = self.filter_SP_SN_history()
        SNP = SN + SP
        return -1 * pt.mean(2 * (SN / SNP * FN_penalty) + (SP / SNP * FP_penalty))

    def filter_SP_SN_history(self):
        # add 0.5
        SN = signal.lfilter(self.b, self.a, self.SN[-self.max_filter_window:])[-1] + 0.5
        SP = signal.lfilter(self.b, self.a, self.SP[-self.max_filter_window:])[-1] + 0.5
        return SN, SP

    def add_to_history(self, SN, SP):
        # subtract 0.5 so that signal.lfilter by default returns 0
        self.SN += [SN - 0.5]
        self.SP += [SP - 0.5]

    def plot(self, figure=None):
        Use this to visualize SN/SP, example code:
        criterion = ModifiedCrossEntropLoss()
        for i in range(100):
            criterion.add_to_history(0.25, 0.75)
        for i in range(100):
            criterion.add_to_history(0.4, 0.6)
        if figure is None:
        pp.plot(self.SN, label='SP')
        pp.plot(self.SP, label='SP')
        pp.plot(signal.lfilter(self.b, self.a, self.SN), c='#000000', alpha=0.5, linestyle=':')
        pp.plot(signal.lfilter(self.b, self.a, self.SP), c='#000000', alpha=0.5, linestyle=':')

Your training look would look something like this:

criterion = ModifiedCrossEntropLoss()
for e in range(epochs):
    for b in range(batch_size):
        input, target = get_training_data()
        y = model(input)
        loss = criterion(y, target)
    SN, SP = calculate_SN_SP()
    criterion.add_to_history(SN, SP)

Sure, you can use the probability values to calculate the threshold that gives you roughly equal sensitivity and specificity. You can tune this using cross validation (where a wildly varying ideal threshold would be a red flag to me).

That you’re getting into the probability values and thinking about different thresholds, however, suggests that there is some cost associated with incorrect decisions. If that is the case, you might want to consider that cost in order to favor sensitivity or specificity. In the extreme, this could lead to classifying every observation in one class, regardless of what the features are. Better yet, ditch hard classifications and assess the raw probability values. I’ll close with the links that I often post about evaluating the probability values.


