Tensorflow: How to replace or modify gradient?
Solution 1:
For TensorFlow 1.7 and TensorFlow 2.0 look at edit blow.
First define your custom gradient:
@tf.RegisterGradient("CustomGrad")
def _const_mul_grad(unused_op, grad):
return 5.0 * grad
Since you want nothing to happen in the forward pass, override the gradient of an identity operation with your new gradient:
g = tf.get_default_graph()
with g.gradient_override_map({"Identity": "CustomGrad"}):
output = tf.identity(input, name="Identity")
Here is a working example with a layer that clips gradients in the backwards pass and does nothing in the forwards pass, using the same method:
import tensorflow as tf
@tf.RegisterGradient("CustomClipGrad")
def _clip_grad(unused_op, grad):
return tf.clip_by_value(grad, -0.1, 0.1)
input = tf.Variable([3.0], dtype=tf.float32)
g = tf.get_default_graph()
with g.gradient_override_map({"Identity": "CustomClipGrad"}):
output_clip = tf.identity(input, name="Identity")
grad_clip = tf.gradients(output_clip, input)
# output without gradient clipping in the backwards pass for comparison:
output = tf.identity(input)
grad = tf.gradients(output, input)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print("with clipping:", sess.run(grad_clip)[0])
print("without clipping:", sess.run(grad)[0])
Edit for TensorFlow 1.7 and TensorFlow 2.0
Since 1.7 there is a new way to redefine the gradient with shorter syntax, which also works with Tensorflow 2.0. It also allows to redefine the gradient of multiple operations at the same time. Here are the examples from above, rewritten for TensorFlow 1.7 and TensorFlow 2.0:
Layer that scales gradients in the backward pass:
@tf.custom_gradient
def scale_grad_layer(x):
def grad(dy):
return 5.0 * dy
return tf.identity(x), grad
Example with a layer that clips gradients in the backward pass:
@tf.custom_gradient
def clip_grad_layer(x):
def grad(dy):
return tf.clip_by_value(dy, -0.1, 0.1)
return tf.identity(x), grad
Solution 2:
Assuming the forward computation is
y = f(x)
And you want it to backpropagate like
y = b(x)
A simple hack will be:
y = b(x) + tf.stop_gradient(f(x) - b(x))