【发布时间】:2016-06-24 23:13:43
【问题描述】:
我有一个关于 sklearn 上的糖尿病数据集的问题。我正在尝试绘制一种估计器的学习曲线,但不知何故我有警告:
"D:\Users\XXXX\Anaconda2\lib\site-packages\sklearn\cross_validation.p ing: y 中人口最少的类只有 1 个成员,这也是 任何类的最小标签数不能小于 n_folds=3。"
并且代码正在绘制一个奇怪的结果。训练数据集的得分非常高(总是 1,这可能是有道理的,因为它是一棵树),但测试得分表现很差(最好是 0.03125)
我在不同的数据集(数字)中尝试过,效果很好。我的代码如下:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_diabetes
from sklearn.learning_curve import learning_curve
from sklearn import tree
diabetes = load_diabetes()
X, y = diabetes.data, diabetes.target
estimator = tree.DecisionTreeClassifier()
estimator.fit(X, y)
title = "Learning Curves Decision Tree"
plt.figure(1)
plt.title(title)
plt.xlabel("Training examples")
plt.ylabel("Score")
train_sizes, train_scores, test_scores = learning_curve(estimator, X, y)
print train_sizes
print train_scores
print test_scores
plt.grid()
plt.plot(train_sizes, train_scores, 'o-', color="r",label="Training score")
plt.plot(train_sizes, test_scores, 'o-', color="g",label="Cross-validation score")
plt.legend(loc="best")
plt.show()
谁能解释一下为什么会这样?谢谢
【问题讨论】:
标签: python machine-learning scikit-learn supervised-learning