【发布时间】:2014-11-30 04:40:08
【问题描述】:
我在 3 类数据集上使用 scikit-learn 决策树分类器。拟合分类器后,我访问 tree_ 属性上的所有叶节点,以获得每个类最终在给定节点中的实例数量。
clf = tree.DecisionTreeClassifier(max_depth=5)
clf.fit(X, y)
# lets assume there is a leaf node with id 5
print clf.tree_.value[5]
这将打印出来:
>>> array([[ 0., 1., 68.]])
但是...我如何知道该数组中的哪个位置属于哪个类? 分类器有一个 classes_ 属性,它也是一个列表
>>> clf.classes_
array(['CLASS_1', 'CLASS_2', 'CLASS_3'], dtype=object)
也许 value 数组的索引 1 与 classes 数组的索引 1 上的类匹配,等等?
【问题讨论】:
-
请单独发布答案,而不是将其编辑到问题中。然后,您可以接受自己的答案,将问题标记为已关闭。
-
@larsmans,这是通用规则吗?我曾经读过一篇有人这样做的帖子,并得到评论说他应该做我所做的事情。你的声望似乎足够高。我会这样做,希望没有人说相反:S
标签: python scikit-learn decision-tree