How to flatten a gradient (list of tensors) in tf.function (graph mode)
Solution 1:
Maybe you could try directly iterating over your list
of tensors instead of getting individual tensors by their index:
import tensorflow as tf
grad = [tf.ones((2,10)), tf.ones((3,))] # an example of what a gradient from tape.gradient can look like
@tf.function
def flatten_grad1(grad):
temp = [None]*len(grad)
for i, g in enumerate(grad):
temp[i] = tf.reshape(g, (tf.math.reduce_prod(tf.shape(g)), ))
return tf.concat(temp, axis=0)
print(flatten_grad1(grad))
tf.Tensor([1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], shape=(23,), dtype=float32)
With tf.TensorArray
:
@tf.function
def flatten_grad2(grad):
temp = tf.TensorArray(tf.float32, size=0, dynamic_size=True, infer_shape=False)
for g in grad:
temp = temp.write(temp.size(), tf.reshape(g, (tf.math.reduce_prod(tf.shape(g)), )))
return temp.concat()
print(flatten_grad2(grad))