faster alternative to numpy.where?

I have a 3d array filled with integers from 0 to N. I need a list of the indices corresponding to where the array is equal 1, 2, 3, ... N. I can do it with np.where as follows:

N = 300
shape = (1000,1000,10)
data = np.random.randint(0,N+1,shape)
indx = [np.where(data == i_id) for i_id in range(1,data.max()+1)]

but this is quite slow. According to this question fast python numpy where functionality? it should be possible to speed up the index search quite a lot, but I haven't been able to transfer the methods proposed there to my problem of getting the actual indices. What would be the best way to speed up the above code?

As an add-on: I want to store the indices later, for which it makes sense to use np.ravel_multi_index to reduce the size from saving 3 indices to only 1, i.e. using:

indx = [np.ravel_multi_index(np.where(data == i_id), data.shape) for i_id in range(1, data.max()+1)]

which is closer to e.g. Matlab's find function. Can this be directly incorporated in a solution that doesn't use np.where?


I think that a standard vectorized approach to this problem would end up being very memory intensive – for int64 data, it would require O(8 * N * data.size) bytes, or ~22 gigs of memory for the example you gave above. I'm assuming that is not an option.

You might make some progress by using a sparse matrix to store the locations of the unique values. For example:

import numpy as np
from scipy.sparse import csr_matrix

def compute_M(data):
    cols = np.arange(data.size)
    return csr_matrix((cols, (data.ravel(), cols)),
                      shape=(data.max() + 1, data.size))

def get_indices_sparse(data):
    M = compute_M(data)
    return [np.unravel_index(row.data, data.shape) for row in M]

This takes advantage of fast code within the sparse matrix constructor to organize the data in a useful way, constructing a sparse matrix where row i contains just the indices where the flattened data equals i.

To test it out, I'll also define a function that does your straightforward method:

def get_indices_simple(data):
    return [np.where(data == i) for i in range(0, data.max() + 1)]

The two functions give the same results for the same input:

data_small = np.random.randint(0, 100, size=(100, 100, 10))
all(np.allclose(i1, i2)
    for i1, i2 in zip(get_indices_simple(data_small),
                      get_indices_sparse(data_small)))
# True

And the sparse method is an order of magnitude faster than the simple method for your dataset:

data = np.random.randint(0, 301, size=(1000, 1000, 10))

%time ind = get_indices_simple(data)
# CPU times: user 14.1 s, sys: 638 ms, total: 14.7 s
# Wall time: 14.8 s

%time ind = get_indices_sparse(data)
# CPU times: user 881 ms, sys: 301 ms, total: 1.18 s
# Wall time: 1.18 s

%time M = compute_M(data)
# CPU times: user 216 ms, sys: 148 ms, total: 365 ms
# Wall time: 363 ms

The other benefit of the sparse method is that the matrix M ends up being a very compact and efficient way to store all the relevant information for later use, as mentioned in the add-on part of your question. Hope that's useful!


Edit: I realized there was a bug in the initial version: it failed if any values in the range didn't appear in the data: that's now fixed above.


I was mulling on this and realized that there's a more intuitive (but slightly slower) approach to solving this using Pandas groupby(). Consider this:

import numpy as np
import pandas as pd

def get_indices_pandas(data):
    d = data.ravel()
    f = lambda x: np.unravel_index(x.index, data.shape)
    return pd.Series(d).groupby(d).apply(f)

This returns the same result as get_indices_simple from my previous answer:

data_small = np.random.randint(0, 100, size=(100, 100, 10))
all(np.allclose(i1, i2)
    for i1, i2 in zip(get_indices_simple(data_small),
                      get_indices_pandas(data_small)))
# True

And this Pandas approach is just slightly slower than the less intuitive matrix approach:

data = np.random.randint(0, 301, size=(1000, 1000, 10))

%time ind = get_indices_simple(data)
# CPU times: user 14.2 s, sys: 665 ms, total: 14.8 s
# Wall time: 14.9 s

%time ind = get_indices_sparse(data)
# CPU times: user 842 ms, sys: 277 ms, total: 1.12 s
# Wall time: 1.12 s

%time ind = get_indices_pandas(data)
# CPU times: user 1.16 s, sys: 326 ms, total: 1.49 s
# Wall time: 1.49 s

Here's one vectorized approach -

# Mask of matches for data elements against all IDs from 1 to data.max()
mask = data == np.arange(1,data.max()+1)[:,None,None,None]

# Indices of matches across all IDs and their linear indices
idx = np.argwhere(mask.reshape(N,-1))

# Get cut indices where IDs shift
_,cut_idx = np.unique(idx[:,0],return_index=True)

# Cut at shifts to give us the final indx output
out = np.hsplit(idx[:,1],cut_idx[1:])