这是一种无需调用 imshow() 即可直接在图像上绘图的方法。
它将绘图渲染为与输入具有相同形状的透明图像,然后将其 alpha 合成到输入图像上。
import numpy as np
import matplotlib.pyplot as plt
import imageio
from contextlib import contextmanager
@contextmanager
def plot_over(img, extent=None, origin="upper", dpi=100):
h, w, d = img.shape
assert d == 3
if extent is None:
xmin, xmax, ymin, ymax = -0.5, w + 0.5, -0.5, h + 0.5
else:
xmin, xmax, ymin, ymax = extent
if origin == "upper":
ymin, ymax = ymax, ymin
elif origin != "lower":
raise ValueError("origin must be 'upper' or 'lower'")
fig = plt.figure(figsize=(w / dpi, h / dpi), dpi=dpi)
ax = plt.Axes(fig, (0, 0, 1, 1))
ax.set_axis_off()
ax.set_xlim(xmin, xmax)
ax.set_ylim(ymin, ymax)
fig.add_axes(ax)
fig.set_facecolor((0, 0, 0, 0))
yield ax
fig.canvas.draw()
plot = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8).reshape(h, w, 4)
plt.close(fig)
rgb = plot[..., :3]
alpha = plot[..., 3, None]
img[...] = ((255 - alpha) * img.astype(np.uint16) + alpha * rgb.astype(np.uint16)) // 255
img = imageio.imread("image.jpg")
img_with_plot = img.copy()
with plot_over(img_with_plot) as ax:
ax.scatter(...)
# etc
imageio.imwrite("result.png", img_with_plot)