Python numpy sort array according to another array and broadcast over an axis
I have done this a million times before - sorting one array according to another. But this time it is just slightly more complicated and I have been stumped how to do it. Let me explain. I have two arrays, say A
:
[[1.59956565 1.16421459]
[1.21548342 1.63884363]
[0.73023302 0.54681896]
[2.02628432 1.32127994]
[0.2132793 0.26559821]
[0.38242608 0.30073228]]
and B
:
[[[ 0.93634073 0.35109262]
[-0.35109262 0.93634073]]
[[-0.63561769 0.77200398]
[ 0.77200398 0.63561769]]
[[ 0.8331935 0.55298155]
[-0.55298155 0.8331935 ]]
[[ 0.96691332 0.25510513]
[-0.25510513 0.96691332]]
[[-0.41372983 0.91039971]
[ 0.91039971 0.41372983]]
[[ 0.84228545 0.53903174]
[-0.53903174 0.84228545]]]
i.e., B
's dimension is 1 more than A
's. I want to fist sore A
along the last axis:
[[1.16421459 1.59956565]
[1.21548342 1.63884363]
[0.54681896 0.73023302]
[1.32127994 2.02628432]
[0.2132793 0.26559821]
[0.30073228 0.38242608]]
and then also sort the middle axis of B
according to this sort, e.g. B
should become:
[[[-0.35109262 0.93634073]
[ 0.93634073 0.35109262]]
[[-0.63561769 0.77200398]
[ 0.77200398 0.63561769]]
[[-0.55298155 0.8331935 ]
[ 0.8331935 0.55298155]]
[[-0.25510513 0.96691332]
[ 0.96691332 0.25510513]]
[[-0.41372983 0.91039971]
[ 0.91039971 0.41372983]]
[[-0.53903174 0.84228545]
[ 0.84228545 0.53903174]]]
How can I do this with a view or a slice using the argsort
of A
? I tried but got nowhere, because there is one more axis in B
.
Here is one way using take_along_axis
:
np.take_along_axis(B, A.argsort(1)[:, :, None], axis=1)
Output:
array([[[-0.35109262, 0.93634073],
[ 0.93634073, 0.35109262]],
[[-0.63561769, 0.77200398],
[ 0.77200398, 0.63561769]],
[[-0.55298155, 0.8331935 ],
[ 0.8331935 , 0.55298155]],
[[-0.25510513, 0.96691332],
[ 0.96691332, 0.25510513]],
[[-0.41372983, 0.91039971],
[ 0.91039971, 0.41372983]],
[[-0.53903174, 0.84228545],
[ 0.84228545, 0.53903174]]])