【问题标题】:Strange error in fitting classifier拟合分类器的奇怪错误
【发布时间】:2020-07-09 16:55:39
【问题描述】:

我正在研究 O'Reilly 的使用 Scikit-Learn 和 Tensorflow 进行机器学习的动手实践

我正在针对 MNIST 数据集训练分类器,但出现错误

ValueError: The number of classes has to be greater than one; got 1 class

这是我的代码

mnist = fetch_openml('mnist_784', version=1, cache=True)

X, y = mnist["data"], mnist["target"]

X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]

shuffle_index = np.random.permutation(60000)
X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]

y_train_5 = (y_train == 9)
y_test_5 = (y_test == 9)

sgd_clf = SGDClassifier(random_state=42)
sgd_clf.fit(X_train, y_train_5)

我已经对我的代码进行了三次检查,但我仍然不确定发生了什么。

【问题讨论】:

  • 我不习惯 python 但 y_train_5 = (y_train == 9) 是做什么的? iy_train_5 是否仅包含 9 的值?
  • 这是我的错误。所以,我完全按照书中给我的“9”而不是“5”,所以我只是重写了变量名并忘记了那里。此外,我不得不使用不同版本的 MNIST,因为这本书给出的链接已损坏。老实说,我对此以及之前甚至无法运行的代码感到有些不安。

标签: scikit-learn


【解决方案1】:

来自sklearn 中的 MNIST 数据集的标签包含字符串,而不是整数。所以,设置

y_train_5 = (y_train == '9')
y_test_5 = (y_test == '9')

当你检查一个整数时,它都会得到False 并且 Python 会警告你你只有 1 个类。

【讨论】:

    【解决方案2】:

    过程都是正确的,只是将数字转换成字符串,因为scikit中的标签需要字符串。

    y_train_5=(y_train == '5')
    y_test_5=(y_test == '5')
    

    【讨论】:

    • 请在您的代码周围使用代码框。阅读此meta post 了解更多详情。
    猜你喜欢
    • 2015-08-23
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多