No gradients provided for any variable

I have composed a customized loss function (kl_loss):

def tensor_pValue(pnls,pnl):
    vec=tf.contrib.framework.sort(pnls,axis=-1,direction='ASCENDING')
    rank_p=tf.divide(tf.range(0,264.5,1),264.0)
    return tf.gather(rank_p, tf.searchsorted(vec,pnl,side='left'))

def kl_divergence(p, q): 
    epsilon = 0.00001
    p=p+epsilon
    q=q+epsilon
    return tf.reduce_sum(p * tf.log(p/q))

def kl_loss(predicted_pnL,actual_pnl_tensor):
    p_dist=tf.squeeze(tf.map_fn(lambda inp:tensor_pValue(inp[0],inp[1]),(predicted_pnL,actual_pnl_tensor),dtype=tf.float32))
    u_dist=tf.random.uniform([264],0,1,dtype=tf.float32)
    return kl_divergence(p_dist,u_dist)

And then i constructed a simple net work using Keras:

optimizer = tf.train.AdamOptimizer(0.001)
input_dim = X_train.shape[1]
model = keras.Sequential([
keras.layers.Dense(UNITS, activation=tf.nn.relu,
             input_dim=input_dim),
keras.layers.Dense(UNITS, activation=tf.nn.relu),
keras.layers.Dense(264)
])
model.compile(loss=lambda y, f: kl_loss(f,y), optimizer=optimizer)
model.fit(X_train, train_y, epochs=EPOCHS, batch_size=BATCH_SIZE,verbose=0)

And got following errors:

ValueError: No gradients provided for any variable: ["", "", "", "", "", ""].

Can anyone help to take a look on where might be wrong on this? Thank you very much!

Topic keras implementation deep-learning

Category Data Science


It seems that you are using sorting operations like to calculate p_dist, and these kind of operations do not provide a gradient. So the error might not be in the KL function. Hope that helps.


If you are using a default KL divergence loss, I recommend using an implemented one: tf.keras.losses.KLDivergence.

If it is the problem to use Keras from TF, implement it as they do: https://github.com/tensorflow/tensorflow/blob/3699977134badccb7032fa6921d70e01ba8fdf7d/tensorflow/python/keras/losses.py#L978

About

Geeks Mental is a community that publishes articles and tutorials about Web, Android, Data Science, new techniques and Linux security.