Numpy: get the first different value along axis

I have a 2D array, and I would like to know, for each "row", what is the index of the first value that differs from the first value of each row.

Here is an example:

arr = np.array([[0, 0, 1, 1, 1], 
                [2, 2, 2, 3, 3],
                [9, nan, nan, 8, 8],
                [5, 5, 5, 5, 0]])

The solution I'm looking for would yield:

array([2, 3, 1, 4])

since in the first row, the first value different from 0 is the 3rd, in the second row, the first value different from 2 is the 4th, and so on.

Thanks for your help !


Solution 1:

Just compare all the values in the array to the first column and extract the index of the first true value.

(arr != arr[:,0][:, None]).argmax(1)

Out[]: array([2, 3, 1, 4], dtype=int64)