Adding a group specific penalty to binary cross-entropy

I want to implement a custom Keras loss function that consists of plain binary cross-entropy plus a penalty that increases the loss for false negatives from one class (each observation can belong to one of two classes, privileged and unprivileged) and decreases the loss for true positives from that same class.

My implementation so far can be seen below. Unfortunately, it does not work yet, because as you can see, I simply add the penalty to the binary cross-entropy, and added constants don't enter a derivation, so the penalty does not affect the gradients. Do you have any idea how I can fix this without changing the general idea of the penalty?

priv is an additional tensor encoding to which group an observation belongs.

Any help is appreciated and you might literally save my master's thesis by solving this.

def customLoss2(priv):

  def binary_crossentropy_adjusted_groupspecific(y_true, y_pred): 

    #Binary tensor that is 1 for predictions smaller than tau
    temp = tf.subtract(y_pred, tau)
    temp = K.relu(temp)
    less_than_tau = tf.multiply((tf.subtract(K.sign(temp), 1.0)), -1.0)

    #Inversion of the priviledged tensor
    temp = tf.subtract(priv, 1.0)
    unpriv = tf.multiply(temp, -1.0)

    #Inversion of the true label tensor
    temp = tf.subtract(y_true, 1.0)
    inverted_y_true = tf.multiply(temp, -1.0)

    #Creating tensor with gain for all true negatives
    gains = tf.multiply(tf.multiply(less_than_tau, unpriv), tf.multiply(inverted_y_true, gain))

    #Creating tensor with loss for all false negatives
    losses = tf.multiply(tf.multiply(less_than_tau, unpriv), tf.multiply(y_true, loss))

    #Concatenating the tensors to take their mean
    bce = K.mean(K.binary_crossentropy(y_true, y_pred))
    conc = K.mean(K.concatenate([gains, losses], axis=0))

    sum = bce + conc

    #return result
    return sum

  #return the loss function to keras
  return binary_crossentropy_adjusted_groupspecific

Topic derivation keras loss-function gradient-descent python

Category Data Science

About

Geeks Mental is a community that publishes articles and tutorials about Web, Android, Data Science, new techniques and Linux security.