My suggestion would be to modify your loss function to penalize False Positives and False Negatives with different magnitudes. Simple Cross entropy loss looks like:
$$\frac{-1}{N}\sum_{i=1}^{N} y_i \cdot log (\hat{y_i}) + (1 - y_i) \cdot log (1 - \hat{y_i})$$
Where:
- $y_i \cdot log (\hat{y_i})$ penalizes false negatives $F_N$
- $(1 - y_i) \cdot log (1 - \hat{y_i})$ penalizes false positives $F_P$.
You could introduce some new parameters $S_N=\frac{T_P}{F_N}=0.5$ and $S_P=\frac{T_N}{F_P}=0.5$ and change your loss function:
$$\frac{-1}{N}\sum_{i=1}^{N} \left(\frac{2\cdot S_N}{S_N+S_P}\right) y_i \cdot log (\hat{y_i}) + \left(\frac{2\cdot S_P}{S_N+S_P}\right)(1 - y_i) \cdot log (1 - \hat{y_i})$$
This way, if sensitivity is too low, false positives are penalized more, and vice-versa.
This is just one way to modify a loss function. In my experience, custom loss functions can be tricky and require a little more debugging. My immediate suggestion would be to only change $S_P$ and $S_N$ every epoch using a moving average of the last 100 or so.
It's very likely that a modified loss function will cause you're network to degrade in its accuracy and ability to generalize, so check the loss, specificity and senility of the validation set. Consider making your neural network larger if these parameters are too high.
Here's a simple implementation using Pytorch and Scipy's butter filter:
import torch as pt
from scipy import signal
import matplotlib.pyplot as pp
class ModifiedCrossEntropLoss(pt.nn.Module):
def __init__(self):
super(ModifiedCrossEntropLoss, self).__init__()
# filter_strength is the cut-off frequency of the low-pass filter.
# This value is always less than 1. Smaller values means that the moving
# average moves slower. This will make the training process more stable,
# but more epochs may be required to get good results.
filter_strength = 0.01
self.b, self.a = signal.butter(1, filter_strength) # filter coefficients
self.SN = []
self.SP = []
self.max_filter_window = 1000
def forward(self, input: pt.Tensor, target: pt.Tensor) -> pt.Tensor:
# assume target is a float tensor with 1s and 0s
# assume 1 indicates a positive result
assert list(input.shape) == list(target.shape), 'input shapes do not match'
assert (input >= 0.0).all(), 'input has negative values'
assert (target >= 0.0).all(), 'target has negative values'
assert (input <= 1.0).all(), 'input has values G.T. 1.0'
assert (target <= 1.0).all(), 'target has values G.T. 1.0'
eps = pt.finfo(input.dtype).tiny * 100 # for numerical stability, cannot take Log(0)
FN_penalty = target * pt.log(input + eps)
FP_penalty = (1 - target) * pt.log(1 - input + eps)
SN, SP = self.filter_SP_SN_history()
SNP = SN + SP
return -1 * pt.mean(2 * (SN / SNP * FN_penalty) + (SP / SNP * FP_penalty))
def filter_SP_SN_history(self):
# add 0.5
SN = signal.lfilter(self.b, self.a, self.SN[-self.max_filter_window:])[-1] + 0.5
SP = signal.lfilter(self.b, self.a, self.SP[-self.max_filter_window:])[-1] + 0.5
return SN, SP
def add_to_history(self, SN, SP):
# subtract 0.5 so that signal.lfilter by default returns 0
self.SN += [SN - 0.5]
self.SP += [SP - 0.5]
def plot(self, figure=None):
'''
Use this to visualize SN/SP, example code:
criterion = ModifiedCrossEntropLoss()
for i in range(100):
criterion.add_to_history(0.25, 0.75)
for i in range(100):
criterion.add_to_history(0.4, 0.6)
criterion.plot(figure=1)
'''
if figure is None:
pp.figure()
else:
pp.figure(figure).clf()
pp.plot(self.SN, label='SP')
pp.plot(self.SP, label='SP')
pp.plot(signal.lfilter(self.b, self.a, self.SN), c='#000000', alpha=0.5, linestyle=':')
pp.plot(signal.lfilter(self.b, self.a, self.SP), c='#000000', alpha=0.5, linestyle=':')
pp.legend()
pp.xlabel('Epochs')
Your training look would look something like this:
criterion = ModifiedCrossEntropLoss()
for e in range(epochs):
for b in range(batch_size):
input, target = get_training_data()
y = model(input)
loss = criterion(y, target)
optim.zero_grad()
loss.backward()
optim.step()
SN, SP = calculate_SN_SP()
criterion.add_to_history(SN, SP)
criterion.plot()