Plotting images side by side using matplotlib

I was wondering how I am able to plot images side by side using matplotlib for example something like this:

enter image description here

The closest I got is this:

enter image description here

This was produced by using this code:

f, axarr = plt.subplots(2,2)
axarr[0,0] = plt.imshow(image_datas[0])
axarr[0,1] = plt.imshow(image_datas[1])
axarr[1,0] = plt.imshow(image_datas[2])
axarr[1,1] = plt.imshow(image_datas[3])

But I can't seem to get the other images to show. I'm thinking that there must be a better way to do this as I would imagine trying to manage the indexes would be a pain. I have looked through the documentation although I have a feeling I may be look at the wrong one. Would anyone be able to provide me with an example or point me in the right direction?

EDIT:

See the answer from @duhaime if you want a function to automatically determine the grid size.


Solution 1:

The problem you face is that you try to assign the return of imshow (which is an matplotlib.image.AxesImage to an existing axes object.

The correct way of plotting image data to the different axes in axarr would be

f, axarr = plt.subplots(2,2)
axarr[0,0].imshow(image_datas[0])
axarr[0,1].imshow(image_datas[1])
axarr[1,0].imshow(image_datas[2])
axarr[1,1].imshow(image_datas[3])

The concept is the same for all subplots, and in most cases the axes instance provide the same methods than the pyplot (plt) interface. E.g. if ax is one of your subplot axes, for plotting a normal line plot you'd use ax.plot(..) instead of plt.plot(). This can actually be found exactly in the source from the page you link to.

Solution 2:

One thing that I found quite helpful to use to print all images :

_, axs = plt.subplots(n_row, n_col, figsize=(12, 12))
axs = axs.flatten()
for img, ax in zip(imgs, axs):
    ax.imshow(img)
plt.show()