Shuffling along a given axis in PyTorch

Solution 1:

You can use torch.randperm.

For tensor t, you can use:

t[:,torch.randperm(t.shape[1]),:]

For your example:

>>> t = torch.tensor([[[1,1],[2,2],[3,3]],[[4,4],[5,5],[6,6]],[[7,7],[8,8],[9,9]]])
>>> t
tensor([[[1, 1],
         [2, 2],
         [3, 3]],

        [[4, 4],
         [5, 5],
         [6, 6]],

        [[7, 7],
         [8, 8],
         [9, 9]]])
>>> t[:,torch.randperm(t.shape[1]),:]
tensor([[[2, 2],
         [3, 3],
         [1, 1]],

        [[5, 5],
         [6, 6],
         [4, 4]],

        [[8, 8],
         [9, 9],
         [7, 7]]])