【问题标题】:Why Decision Tree code written in python predicts differently than the code written in R?为什么用 python 编写的决策树代码的预测与用 R 编写的代码不同?
【发布时间】:2017-11-21 00:24:18
【问题描述】:

我正在使用 python 和 R 中来自 sklearn 的 load_iris 数据集(在 R 中它只是称为 iris)。

我使用“gini”索引以两种语言构建了模型,并且当测试数据直接取自 iris 数据集时,我能够正确地测试这两种语言的模型。

但是,如果我给一个新的数据集作为测试输入,对于同一个 python 和 R 将它分为不同的类别。

我不确定我在这里遗漏了什么或做错了什么,因此非常感谢任何指导。

代码如下: Python 2.7:

from sklearn.datasets import load_iris
from sklearn import tree
iris = load_iris()
model = tree.DecisionTreeClassifier(criterion='gini')
model.fit(iris.data, iris.target)
model.score(iris.data, iris.target)
print iris.data[49],model.predict([iris.data[49]])
print iris.data[99],model.predict([iris.data[99]])
print iris.data[100],model.predict([iris.data[100]])
print iris.data[149],model.predict([iris.data[149]])
print [6.3,2.8,6,1.3],model.predict([[6.3,2.8,6,1.3]])

运行 3.3.2 32 位的 R-Rstudio:

library(rpart)
iris<- iris
x_train = iris[c('Sepal.Length','Sepal.Width','Petal.Length','Petal.Width')]
y_train = as.matrix(cbind(iris['Species']))
x <- cbind(x_train,y_train)
fit <- rpart(y_train ~ ., data = x_train,method="class",parms = list(split = "gini"))
summary(fit)
x_test = x[149,]
x_test[,1]=6.3
x_test[,2]=2.8
x_test[,3]=6
x_test[,4]=1.3
predicted1= predict(fit,x[49,]) # same as python result
predicted2= predict(fit,x[100,]) # same as python result 
predicted3= predict(fit,x[101,]) # same as python result
predicted4= predict(fit,x[149,]) # same as python result
predicted5= predict(fit,x_test) ## this value does not match with pythons result

我的 python 输出是:

[ 5.   3.3  1.4  0.2] [0]
[ 5.7  2.8  4.1  1.3] [1]
[ 6.3  3.3  6.   2.5] [2]
[ 5.9  3.   5.1  1.8] [2]
[6.3, 2.8, 6, 1.3] [2] -----> this means it's putting the test data into virginica bucket

R 输出为:

> predicted1
   setosa versicolor virginica
49      1          0         0
> predicted2
    setosa versicolor  virginica
100      0  0.9074074 0.09259259
> predicted3
    setosa versicolor virginica
101      0 0.02173913 0.9782609
> predicted4
    setosa versicolor virginica
149      0 0.02173913 0.9782609
> predicted5
    setosa versicolor  virginica
149      0  0.9074074 0.09259259 --> this means it's putting the test data into versicolor bucket

请帮忙。谢谢。

【问题讨论】:

  • 你能公布你的R树的树参数和值吗?

标签: python r decision-tree


【解决方案1】:

决策树涉及很多参数(最小/最大叶大小、树的深度、何时拆分等),不同的包可能有不同的默认设置。如果您想获得相同的结果,则需要确保隐式默认值相似。例如,尝试运行以下命令:

fit <- rpart(y_train ~ ., data = x_train,method="class",
             parms = list(split = "gini"), 
             control = rpart.control(minsplit = 2, minbucket = 1, xval=0, maxdepth = 30))

(predicted5= predict(fit,x_test))
    setosa versicolor virginica
149      0  0.3333333 0.6666667

这里,选择minsplit = 2, minbucket = 1, xval=0maxdepth = 30 选项以与sklearn 选项相同,请参阅heremaxdepth = 30是最大价值rpart会让你拥有; sklearn 此处没有限制)。如果您希望概率等相同,您可能还想使用cp 参数。

同样,

model = tree.DecisionTreeClassifier(criterion='gini', 
                                    min_samples_split=20, 
                                    min_samples_leaf=round(20.0/3.0), max_depth=30)
model.fit(iris.data, iris.target)

我明白了

print model.predict([iris.data[49]])
print model.predict([iris.data[99]])
print model.predict([iris.data[100]])
print model.predict([iris.data[149]])
print model.predict([[6.3,2.8,6,1.3]])

[0]
[1]
[2]
[2]
[1]

这看起来与您最初的 R 输出非常相似。

不用说,当您的预测(在训练集上)看起来“非常好”时要小心,因为您可能会过度拟合数据。例如,查看model.predict_proba(...),它为您提供sklearn 中的概率(而不是预测的类)。您应该看到,使用您当前的 Python 代码/设置,您几乎肯定会过度拟合。

【讨论】:

  • 一个非常描述性和信息丰富的答案@coffeinjunkey。谢谢你。我只想补充一件事:min_samples_leaf=round(20.0/3.0) 扔了一个ValueError: min_samples_leaf must be at least 1 or in (0, 0.5], got 7.0。为了解决这个问题,我用 int() 包裹了圆形函数,一切都很好。所以,我最终的更新代码形式如下:Python 应该有:model = tree.DecisionTreeClassifier(criterion='gini', min_samples_split=20, min_samples_leaf=int(round(20.0/3.0)), max_depth=30)
【解决方案2】:

除了@coffeeinjunky 的回答,你还需要注意参数random_state(这是Python 的参数,不知道R 中叫什么)。树本身的生成是伪随机的,因此您需要指定两个模型具有相同的种子值。否则,您将使用相同的模型进行拟合/预测,并在每次运行时得到不同的结果,因为每次使用的树都不同。

查看 Mueller 和 Guido 中关于决策树的部分——“Python 用于机器学习”。它在视觉上解释不同的参数方面做得很好,如果你只是尝试谷歌搜索,pdf 就会在互联网上流传。借助决策树和集成学习方法,您指定的参数将对预测产生有意义的影响。

【讨论】:

    猜你喜欢
    • 2016-12-19
    • 2017-05-21
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2020-02-10
    • 1970-01-01
    相关资源
    最近更新 更多