Comparing numpy arrays containing NaN

For my unittest, I want to check if two arrays are identical. Reduced example:

a = np.array([1, 2, np.NaN])
b = np.array([1, 2, np.NaN])

if np.all(a==b):
    print 'arrays are equal'

This does not work because nan != nan. What is the best way to proceed?


For versions of numpy prior to 1.19, this is probably the best approach in situations that don't specifically involve unit tests:

>>> ((a == b) | (numpy.isnan(a) & numpy.isnan(b))).all()
True

However, modern versions provide the array_equal function with a new keyword argument, equal_nan, which fits the bill exactly.

This was first pointed out by flyingdutchman; see his answer below for details.


Alternatively you can use numpy.testing.assert_equal or numpy.testing.assert_array_equal with a try/except:

In : import numpy as np

In : def nan_equal(a,b):
...:     try:
...:         np.testing.assert_equal(a,b)
...:     except AssertionError:
...:         return False
...:     return True

In : a=np.array([1, 2, np.NaN])

In : b=np.array([1, 2, np.NaN])

In : nan_equal(a,b)
Out: True

In : a=np.array([1, 2, np.NaN])

In : b=np.array([3, 2, np.NaN])

In : nan_equal(a,b)
Out: False

Edit

Since you are using this for unittesting, bare assert (instead of wrapping it to get True/False) might be more natural.


The easiest way is use numpy.allclose() method, which allow to specify the behaviour when having nan values. Then your example will look like the following:

a = np.array([1, 2, np.nan])
b = np.array([1, 2, np.nan])

if np.allclose(a, b, equal_nan=True):
    print('arrays are equal')

Then arrays are equal will be printed.

You can find here the related documentation


You could use numpy masked arrays, mask the NaN values and then use numpy.ma.all or numpy.ma.allclose:

http://docs.scipy.org/doc/numpy/reference/generated/numpy.ma.all.html

http://docs.scipy.org/doc/numpy/reference/generated/numpy.ma.allclose.html

For example:

a=np.array([1, 2, np.NaN])
b=np.array([1, 2, np.NaN])
np.ma.all(np.ma.masked_invalid(a) == np.ma.masked_invalid(b)) #True