How to make a Tensor that has the number of times its indexes are in another Tensor?

Solution 1:

Found the answer! Need to use tf.math.bincount:

A = tf.math.bincount(I, axis=-1)

Note, if you need the second dimension of A being max_indexes, you can pad like this:

A = tf.math.bincount(a, axis=-1, minlength=max_indexes, maxlength=max_indexes)