【发布时间】:2020-07-09 16:55:39
【问题描述】:
我正在研究 O'Reilly 的使用 Scikit-Learn 和 Tensorflow 进行机器学习的动手实践。
我正在针对 MNIST 数据集训练分类器,但出现错误
ValueError: The number of classes has to be greater than one; got 1 class
这是我的代码
mnist = fetch_openml('mnist_784', version=1, cache=True)
X, y = mnist["data"], mnist["target"]
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
shuffle_index = np.random.permutation(60000)
X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]
y_train_5 = (y_train == 9)
y_test_5 = (y_test == 9)
sgd_clf = SGDClassifier(random_state=42)
sgd_clf.fit(X_train, y_train_5)
我已经对我的代码进行了三次检查,但我仍然不确定发生了什么。
【问题讨论】:
-
我不习惯 python 但 y_train_5 = (y_train == 9) 是做什么的? iy_train_5 是否仅包含 9 的值?
-
这是我的错误。所以,我完全按照书中给我的“9”而不是“5”,所以我只是重写了变量名并忘记了那里。此外,我不得不使用不同版本的 MNIST,因为这本书给出的链接已损坏。老实说,我对此以及之前甚至无法运行的代码感到有些不安。
标签: scikit-learn