Can someone explain exactly what the axis parameter in NumPy does?

I am terribly confused.

I'm trying to use the function myArray.sum(axis=num)

At first I thought if the array is itself 3 dimensions, axis=0 will return three elements, consisting of the sum of all nested items in that same position. If each dimension contained five dimensions, I expected axis=1 to return a result of five items, and so on.

However this is not the case, and the documentation does not do a good job helping me out (they use a 3x3x3 array so it's hard to tell what's happening)

Here's what I did:

>>> e
array([[[1, 0],
        [0, 0]],

       [[1, 1],
        [1, 0]],

       [[1, 0],
        [0, 1]]])
>>> e.sum(axis = 0)
array([[3, 1],
       [1, 1]])
>>> e.sum(axis=1)
array([[1, 0],
       [2, 1],
       [1, 1]])
>>> e.sum(axis=2)
array([[1, 0],
       [2, 1],
       [1, 1]])
>>>

Clearly the result is not intuitive.


Clearly,

e.shape == (3, 2, 2)

Sum over an axis is a reduction operation so the specified axis disappears. Hence,

e.sum(axis=0).shape == (2, 2)
e.sum(axis=1).shape == (3, 2)
e.sum(axis=2).shape == (3, 2)

Intuitively, we are "squashing" the array along the chosen axis, and summing the numbers that get squashed together.


To understand the axis intuitively, refer the picture below (source: Physics Dept, Cornell Uni)

enter image description here

The shape of the (boolean) array in the above figure is shape=(8, 3). ndarray.shape will return a tuple where the entries correspond to the length of the particular dimension. In our example, 8 corresponds to length of axis 0 whereas 3 corresponds to length of axis 1.