How to flatten a gradient (list of tensors) in tf.function (graph mode)

Solution 1:

Maybe you could try directly iterating over your list of tensors instead of getting individual tensors by their index:

import tensorflow as tf

grad = [tf.ones((2,10)), tf.ones((3,))]  # an example of what a gradient from tape.gradient can look like

@tf.function
def flatten_grad1(grad):
    temp = [None]*len(grad)
    for i, g in enumerate(grad):
        temp[i] = tf.reshape(g, (tf.math.reduce_prod(tf.shape(g)), ))
    return tf.concat(temp, axis=0)
print(flatten_grad1(grad))
tf.Tensor([1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], shape=(23,), dtype=float32)

With tf.TensorArray:

@tf.function
def flatten_grad2(grad):
    temp = tf.TensorArray(tf.float32, size=0, dynamic_size=True, infer_shape=False)
    for g in grad:
        temp = temp.write(temp.size(), tf.reshape(g, (tf.math.reduce_prod(tf.shape(g)), )))
    return temp.concat()

print(flatten_grad2(grad))