How to Do Numpy Like Index Selection in Tensorflow?
Solution 1:
The simplest solution would be to use tf.concat
, although it is probably not so efficient:
import numpy as np
import tensorflow as tf
sample_array = np.random.uniform(size=(2, 2, 20))
to_select = [5, 6, 9, 4]
sample_tensor = tf.convert_to_tensor(value = sample_array)
numpy_way = sample_array[:, :, to_select]
tf_way = tf.concat([tf.expand_dims(sample_array[:, :, to_select[i]], axis=-1) for i in tf.range(len(to_select))], axis=-1)
#tf_way = tf.concat([tf.expand_dims(sample_array[:, :, s], axis=-1) for s in to_select], axis=-1)
print(numpy_way)
print(tf_way)
[[[0.81208086 0.03873406 0.89959868 0.97896671]
[0.57569184 0.33659472 0.32566287 0.58383079]]
[[0.59984846 0.43405048 0.42366314 0.25505199]
[0.16180442 0.5903358 0.21302399 0.86569914]]]
tf.Tensor(
[[[0.81208086 0.03873406 0.89959868 0.97896671]
[0.57569184 0.33659472 0.32566287 0.58383079]]
[[0.59984846 0.43405048 0.42366314 0.25505199]
[0.16180442 0.5903358 0.21302399 0.86569914]]], shape=(2, 2, 4), dtype=float64)
A more complicated, but efficient solution would involve using tf.meshgrid
and tf.gather_nd
. Check this post or this post and finally this. Here is an example based on your question:
to_select = tf.expand_dims(tf.constant([5, 6, 9, 4]), axis=0)
to_select_shape = tf.shape(to_select)
sample_tensor_shape = tf.shape(sample_tensor)
to_select = tf.expand_dims(tf.reshape(tf.tile(to_select, [1, to_select_shape[1]]), (sample_tensor_shape[0], sample_tensor_shape[0] * to_select_shape[1])), axis=-1)
ij = tf.stack(tf.meshgrid(
tf.range(sample_tensor_shape[0], dtype=tf.int32),
tf.range(sample_tensor_shape[1], dtype=tf.int32),
indexing='ij'), axis=-1)
gather_indices = tf.concat([tf.repeat(ij, repeats=to_select_shape[1], axis=1), to_select], axis=-1)
gather_indices = tf.reshape(gather_indices, (to_select_shape[1], to_select_shape[1], 3))
result = tf.gather_nd(sample_tensor, gather_indices, batch_dims=0)
result = tf.reshape(result, (result.shape[0]//2, result.shape[0]//2, result.shape[1]))
tf.Tensor(
[[[0.81208086 0.03873406 0.89959868 0.97896671]
[0.57569184 0.33659472 0.32566287 0.58383079]]
[[0.59984846 0.43405048 0.42366314 0.25505199]
[0.16180442 0.5903358 0.21302399 0.86569914]]], shape=(2, 2, 4), dtype=float64)