Matplotlib: save plot to numpy array
In Python and Matplotlib, it is easy to either display the plot as a popup window or save the plot as a PNG file. How can I instead save the plot to a numpy array in RGB format?
Solution 1:
This is a handy trick for unit tests and the like, when you need to do a pixel-to-pixel comparison with a saved plot.
One way is to use fig.canvas.tostring_rgb
and then numpy.fromstring
with the approriate dtype. There are other ways as well, but this is the one I tend to use.
E.g.
import matplotlib.pyplot as plt
import numpy as np
# Make a random plot...
fig = plt.figure()
fig.add_subplot(111)
# If we haven't already shown or saved the plot, then we need to
# draw the figure first...
fig.canvas.draw()
# Now we can save it to a numpy array.
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
Solution 2:
There is a bit simpler option for @JUN_NETWORKS's answer. Instead of saving the figure in png
, one can use other format, like raw
or rgba
and skip the cv2
decoding step.
In other words the actual plot-to-numpy conversion boils down to:
io_buf = io.BytesIO()
fig.savefig(io_buf, format='raw', dpi=DPI)
io_buf.seek(0)
img_arr = np.reshape(np.frombuffer(io_buf.getvalue(), dtype=np.uint8),
newshape=(int(fig.bbox.bounds[3]), int(fig.bbox.bounds[2]), -1))
io_buf.close()
Hope, this helps.
Solution 3:
Some people propose a method which is like this
np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
Ofcourse, this code work. But, output numpy array image is so low resolution.
My proposal code is this.
import io
import cv2
import numpy as np
import matplotlib.pyplot as plt
# plot sin wave
fig = plt.figure()
ax = fig.add_subplot(111)
x = np.linspace(-np.pi, np.pi)
ax.set_xlim(-np.pi, np.pi)
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.plot(x, np.sin(x), label="sin")
ax.legend()
ax.set_title("sin(x)")
# define a function which returns an image as numpy array from figure
def get_img_from_fig(fig, dpi=180):
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=dpi)
buf.seek(0)
img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
buf.close()
img = cv2.imdecode(img_arr, 1)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
# you can get a high-resolution image as numpy array!!
plot_img_np = get_img_from_fig(fig)
This code works well.
You can get a high-resolution image as a numpy array if you set a large number on the dpi argument.
Solution 4:
Time to benchmark your solutions.
import io
import matplotlib
matplotlib.use('agg') # turn off interactive backend
import matplotlib.pyplot as plt
import numpy as np
fig, ax = plt.subplots()
ax.plot(range(10))
def plot1():
fig.canvas.draw()
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
w, h = fig.canvas.get_width_height()
im = data.reshape((int(h), int(w), -1))
def plot2():
with io.BytesIO() as buff:
fig.savefig(buff, format='png')
buff.seek(0)
im = plt.imread(buff)
def plot3():
with io.BytesIO() as buff:
fig.savefig(buff, format='raw')
buff.seek(0)
data = np.frombuffer(buff.getvalue(), dtype=np.uint8)
w, h = fig.canvas.get_width_height()
im = data.reshape((int(h), int(w), -1))
>>> %timeit plot1()
34 ms ± 4.16 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
>>> %timeit plot2()
50.2 ms ± 234 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
>>> %timeit plot3()
16.4 ms ± 36 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Under this scenario, IO raw buffers are the fastest to convert a matplotlib figure to a numpy array.
Additional remarks:
-
if you don't have an access to the figure, you can always extract it from the axes:
fig = ax.figure
-
if you need the array in the
channel x height x width
format, doim = im.transpose((2, 0, 1))
.