Efficiently count zero elements in numpy array?

Solution 1:

A 2x faster approach would be to just use np.count_nonzero() but with the condition as needed.

In [3]: arr
Out[3]: 
array([[1, 2, 0, 3],
      [3, 9, 0, 4]])

In [4]: np.count_nonzero(arr==0)
Out[4]: 2

In [5]:def func_cnt():
            for arr in arrs:
                zero_els = np.count_nonzero(arr==0)
                # here, it counts the frequency of zeroes actually

You can also use np.where() but it's slower than np.count_nonzero()

In [6]: np.where( arr == 0)
Out[6]: (array([0, 1]), array([2, 2]))

In [7]: len(np.where( arr == 0))
Out[7]: 2

Efficiency: (in descending order)

In [8]: %timeit func_cnt()
10 loops, best of 3: 29.2 ms per loop

In [9]: %timeit func1()
10 loops, best of 3: 46.5 ms per loop

In [10]: %timeit func_where()
10 loops, best of 3: 61.2 ms per loop

more speedups with accelerators

It is now possible to achieve more than 3 orders of magnitude speed boost with the help of JAX if you've access to accelerators (GPU/TPU). Another advantage of using JAX is that the NumPy code needs very little modification to make it JAX compatible. Below is a reproducible example:

In [1]: import jax.numpy as jnp
In [2]: from jax import jit

# set up inputs
In [3]: arrs = []
In [4]: for _ in range(1000):
   ...:     arrs.append(np.random.randint(-5, 5, 10000))

# JIT'd function that performs the counting task
In [5]: @jit
   ...: def func_cnt():
   ...:     for arr in arrs:
   ...:         zero_els = jnp.count_nonzero(arr==0)

# efficiency test
In [8]: %timeit func_cnt()
15.6 µs ± 391 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)