How to use stop_gradient in Tensorflow
tf.stop_gradient
provides a way to not compute gradient with respect to some variables during back-propagation.
For example, in the code below, we have three variables, w1
, w2
, w3
and input x
. The loss is square((x1.dot(w1) - x.dot(w2 * w3)))
. We want to minimize this loss wrt to w1
but want to keep w2
and w3
fixed. To achieve this we can just put tf.stop_gradient(tf.matmul(x, w2*w3))
.
In the figure below, I plotted how w1
, w2
, and w3
from their initial values as the function of training iterations. It can be seen that w2
and w3
remain fixed while w1 changes until it becomes equal to w2 * w3
.
An image showing that w1 only learns but not w2
and w3
:
import tensorflow as tf
import numpy as np
w1 = tf.get_variable("w1", shape=[5, 1], initializer=tf.truncated_normal_initializer())
w2 = tf.get_variable("w2", shape=[5, 1], initializer=tf.truncated_normal_initializer())
w3 = tf.get_variable("w3", shape=[5, 1], initializer=tf.truncated_normal_initializer())
x = tf.placeholder(tf.float32, shape=[None, 5], name="x")
a1 = tf.matmul(x, w1)
a2 = tf.matmul(x, w2*w3)
a2 = tf.stop_gradient(a2)
loss = tf.reduce_mean(tf.square(a1 - a2))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)
gradients = optimizer.compute_gradients(loss)
train_op = optimizer.apply_gradients(gradients)
tf. gradients(loss, embed)
computes the partial derivative of the tensor loss
with respect to the tensor embed
. TensorFlow computes this partial derivative by backpropagation, so it is expected behavior that evaluating the result of tf. gradients(...)
performs backpropagation. However, evaluating that tensor does not perform any variable updates, because the expression does not include any assignment operations.
tf.stop_gradient()
is an operation that acts as the identity function in the forward direction but stops the accumulated gradient from flowing through that operator in the backward direction. It does not prevent backpropagation altogether, but instead prevents an individual tensor from contributing to the gradients that are computed for an expression. The documentation for the operation has more details about the operation, and when to use it.