Sort invariant for numpy.argsort with multiple dimensions

Solution 1:

The numpy issue #8708 has a sample implementation of take_along_axis that does what I need; I'm not sure if it's efficient for large arrays but it seems to work.

def take_along_axis(arr, ind, axis):
    """
    ... here means a "pack" of dimensions, possibly empty

    arr: array_like of shape (A..., M, B...)
        source array
    ind: array_like of shape (A..., K..., B...)
        indices to take along each 1d slice of `arr`
    axis: int
        index of the axis with dimension M

    out: array_like of shape (A..., K..., B...)
        out[a..., k..., b...] = arr[a..., inds[a..., k..., b...], b...]
    """
    if axis < 0:
       if axis >= -arr.ndim:
           axis += arr.ndim
       else:
           raise IndexError('axis out of range')
    ind_shape = (1,) * ind.ndim
    ins_ndim = ind.ndim - (arr.ndim - 1)   #inserted dimensions

    dest_dims = list(range(axis)) + [None] + list(range(axis+ins_ndim, ind.ndim))

    # could also call np.ix_ here with some dummy arguments, then throw those results away
    inds = []
    for dim, n in zip(dest_dims, arr.shape):
        if dim is None:
            inds.append(ind)
        else:
            ind_shape_dim = ind_shape[:dim] + (-1,) + ind_shape[dim+1:]
            inds.append(np.arange(n).reshape(ind_shape_dim))

    return arr[tuple(inds)]

which yields

>>> A = np.array([[3,2,1],[4,0,6]])
>>> B = np.array([[3,1,4],[1,5,9]])
>>> i = A.argsort(axis=-1)
>>> take_along_axis(A,i,axis=-1)
array([[1, 2, 3],
       [0, 4, 6]])
>>> take_along_axis(B,i,axis=-1)
array([[4, 1, 3],
       [5, 1, 9]])

Solution 2:

This argsort produces a (3,2) array

In [453]: idx=np.argsort(A,axis=-1)
In [454]: idx
Out[454]: 
array([[0, 1],
       [1, 0],
       [0, 1]], dtype=int32)

As you note applying this to A to get the equivalent of np.sort(A, axis=-1) isn't obvious. The iterative solution is sort each row (a 1d case) with:

In [459]: np.array([x[i] for i,x in zip(idx,A)])
Out[459]: 
array([[-1.0856306 ,  0.99734545],
       [-1.50629471,  0.2829785 ],
       [-0.57860025,  1.65143654]])

While probably not the fastest, it is probably the clearest solution, and a good starting point for conceptualizing a better solution.

The tuple(inds) from the take solution is:

(array([[0],
        [1],
        [2]]), 
 array([[0, 1],
        [1, 0],
        [0, 1]], dtype=int32))
In [470]: A[_]
Out[470]: 
array([[-1.0856306 ,  0.99734545],
       [-1.50629471,  0.2829785 ],
       [-0.57860025,  1.65143654]])

In other words:

In [472]: A[np.arange(3)[:,None], idx]
Out[472]: 
array([[-1.0856306 ,  0.99734545],
       [-1.50629471,  0.2829785 ],
       [-0.57860025,  1.65143654]])

The first part is what np.ix_ would construct, but it does not 'like' the 2d idx.


Looks like I explored this topic a couple of years ago

argsort for a multidimensional ndarray

a[np.arange(np.shape(a)[0])[:,np.newaxis], np.argsort(a)]

I tried to explain what is going on. The take function does the same sort of thing, but constructs the indexing tuple for a more general case (dimensions and axis). Generalizing to more dimensions, but still with axis=-1 should be easy.

For the first axis, A[np.argsort(A,axis=0),np.arange(2)] works.