understanding matplotlib.subplots python [duplicate]
Solution 1:
The different return types are due to the squeeze
keyword argument to plt.subplots()
which is set to True
by default.
Let's enhance the documentation with the respective unpackings:
squeeze : bool, optional, default: True
If True, extra dimensions are squeezed out from the returned Axes object:
- if only one subplot is constructed (nrows=ncols=1), the resulting single Axes object is returned as a scalar.
fig, ax = plt.subplots()
- for Nx1 or 1xN subplots, the returned object is a 1D numpy object array of Axes objects are returned as numpy 1D arrays.
fig, (ax1, ..., axN) = plt.subplots(nrows=N, ncols=1)
(for Nx1)fig, (ax1, ..., axN) = plt.subplots(nrows=1, ncols=N)
(for 1xN)- for NxM, subplots with N>1 and M>1 are returned as a 2D arrays.
fig, ((ax11, .., ax1M),..,(axN1, .., axNM)) = plt.subplots(nrows=N, ncols=M)
- If False, no squeezing at all is done: the returned Axes object is always a 2D array containing Axes instances, even if it ends up being 1x1.
fig, ((ax,),) = plt.subplots(nrows=1, ncols=1, squeeze=False)
fig, ((ax,), .. ,(axN,)) = plt.subplots(nrows=N, ncols=1, squeeze=False)
for Nx1fig, ((ax, .. ,axN),) = plt.subplots(nrows=1, ncols=N, squeeze=False)
for 1xNfig, ((ax11, .., ax1M),..,(axN1, .., axNM)) = plt.subplots(nrows=N, ncols=M)
Alternatively you may always use the unpacked version
fig, ax_arr = plt.subplots(nrows=N, ncols=M, squeeze=False)
and index the array to obtain the axes, ax_arr[1,2].plot(..)
.
So for a 2 x 3 grid it wouldn't actually matter if you set squeeze
to False
. The result will always be a 2D array. You may unpack it as
fig, ((ax1, ax2, ax3),(ax4, ax5, ax6)) = plt.subplots(nrows=2, ncols=3)
to have ax{i}
as the matplotlib axes objects, or you may use the packed version
fig, ax_arr = plt.subplots(nrows=2, ncols=3)
ax_arr[0,0].plot(..) # plot to first top left axes
ax_arr[1,2].plot(..) # plot to last bottom right axes