How to plot multiple Seaborn Jointplot in Subplot
Solution 1:
Moving axes in matplotlib is not as easy as it used to be in previous versions. The below is working with the current version of matplotlib.
As has been pointed out at several places (this question, also this issue) several of the seaborn commands create their own figure automatically. This is hardcoded into the seaborn code, so there is currently no way to produce such plots in existing figures. Those are PairGrid
, FacetGrid
, JointGrid
, pairplot
, jointplot
and lmplot
.
There is a seaborn fork available which would allow to supply a subplot grid to the respective classes such that the plot is created in a preexisting figure. To use this, you would need to copy the axisgrid.py
from the fork to the seaborn folder. Note that this is currently restricted to be used with matplotlib 2.1 (possibly 2.0 as well).
An alternative could be to create a seaborn figure and copy the axes to another figure. The principle of this is shown in this answer and could be extended to Searborn plots. The implementation is a bit more complicated that I had initially expected. The following is a class SeabornFig2Grid
that can be called with a seaborn grid instance (the return of any of the above commands), a matplotlib figure and a subplot_spec
, which is a position of a gridspec
grid.
Note: This is a proof of concept, it may work for most easy cases, but I would not recommend using it in production code.
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
import numpy as np
class SeabornFig2Grid():
def __init__(self, seaborngrid, fig, subplot_spec):
self.fig = fig
self.sg = seaborngrid
self.subplot = subplot_spec
if isinstance(self.sg, sns.axisgrid.FacetGrid) or \
isinstance(self.sg, sns.axisgrid.PairGrid):
self._movegrid()
elif isinstance(self.sg, sns.axisgrid.JointGrid):
self._movejointgrid()
self._finalize()
def _movegrid(self):
""" Move PairGrid or Facetgrid """
self._resize()
n = self.sg.axes.shape[0]
m = self.sg.axes.shape[1]
self.subgrid = gridspec.GridSpecFromSubplotSpec(n,m, subplot_spec=self.subplot)
for i in range(n):
for j in range(m):
self._moveaxes(self.sg.axes[i,j], self.subgrid[i,j])
def _movejointgrid(self):
""" Move Jointgrid """
h= self.sg.ax_joint.get_position().height
h2= self.sg.ax_marg_x.get_position().height
r = int(np.round(h/h2))
self._resize()
self.subgrid = gridspec.GridSpecFromSubplotSpec(r+1,r+1, subplot_spec=self.subplot)
self._moveaxes(self.sg.ax_joint, self.subgrid[1:, :-1])
self._moveaxes(self.sg.ax_marg_x, self.subgrid[0, :-1])
self._moveaxes(self.sg.ax_marg_y, self.subgrid[1:, -1])
def _moveaxes(self, ax, gs):
#https://stackoverflow.com/a/46906599/4124317
ax.remove()
ax.figure=self.fig
self.fig.axes.append(ax)
self.fig.add_axes(ax)
ax._subplotspec = gs
ax.set_position(gs.get_position(self.fig))
ax.set_subplotspec(gs)
def _finalize(self):
plt.close(self.sg.fig)
self.fig.canvas.mpl_connect("resize_event", self._resize)
self.fig.canvas.draw()
def _resize(self, evt=None):
self.sg.fig.set_size_inches(self.fig.get_size_inches())
The usage of this class would look like this:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns; sns.set()
import SeabornFig2Grid as sfg
iris = sns.load_dataset("iris")
tips = sns.load_dataset("tips")
# An lmplot
g0 = sns.lmplot(x="total_bill", y="tip", hue="smoker", data=tips,
palette=dict(Yes="g", No="m"))
# A PairGrid
g1 = sns.PairGrid(iris, hue="species")
g1.map(plt.scatter, s=5)
# A FacetGrid
g2 = sns.FacetGrid(tips, col="time", hue="smoker")
g2.map(plt.scatter, "total_bill", "tip", edgecolor="w")
# A JointGrid
g3 = sns.jointplot("sepal_width", "petal_length", data=iris,
kind="kde", space=0, color="g")
fig = plt.figure(figsize=(13,8))
gs = gridspec.GridSpec(2, 2)
mg0 = sfg.SeabornFig2Grid(g0, fig, gs[0])
mg1 = sfg.SeabornFig2Grid(g1, fig, gs[1])
mg2 = sfg.SeabornFig2Grid(g2, fig, gs[3])
mg3 = sfg.SeabornFig2Grid(g3, fig, gs[2])
gs.tight_layout(fig)
#gs.update(top=0.7)
plt.show()
Note that there might be several drawbacks from copying axes and the above is not (yet) tested thoroughly.
Solution 2:
It can not be easily done without hacking. jointplot
calls JointGrid
method, which in turn creates a new figure
object every time it is called.
Therefore, the hack is to make two jointplots (JG1
JG2
), then make a new figure, then migrate the axes objects from JG1
JG2
to the new figure created.
Finally, we adjust the sizes and the positions of subplots in the new figure we just created.
JG1 = sns.jointplot("C1", "C2", data=df, kind='reg')
JG2 = sns.jointplot("C1", "C2", data=df, kind='kde')
#subplots migration
f = plt.figure()
for J in [JG1, JG2]:
for A in J.fig.axes:
f._axstack.add(f._make_key(A), A)
#subplots size adjustment
f.axes[0].set_position([0.05, 0.05, 0.4, 0.4])
f.axes[1].set_position([0.05, 0.45, 0.4, 0.05])
f.axes[2].set_position([0.45, 0.05, 0.05, 0.4])
f.axes[3].set_position([0.55, 0.05, 0.4, 0.4])
f.axes[4].set_position([0.55, 0.45, 0.4, 0.05])
f.axes[5].set_position([0.95, 0.05, 0.05, 0.4])
It is a hack because we are now using _axstack
and _add_key
private methods, which might and might not stay the same as they are now in matplotlib
future versions.
Solution 3:
If you get into trouble despite the elegant solution of @ImportanceOfBeingErnest, you can still save seaborn plots to memory as images and use them to build your custom figure. Use other formats than '.png' if you seek a higher resolution.
Here is the example is shown above using this nasty (but working) approach:
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import seaborn as sns
# data
iris = sns.load_dataset("iris")
tips = sns.load_dataset("tips")
############### 1. CREATE PLOTS
# An lmplot
g0 = sns.lmplot(x="total_bill", y="tip", hue="smoker", data=tips,
palette=dict(Yes="g", No="m"))
# A PairGrid
g1 = sns.PairGrid(iris, hue="species")
g1.map(plt.scatter, s=5)
# A FacetGrid
g2 = sns.FacetGrid(tips, col="time", hue="smoker")
g2.map(plt.scatter, "total_bill", "tip", edgecolor="w")
# A JointGrid
g3 = sns.jointplot("sepal_width", "petal_length", data=iris,
kind="kde", space=0, color="g")
############### 2. SAVE PLOTS IN MEMORY TEMPORALLY
g0.savefig('g0.png')
plt.close(g0.fig)
g1.savefig('g1.png')
plt.close(g1.fig)
g2.savefig('g2.png')
plt.close(g2.fig)
g3.savefig('g3.png')
plt.close(g3.fig)
############### 3. CREATE YOUR SUBPLOTS FROM TEMPORAL IMAGES
f, axarr = plt.subplots(2, 2, figsize=(25, 16))
axarr[0,0].imshow(mpimg.imread('g0.png'))
axarr[0,1].imshow(mpimg.imread('g1.png'))
axarr[1,0].imshow(mpimg.imread('g3.png'))
axarr[1,1].imshow(mpimg.imread('g2.png'))
# turn off x and y axis
[ax.set_axis_off() for ax in axarr.ravel()]
plt.tight_layout()
plt.show()
Solution 4:
Recently, I'm developing patchworklib, which is a subplot manager for matplotlib, inspired by patchwork.
It allows you to quickly arrange multiple gridded seaborn plots using only /
and |
operators.
Here is the example code, which you can also run on Google colab:
import seaborn as sns
import patchworklib as pw
sns.set_theme()
pw.overwrite_axisgrid()
iris = sns.load_dataset("iris")
tips = sns.load_dataset("tips")
# An lmplot
g0 = sns.lmplot(x="total_bill", y="tip", hue="smoker", data=tips,
palette=dict(Yes="g", No="m"))
g0 = pw.load_seaborngrid(g0, label="g0")
# A Pairplot
g1 = sns.pairplot(iris, hue="species")
g1 = pw.load_seaborngrid(g1, label="g1")
# A relplot
g2 = sns.relplot(data=tips, x="total_bill", y="tip", col="time", hue="time",
size="size", style="sex", palette=["b", "r"], sizes=(10, 100))
g2 = pw.load_seaborngrid(g2, label="g2")
# A JointGrid
g3 = sns.jointplot("sepal_width", "petal_length", data=iris,
kind="kde", space=0, color="g")
g3 = pw.load_seaborngrid(g3, label="g3")
(((g0|g1)["g0"]/g3)["g3"]|g2).savefig("seaborn_subplots.png")