How can I retrieve elements in a multidimensional pytorch tensor by a list of indices?

Solution 1:

I imagine you tried something like

indices = scores.argmax(dim=1)
selection = lists[:, indices]

This does not work because the indices are selected for every element in dimension 0, so the final shape is (x, x, 4).

The perform the correct selection you need to replace the slice with a range.

indices = scores.argmax(dim=1)
selection = lists[range(indices.size(0)), indices]