【发布时间】:2021-10-31 18:02:17
【问题描述】:
我正在尝试使用自定义颜色图来显示 ConfusionMatrixDisplay 对象,使其在 0 和 50 之间的范围比使用 this answer 的 50 和 100 之间的范围更小。
from sklearn.datasets import make_classification
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.pyplot as plt
import numpy as np
plt.rcParams["figure.figsize"] = (15, 15)
font = {'family' : 'DejaVu Sans',
'weight' : 'bold',
'size' : 22}
plt.rc('font', **font)
class nlcmap(LinearSegmentedColormap):
def __init__(self, cmap, levels):
self.cmap = cmap
self.N = cmap.N
self.monochrome = self.cmap.monochrome
self.levels = np.asarray(levels, dtype='float64')
self._x = self.levels
self.levmax = self.levels.max()
self.transformed_levels = np.linspace(0.0, self.levmax, len(self.levels))
def __call__(self, xi, alpha=1.0, **kw):
yi = np.interp(xi, self._x, self.transformed_levels)
return self.cmap(yi / self.levmax, alpha)
levels = [0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 100]
cmap_nonlin = nlcmap(plt.cm.viridis, levels)
X, y = make_classification(random_state=0)
X_train, X_test, y_train, y_test = train_test_split(X, y,
random_state=0)
clf = SVC(random_state=0)
clf.fit(X_train, y_train)
SVC(random_state=0)
predictions = clf.predict(X_test)
cm = confusion_matrix(y_test, predictions, labels=clf.classes_)
disp = ConfusionMatrixDisplay(confusion_matrix=cm,
display_labels=clf.classes_)
lin_cmap = plt.cm.viridis
levels = [0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 100]
cmap_nonlin = nlcmap(plt.cm.viridis, levels)
fig, ax = plt.subplots()
im = disp.plot(cmap=cmap_nonlin, colorbar=False)
disp.ax_.get_images()[0].set_clim(0, 100)
disp.figure_.colorbar(disp.im_, orientation="horizontal", pad=0.1)
plt.savefig("test.png")
产生以下错误:
Traceback (most recent call last):
File "/Users/me/anaconda3/envs/myenv/lib/python3.6/site-packages/matplotlib/backends/backend_macosx.py", line 61, in _draw
self.figure.draw(renderer)
File "/Users/me/anaconda3/envs/myenv/lib/python3.6/site-packages/matplotlib/artist.py", line 41, in draw_wrapper
return draw(artist, renderer, *args, **kwargs)
File "/Users/me/anaconda3/envs/myenv/lib/python3.6/site-packages/matplotlib/figure.py", line 1864, in draw
renderer, self, artists, self.suppressComposite)
File "/Users/me/anaconda3/envs/myenv/lib/python3.6/site-packages/matplotlib/image.py", line 131, in _draw_list_compositing_images
a.draw(renderer)
File "/Users/me/anaconda3/envs/myenv/lib/python3.6/site-packages/matplotlib/artist.py", line 41, in draw_wrapper
return draw(artist, renderer, *args, **kwargs)
File "/Users/me/anaconda3/envs/myenv/lib/python3.6/site-packages/matplotlib/cbook/deprecation.py", line 411, in wrapper
return func(*inner_args, **inner_kwargs)
File "/Users/me/anaconda3/envs/myenv/lib/python3.6/site-packages/matplotlib/axes/_base.py", line 2747, in draw
mimage._draw_list_compositing_images(renderer, self, artists)
File "/Users/me/anaconda3/envs/myenv/lib/python3.6/site-packages/matplotlib/image.py", line 131, in _draw_list_compositing_images
a.draw(renderer)
File "/Users/me/anaconda3/envs/myenv/lib/python3.6/site-packages/matplotlib/artist.py", line 41, in draw_wrapper
return draw(artist, renderer, *args, **kwargs)
File "/Users/me/anaconda3/envs/myenv/lib/python3.6/site-packages/matplotlib/image.py", line 646, in draw
renderer.draw_image(gc, l, b, im)
TypeError: Cannot cast array data from dtype('float64') to dtype('uint8') according to the rule 'safe'
似乎错误与 imshow 和自定义颜色图有关,因为我可以在没有 sklearn 的情况下重现:
fig, ax = plt.subplots()
ax.imshow(np.array([[10, 15], [20, 30]]), cmap=cmap_nonlin)
有什么想法吗?如果可能的话,我希望修改颜色图而不是数据本身。
【问题讨论】:
-
您是否有理由像这样深入了解,而不是直接将您的颜色图传递给
ConfusionMatrixDisplay(这需要cmapkwarg。 -
将 cmap=cmap_nonlin 直接添加到 ConfusionMatrixDisplay init 不起作用,因为它未被识别为有效的 kwargs,我随后将其传递给 .plot 调用,但它也不起作用,可能是因为 cmap无效(出于某种我不理解的原因)。
-
到达胆量之后发生的原因有两个:将颜色条移到绘图下方并固定它的限制。
-
要移动下面的颜色条,colorbar 有一个
location和一个orientation参数,无需更改颜色图。要更改限制,BoundaryNorm 可能会有所帮助。请注意,您链接到的博文已经有将近 10 年的历史了(提到 2006 年的代码),并且 matplotlib 已经得到了极大的扩展。
标签: matplotlib scikit-learn confusion-matrix colormap imshow