How to select rows from a 3-D Tensor in TensorFlow?
This is possible in TensorFlow, but slightly inconvenient, because tf.gather()
currently only works with one-dimensional indices, and only selects slices from the 0th dimension of a tensor. However, it is still possible to solve your problem efficiently, by transforming the arguments so that they can be passed to tf.gather()
:
logits = ... # [2 x 4 x 4] tensor
indices = tf.constant([[0, 1], [1, 3]])
# Use tf.shape() to make this work with dynamic shapes.
batch_size = tf.shape(logits)[0]
rows_per_batch = tf.shape(logits)[1]
indices_per_batch = tf.shape(indices)[1]
# Offset to add to each row in indices. We use `tf.expand_dims()` to make
# this broadcast appropriately.
offset = tf.expand_dims(tf.range(0, batch_size) * rows_per_batch, 1)
# Convert indices and logits into appropriate form for `tf.gather()`.
flattened_indices = tf.reshape(indices + offset, [-1])
flattened_logits = tf.reshape(logits, tf.concat(0, [[-1], tf.shape(logits)[2:]]))
selected_rows = tf.gather(flattened_logits, flattened_indices)
result = tf.reshape(selected_rows,
tf.concat(0, [tf.pack([batch_size, indices_per_batch]),
tf.shape(logits)[2:]]))
Note that, since this uses tf.reshape()
and not tf.transpose()
, it doesn't need to modify the (potentially large) data in the logits
tensor, so it should be fairly efficient.
mrry's answer is great, but I think with the function tf.gather_nd
the problem can be solved with much fewer lines of code (probably this function was not yet available at the time of mrry's writing):
logits = tf.constant([[[10.0, 10.0, 20.0, 20.0],
[11.0, 10.0, 10.0, 30.0],
[12.0, 10.0, 10.0, 20.0],
[13.0, 10.0, 10.0, 20.0]],
[[14.0, 11.0, 21.0, 31.0],
[15.0, 11.0, 11.0, 21.0],
[16.0, 11.0, 11.0, 21.0],
[17.0, 11.0, 11.0, 21.0]]])
indices = tf.constant([[[0, 0], [0, 1]], [[1, 1], [1, 3]]])
result = tf.gather_nd(logits, indices)
with tf.Session() as sess:
print(sess.run(result))
This will print
[[[ 10. 10. 20. 20.]
[ 11. 10. 10. 30.]]
[[ 15. 11. 11. 21.]
[ 17. 11. 11. 21.]]]
tf.gather_nd
should be available as of v0.10. Check out this github issue for more discussions on this.