How can I retrieve elements in a multidimensional pytorch tensor by a list of indices?
Solution 1:
I imagine you tried something like
indices = scores.argmax(dim=1)
selection = lists[:, indices]
This does not work because the indices are selected for every element in dimension 0, so the final shape is (x, x, 4)
.
The perform the correct selection you need to replace the slice with a range.
indices = scores.argmax(dim=1)
selection = lists[range(indices.size(0)), indices]