【问题标题】:Scikit learn with digitsScikit 用数字学习
【发布时间】:2018-05-28 04:41:09
【问题描述】:

我不明白

我正在尝试将 scikit learn 与带有数字数据集的 matplotlib 一起使用

这是我的代码

import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn import svm
from sklearn.model_selection import train_test_split

digits = datasets.load_digits()

clf = svm.SVC(gamma=0.001, C=100.)

X_train, X_test, y_train, y_test = train_test_split(digits.data, digits.target, test_size=0.2, random_state=2017)

clf.fit(X_train, y_train)

pred = clf.predict(X_test)

print("Prediction: {}".format(pred))    
plt.imshow(digits.images[-1], cmap=plt.cm.gray_r, interpolation='nearest')
plt.show()

matplotlib 绘制图像编号 4,当我尝试打印预测时,它每次都会向我打印此输出

Prediction: [8 1 3 8 5 8 1 4 9 7 5 2 1 4 3 2 4 9 5 4 1 9 2 4 7 8 9 3 1 7 5 7 6 2 0 5 7
 1 6 1 9 4 4 5 3 7 3 6 3 3 9 8 5 2 6 1 1 1 4 5 4 2 8 2 7 2 9 7 8 9 1 2 8 0
 7 8 9 0 1 5 4 0 0 9 2 6 7 8 6 5 1 3 1 8 7 7 2 2 2 6 7 4 1 7 2 5 8 3 4 2 3
 7 6 1 1 0 3 0 2 5 9 3 1 6 9 5 6 2 0 3 2 7 4 6 5 3 9 5 1 5 6 0 1 8 6 5 1 6
 2 1 2 5 0 2 3 4 2 4 9 4 4 2 3 9 2 9 8 2 5 9 9 7 3 7 8 1 4 9 2 9 5 1 8 7 4
 8 2 7 6 9 8 8 3 7 1 9 1 4 5 7 0 5 9 3 5 0 5 0 5 5 2 1 3 5 3 2 8 4 7 4 7 3
 7 2 9 5 6 2 8 0 5 0 2 1 9 2 9 6 1 0 1 7 6 3 1 0 3 2 4 0 6 1 2 1 6 2 8 2 7
 1 5 6 6 9 2 1 4 4 8 0 7 6 2 5 0 4 5 5 5 5 7 4 8 1 0 8 4 8 7 2 5 5 7 3 2 4
 4 7 8 2 0 7 1 4 0 9 6 1 8 5 5 1 5 6 1 7 1 5 5 8 4 6 6 0 6 5 0 9 8 0 8 0 9
 2 0 9 5 7 0 8 1 7 0 6 7 7 0 0 7 7 5 0 3 2 2 8 8 7 7 0]

我正在尝试在 matplotlib 中打印数字,但它显示了这个输出

我希望打印prediction: 8

【问题讨论】:

  • 你有什么不明白的?请解释清楚。 pred 数组返回 X_test 中存在的所有样本的预测
  • 您不明白为什么您的代码每次都打印相同的预测,是吗?
  • 就是这样,它每次都打印相同的预测
  • 我试图打印 matplotlib 图像中的数字,但它打印了这个输出
  • 因为您将 X_test 参数传递给 predict()。

标签: python matplotlib machine-learning scikit-learn


【解决方案1】:

好的,这是 cmets 中的人想告诉你的:

X_test 不是单个数据点,而是测试集的整个。它包括多少个样本?

X_test.shape
# (360, 64)

所以,它包括 360 个样本;因此,您的 pred 变量还必须包含所有这 360 个样本的预测。确实:

pred.shape
# (360,)

您想在X_test 中检查您的第一个样本的预测吗?

pred[0]
# 8

这个预测的基本事实是什么?

y_test[0]
# 8

看来您的第一个测试样本确实得到了正确的预测。你想绘制这个样本 (X_test[0])?您应该首先将其重新整形为 (8,8),因为 train_test_split load_digits() 函数将原始 8x8 图像展平为长度为 64 的一维数组:

plt.imshow(X_test[0].reshape(8,8), cmap=plt.cm.gray_r, interpolation='nearest')
plt.show()

嗯,它看起来像一个 8(有点)......

【讨论】:

  • 稍微修正一下。 train_test_split 不要压扁数字的形状。 load_digits().data返回的数组已经是(n_samples, 64)的数组
猜你喜欢
  • 2014-04-20
  • 2018-10-21
  • 2016-12-20
  • 2015-12-22
  • 2016-07-16
  • 1970-01-01
  • 2016-06-21
  • 2015-08-05
  • 1970-01-01
相关资源
最近更新 更多