How to use stop_gradient in Tensorflow

tf.stop_gradient provides a way to not compute gradient with respect to some variables during back-propagation.

For example, in the code below, we have three variables, w1, w2, w3 and input x. The loss is square((x1.dot(w1) - x.dot(w2 * w3))). We want to minimize this loss wrt to w1 but want to keep w2 and w3 fixed. To achieve this we can just put tf.stop_gradient(tf.matmul(x, w2*w3)).

In the figure below, I plotted how w1, w2, and w3 from their initial values as the function of training iterations. It can be seen that w2 and w3 remain fixed while w1 changes until it becomes equal to w2 * w3.

An image showing that w1 only learns but not w2 and w3:

An image showing that w1 only learns but not w2 and w3

import tensorflow as tf
import numpy as np

w1 = tf.get_variable("w1", shape=[5, 1], initializer=tf.truncated_normal_initializer())
w2 = tf.get_variable("w2", shape=[5, 1], initializer=tf.truncated_normal_initializer())
w3 = tf.get_variable("w3", shape=[5, 1], initializer=tf.truncated_normal_initializer())
x = tf.placeholder(tf.float32, shape=[None, 5], name="x")


a1 = tf.matmul(x, w1)
a2 = tf.matmul(x, w2*w3)
a2 = tf.stop_gradient(a2)
loss = tf.reduce_mean(tf.square(a1 - a2))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)
gradients = optimizer.compute_gradients(loss)
train_op = optimizer.apply_gradients(gradients)

tf. gradients(loss, embed) computes the partial derivative of the tensor loss with respect to the tensor embed. TensorFlow computes this partial derivative by backpropagation, so it is expected behavior that evaluating the result of tf. gradients(...) performs backpropagation. However, evaluating that tensor does not perform any variable updates, because the expression does not include any assignment operations.

tf.stop_gradient() is an operation that acts as the identity function in the forward direction but stops the accumulated gradient from flowing through that operator in the backward direction. It does not prevent backpropagation altogether, but instead prevents an individual tensor from contributing to the gradients that are computed for an expression. The documentation for the operation has more details about the operation, and when to use it.