【问题标题】:Finding a corresponding leaf node for each data point in a decision tree (scikit-learn)为决策树中的每个数据点找到对应的叶节点(scikit-learn)
【发布时间】:2015-08-05 03:30:42
【问题描述】:

我正在使用 python 3.4 中 scikit-learn 包中的决策树分类器,我想为每个输入数据点获取相应的叶节点 ID。

例如,我的输入可能如下所示:

array([[ 5.1,  3.5,  1.4,  0.2],
       [ 4.9,  3. ,  1.4,  0.2],
       [ 4.7,  3.2,  1.3,  0.2]])

假设对应的叶子节点分别是16、5和45。我希望我的输出是:

leaf_node_id = array([16, 5, 45])

我已经阅读了 scikit-learn 邮件列表和有关 SF 的相关问题,但我仍然无法使用它。这是我在邮件列表中找到的一些提示,但仍然不起作用。

http://sourceforge.net/p/scikit-learn/mailman/message/31728624/

归根结底,我只想有一个函数 GetLeafNode(clf, X_valida) ,它的输出是相应叶节点的列表。下面是重现我收到的错误的代码。因此,任何建议都将不胜感激。

from sklearn.datasets import load_iris
from sklearn import tree

# load data and divide it to train and validation
iris = load_iris()

num_train = 100
X_train = iris.data[:num_train,:]
X_valida = iris.data[num_train:,:]

y_train = iris.target[:num_train]
y_valida = iris.target[num_train:]

# fit the decision tree using the train data set
clf = tree.DecisionTreeClassifier()
clf = clf.fit(X_train, y_train)

# Now I want to know the corresponding leaf node id for each of my training data point
clf.tree_.apply(X_train)

# This gives the error message below:
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-17-2ecc95213752> in <module>()
----> 1 clf.tree_.apply(X_train)

_tree.pyx in sklearn.tree._tree.Tree.apply (sklearn/tree/_tree.c:19595)()

ValueError: Buffer dtype mismatch, expected 'DTYPE_t' but got 'double'

【问题讨论】:

    标签: python machine-learning scikit-learn decision-tree


    【解决方案1】:

    我终于让它工作了。这是基于我在 scikit-learn 邮件列表中的通信 message 的一种解决方案:

    scikit-learn 0.16.1版本之后,在clf.tree_中实现了apply方法,因此,我按照以下步骤进行:

    1. 将 scikit-learn 更新到最新版本(0.16.1),以便您可以使用来自clf.tree_apply 方法
    2. 将输入数据数组(X_trainX_valida)从float64 转换为float32,使用:X_train = X_train.astype('float32')
    3. 现在您可以通过这种方式使用apply 方法:clf.tree_.apply(X_train),您将获得每个数据点的叶节点ID。

    这是最终代码:

    from sklearn.datasets import load_iris
    from sklearn import tree
    
    # load data and divide it to train and validation
    iris = load_iris()
    
    num_train = 100
    X_train = iris.data[:num_train,:]
    X_valida = iris.data[num_train:,:]
    
    y_train = iris.target[:num_train]
    y_valida = iris.target[num_train:]
    
    # convert data to float32
    X_train = X_train.astype('float32')
    
    # fit the decision tree using the train data set
    clf = tree.DecisionTreeClassifier()
    clf = clf.fit(X_train, y_train)
    
    # Now I want to know the corresponding leaf node id for each of my training data point
    clf.tree_.apply(X_train)
    
    # This gives the leaf node id:
    array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
           1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
           1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
           2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
           2, 2, 2, 2, 2, 2, 2, 2])
    

    【讨论】:

      【解决方案2】:

      从 scikit-learn 0.17 开始,您可以使用 DecisionTree 对象的 apply 方法来获取数据点在树中结束的叶子的索引。基于 neobot 的回答:

      from sklearn.datasets import load_iris
      from sklearn import tree
      
      # load data and divide it to train and validation
      iris = load_iris()
      
      num_train = 100
      X_train = iris.data[:num_train,:]
      X_valida = iris.data[num_train:,:]
      
      y_train = iris.target[:num_train]
      y_valida = iris.target[num_train:]
      
      # fit the decision tree using the train data set
      clf = tree.DecisionTreeClassifier()
      clf = clf.fit(X_train, y_train)
      
      # Compute the leaf node id for each of my training data points
      clf.apply(X_train)
      

      产生输出

      array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
             1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
             1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
             2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
             2, 2, 2, 2, 2, 2, 2, 2])
      

      【讨论】:

      • 小问题,如何理解哪个节点在哪个索引处
      猜你喜欢
      • 2017-03-26
      • 2017-01-21
      • 2019-12-11
      • 2020-08-29
      • 2020-07-25
      • 2017-10-24
      • 1970-01-01
      • 2020-04-07
      • 2016-08-12
      相关资源
      最近更新 更多