How to set value at Tensor index in batch
I have a tensor of batch size N.
t = [[...], [....], [....] .... ]
In second tensor indices, I have N indices of elements I want to change in each tensor
indices = [i0, i1, i2 .... ]
So I want to have t0
created from t
via:
t0 = [[ set X at i0 ], [ set X at i1 ], [ set X at i2 ] .... ]
How can I do this at Torch?
It seems like you're looking for the following:
t[torch.arange(N),indices]
As an example:
import torch
a = torch.zeros((3,3))
a[torch.arange(3),[0,2,1]] = 0.2
print(a)
Output:
tensor([[0.2000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.2000],
[0.0000, 0.2000, 0.0000]])
Note: This behavior is the same as NumPy's integer array indexing