What does scaling a gradient do?

In the MuZero paper pseudocode, they have the following line of code:

hidden_state = tf.scale_gradient(hidden_state, 0.5)

What does this do? Why is it there?

I've searched for tf.scale_gradient and it doesn't exist in tensorflow. And, unlike scalar_loss, they don't seem to have defined it in their own code.

For context, here's the entire function:

def update_weights(optimizer: tf.train.Optimizer, network: Network, batch,
                   weight_decay: float):
  loss = 0
  for image, actions, targets in batch:
    # Initial step, from the real observation.
    value, reward, policy_logits, hidden_state = network.initial_inference(
        image)
    predictions = [(1.0, value, reward, policy_logits)]

    # Recurrent steps, from action and previous hidden state.
    for action in actions:
      value, reward, policy_logits, hidden_state = network.recurrent_inference(
          hidden_state, action)
      predictions.append((1.0 / len(actions), value, reward, policy_logits))

      # THIS LINE HERE
      hidden_state = tf.scale_gradient(hidden_state, 0.5)

    for prediction, target in zip(predictions, targets):
      gradient_scale, value, reward, policy_logits = prediction
      target_value, target_reward, target_policy = target

      l = (
          scalar_loss(value, target_value) +
          scalar_loss(reward, target_reward) +
          tf.nn.softmax_cross_entropy_with_logits(
              logits=policy_logits, labels=target_policy))

      # AND AGAIN HERE
      loss += tf.scale_gradient(l, gradient_scale)

  for weights in network.get_weights():
    loss += weight_decay * tf.nn.l2_loss(weights)

  optimizer.minimize(loss)

What does scaling the gradient do, and why are they doing it there?

Topic deepmind ai machine-learning-model machine-learning

Category Data Science


Author of the paper here - I missed that this is apparently not a TensorFlow function, it's equivalent to Sonnet's scale_gradient, or the following function:

def scale_gradient(tensor, scale):
  """Scales the gradient for the backward pass."""
  return tensor * scale + tf.stop_gradient(tensor) * (1 - scale)

Given that its pseude code? (since its not in TF 2.0) I would go with gradient clipping or batch normalisation ('scaling of activation functions')

About

Geeks Mental is a community that publishes articles and tutorials about Web, Android, Data Science, new techniques and Linux security.