【问题标题】:Traversal of sklearn decision treesklearn决策树的遍历
【发布时间】:2020-07-26 21:48:36
【问题描述】:

如何进行sklearn决策树的广度优先搜索遍历?

在我的代码中,我尝试了 sklearn.tree_ 库并使用了各种函数,例如 tree_.feature 和 tree_.threshold 来理解树的结构。但是这些函数是对树进行dfs遍历的,如果我想做bfs我应该怎么做呢?

假设

clf1 = DecisionTreeClassifier( max_depth = 2 )
clf1 = clf1.fit(x_train, y_train)

这是我的分类器,生成的决策树是

然后我使用以下函数遍历了树

def encoding(clf, features):
l1 = list()
l2 = list()

for i in range(len(clf.tree_.feature)):
    if(clf.tree_.feature[i]>=0):
        l1.append( features[clf.tree_.feature[i]])
        l2.append(clf.tree_.threshold[i])
    else:
        l1.append(None)
        print(np.max(clf.tree_.value))
        l2.append(np.argmax(clf.tree_.value[i]))

l = [l1 , l2]

return np.array(l)

产生的输出是

array([['address', 'age', None, None, 'age', None, None],
       [0.5, 17.5, 2, 1, 15.5, 1, 1]], dtype=object)

其中第一个数组是节点的特征,或者如果它是叶子节点,那么它被标记为无,第二个数组是特征节点的阈值,对于类节点它是类,但这是树的 dfs 遍历我想做 bfs 遍历什么我应该这样做吗?

由于我是堆栈溢出的新手,请建议如何改进问题描述以及我应该添加哪些其他信息以进一步解释我的问题。

X_train(样本)

y_train(样本)

【问题讨论】:

  • 请提供minimal reproducible example(您可以通过使用虚拟数据修改文档中的一些示例来轻松解决此问题)。
  • 您能否提供一些背景信息为什么您想这样做?我觉得这可能是可怕的XY problem 的情况。
  • @desertnaut 如果需要更多信息请告诉我,我已经详细说明了
  • @Dion 我将这棵树用于遗传算法的初始种群,所以我想将它们编码成染色体,这就是为什么
  • 请采取最后一个额外步骤,通过在问题中提供 x_trainy_train 来使您的示例完全可重现 - 如前所述,可以是虚拟数据,并且在您可以改编的 scikit-learn 文档(或使用 make_classification)。

标签: python scikit-learn decision-tree


【解决方案1】:

应该这样做:

from collections import deque

tree = clf.tree_

stack = deque()
stack.append(0)  # push tree root to stack

while stack:
    current_node = stack.popleft()

    # do whatever you want with current node
    # ...

    left_child = tree.children_left[current_node]
    if left_child >= 0:
        stack.append(left_child)

    right_child = tree.children_right[current_node]
    if right_child >= 0:
        stack.append(right_child)

这使用deque 来保留要处理的节点堆栈。由于我们从左侧移除元素并添加到右侧,这应该代表广度优先遍历。


为了实际使用,我建议你把它变成一个生成器:

from collections import deque

def breadth_first_traversal(tree):
    stack = deque()
    stack.append(0)

    while stack:
        current_node = stack.popleft()

        yield current_node

        left_child = tree.children_left[current_node]
        if left_child >= 0:
            stack.append(left_child)

        right_child = tree.children_right[current_node]
        if right_child >= 0:
            stack.append(right_child)

然后,您只需对原始函数进行少量更改:

def encoding(clf, features):
    l1 = list()
    l2 = list()

    for i in breadth_first_traversal(clf.tree_):
        if(clf.tree_.feature[i]>=0):
            l1.append( features[clf.tree_.feature[i]])
            l2.append(clf.tree_.threshold[i])
        else:
            l1.append(None)
            print(np.max(clf.tree_.value))
            l2.append(np.argmax(clf.tree_.value[i]))

    l = [l1 , l2]

    return np.array(l)

【讨论】:

  • 成功了。我不知道使用 tree.children_left[current_node] 和 tree.children_right[current_node] 检索子节点。感谢您的回答。
  • 我也想知道我们是否可以将树存储到数组中,使其看起来像是一棵完整的二叉树,以便第 i 个节点的子节点存储在第 2i 个和第 2i +1 个位置跨度>
  • 这应该是一个新问题。
  • 好的,如果我包含所有这些细节然后添加它,我应该怎么做。我也想到了一种方法,但我认为它非常低效。
猜你喜欢
  • 2020-08-03
  • 2017-02-12
  • 2018-05-24
  • 2018-09-06
  • 2019-03-04
  • 2020-01-27
  • 2015-02-07
  • 2015-04-07
  • 2018-08-14
相关资源
最近更新 更多