How does the axis parameter from NumPy work?
Can someone explain exactly what the axis
parameter in NumPy does?
I am terribly confused.
I'm trying to use the function myArray.sum(axis=num)
At first I thought if the array is itself 3 dimensions, axis=0
will return three elements, consisting of the sum of all nested items in that same position. If each dimension contained five dimensions, I expected axis=1
to return a result of five items, and so on.
However this is not the case, and the documentation does not do a good job helping me out (they use a 3x3x3 array so it's hard to tell what's happening)
Here's what I did:
>>> e
array([[[1, 0],
[0, 0]],
[[1, 1],
[1, 0]],
[[1, 0],
[0, 1]]])
>>> e.sum(axis = 0)
array([[3, 1],
[1, 1]])
>>> e.sum(axis=1)
array([[1, 0],
[2, 1],
[1, 1]])
>>> e.sum(axis=2)
array([[1, 0],
[2, 1],
[1, 1]])
>>>
Clearly the result is not intuitive.
Clearly,
e.shape == (3, 2, 2)
Sum over an axis is a reduction operation so the specified axis disappears. Hence,
e.sum(axis=0).shape == (2, 2)
e.sum(axis=1).shape == (3, 2)
e.sum(axis=2).shape == (3, 2)
Intuitively, we are "squashing" the array along the chosen axis, and summing the numbers that get squashed together.
To understand the axis
intuitively, refer the picture below (source: Physics Dept, Cornell Uni)
The shape of the (boolean) array in the above figure is shape=(8, 3)
. ndarray.shape will return a tuple where the entries correspond to the length of the particular dimension. In our example, 8
corresponds to length of axis 0 whereas 3
corresponds to length of axis 1.