Python numpy sort array according to another array and broadcast over an axis

I have done this a million times before - sorting one array according to another. But this time it is just slightly more complicated and I have been stumped how to do it. Let me explain. I have two arrays, say A:

[[1.59956565 1.16421459]
 [1.21548342 1.63884363]
 [0.73023302 0.54681896]
 [2.02628432 1.32127994]
 [0.2132793  0.26559821]
 [0.38242608 0.30073228]]

and B:

[[[ 0.93634073  0.35109262]
  [-0.35109262  0.93634073]]

 [[-0.63561769  0.77200398]
  [ 0.77200398  0.63561769]]

 [[ 0.8331935   0.55298155]
  [-0.55298155  0.8331935 ]]

 [[ 0.96691332  0.25510513]
  [-0.25510513  0.96691332]]

 [[-0.41372983  0.91039971]
  [ 0.91039971  0.41372983]]

 [[ 0.84228545  0.53903174]
  [-0.53903174  0.84228545]]]

i.e., B's dimension is 1 more than A's. I want to fist sore A along the last axis:

[[1.16421459 1.59956565]
 [1.21548342 1.63884363]
 [0.54681896 0.73023302]
 [1.32127994 2.02628432]
 [0.2132793  0.26559821]
 [0.30073228 0.38242608]]

and then also sort the middle axis of B according to this sort, e.g. B should become:

[[[-0.35109262  0.93634073]
  [ 0.93634073  0.35109262]]

 [[-0.63561769  0.77200398]
  [ 0.77200398  0.63561769]]

 [[-0.55298155  0.8331935 ]
  [ 0.8331935   0.55298155]]

 [[-0.25510513  0.96691332]
  [ 0.96691332  0.25510513]]

 [[-0.41372983  0.91039971]
  [ 0.91039971  0.41372983]]

 [[-0.53903174  0.84228545]
  [ 0.84228545  0.53903174]]]

How can I do this with a view or a slice using the argsort of A? I tried but got nowhere, because there is one more axis in B.


Here is one way using take_along_axis:

np.take_along_axis(B, A.argsort(1)[:, :, None], axis=1)

Output:

array([[[-0.35109262,  0.93634073],
        [ 0.93634073,  0.35109262]],

       [[-0.63561769,  0.77200398],
        [ 0.77200398,  0.63561769]],

       [[-0.55298155,  0.8331935 ],
        [ 0.8331935 ,  0.55298155]],

       [[-0.25510513,  0.96691332],
        [ 0.96691332,  0.25510513]],

       [[-0.41372983,  0.91039971],
        [ 0.91039971,  0.41372983]],

       [[-0.53903174,  0.84228545],
        [ 0.84228545,  0.53903174]]])