【问题标题】:Prune unnecessary leaves in sklearn DecisionTreeClassifier在 sklearn DecisionTreeClassifier 中修剪不必要的叶子
【发布时间】:2018-12-26 02:01:04
【问题描述】:

我使用 sklearn.tree.DecisionTreeClassifier 来构建决策树。使用最优参数设置,我得到一棵有不必要叶子的树(见下图示例 - 我不需要概率,所以标记为红色的叶子节点是不必要的分裂)

是否有任何第三方库可以修剪这些不必要的节点?还是代码sn-p?我可以写一个,但我真的无法想象我是第一个遇到这个问题的人......

要复制的代码:

from sklearn.tree import DecisionTreeClassifier
from sklearn import datasets
iris = datasets.load_iris()
X = iris.data
y = iris.target
mdl = DecisionTreeClassifier(max_leaf_nodes=8)
mdl.fit(X,y)

PS:我已经尝试了多个关键字搜索,但我有点惊讶地没有找到任何东西 - sklearn 中真的没有一般的后修剪吗?

PPS:针对可能的重复:虽然the suggested question 可能会在我自己编写修剪算法时帮助我,但它回答了一个不同的问题——我想去掉不改变最终决定的叶子,而另一个问题想要一个分裂节点的最小阈值。

PPPS:显示的树是显示我的问题的示例。我知道创建树的参数设置不是最理想的。我不是在问优化这棵特定的树,我需要进行后修剪以去除叶子,如果需要类概率可能会有所帮助,但如果只对最可能的类感兴趣,则无济于事。

【问题讨论】:

  • Pruning Decision Trees的可能重复
  • @ncfirth:虽然问题也与修剪有关,但它会尝试做其他事情 - 请参阅我的编辑。
  • @ncfirth:但是,感谢您提供链接,它帮助我编写了自己的代码 (see my answer below) 用于后期修剪。

标签: python scikit-learn decision-tree pruning


【解决方案1】:

使用 ncfirth 的链接,我能够修改那里的代码,使其适合我的问题:

from sklearn.tree._tree import TREE_LEAF

def is_leaf(inner_tree, index):
    # Check whether node is leaf node
    return (inner_tree.children_left[index] == TREE_LEAF and 
            inner_tree.children_right[index] == TREE_LEAF)

def prune_index(inner_tree, decisions, index=0):
    # Start pruning from the bottom - if we start from the top, we might miss
    # nodes that become leaves during pruning.
    # Do not use this directly - use prune_duplicate_leaves instead.
    if not is_leaf(inner_tree, inner_tree.children_left[index]):
        prune_index(inner_tree, decisions, inner_tree.children_left[index])
    if not is_leaf(inner_tree, inner_tree.children_right[index]):
        prune_index(inner_tree, decisions, inner_tree.children_right[index])

    # Prune children if both children are leaves now and make the same decision:     
    if (is_leaf(inner_tree, inner_tree.children_left[index]) and
        is_leaf(inner_tree, inner_tree.children_right[index]) and
        (decisions[index] == decisions[inner_tree.children_left[index]]) and 
        (decisions[index] == decisions[inner_tree.children_right[index]])):
        # turn node into a leaf by "unlinking" its children
        inner_tree.children_left[index] = TREE_LEAF
        inner_tree.children_right[index] = TREE_LEAF
        ##print("Pruned {}".format(index))

def prune_duplicate_leaves(mdl):
    # Remove leaves if both 
    decisions = mdl.tree_.value.argmax(axis=2).flatten().tolist() # Decision for each node
    prune_index(mdl.tree_, decisions)

在 DecisionTreeClassifier clf 上使用它:

prune_duplicate_leaves(clf)

编辑:修复了更复杂树的错误

【讨论】:

  • 请注意,此代码将就地修改树。这本身还不错,如果您想比较修剪前后的树,很高兴知道。
  • @Thomas 您可以更改此代码以按样本数量进行修剪吗?假设你修剪 n_node_samples 小于 5。我尝试通过 n_node_samples 小于阈值切换决策 [index] 但它不起作用。我已经在这个问题上浪费了几个星期,并尝试调整其他解决方案,但到目前为止没有运气。
  • 您不需要为此进行修剪,您可以通过在训练期间简单地设置 min_samples_split=5 来确保所有节点都有超过 5 个样本。例如,请参见此处:scikit-learn.org/stable/modules/generated/…
【解决方案2】:

DecisionTreeClassifier(max_leaf_nodes=8) 指定(最多)8 个叶子,因此除非树生成器有其他理由停止,否则它将达到最大值。

在所示示例中,与其他 3 个叶子 (>50) 相比,8 个叶子中的 5 个具有非常少量的样本 (min_samples_leaf 或min_samples_split 以更好地指导训练,这可能会消除有问题的叶子。例如,至少 5% 的样本使用值 0.05

【讨论】:

  • 这只是一个可重现的示例,旨在显示我的问题,显然不是我的真实代码......我知道决策树的各种设置,但是,sklearn 到目前为止只是缺少任何后修剪选项。
【解决方案3】:

我在这里发布的代码有问题,所以我对其进行了修改并不得不添加一小部分(它处理双方相同但仍然存在比较的情况):

from sklearn.tree._tree import TREE_LEAF, TREE_UNDEFINED

def is_leaf(inner_tree, index):
    # Check whether node is leaf node
    return (inner_tree.children_left[index] == TREE_LEAF and 
            inner_tree.children_right[index] == TREE_LEAF)

def prune_index(inner_tree, decisions, index=0):
    # Start pruning from the bottom - if we start from the top, we might miss
    # nodes that become leaves during pruning.
    # Do not use this directly - use prune_duplicate_leaves instead.
    if not is_leaf(inner_tree, inner_tree.children_left[index]):
        prune_index(inner_tree, decisions, inner_tree.children_left[index])
    if not is_leaf(inner_tree, inner_tree.children_right[index]):
        prune_index(inner_tree, decisions, inner_tree.children_right[index])

    # Prune children if both children are leaves now and make the same decision:     
    if (is_leaf(inner_tree, inner_tree.children_left[index]) and
        is_leaf(inner_tree, inner_tree.children_right[index]) and
        (decisions[index] == decisions[inner_tree.children_left[index]]) and 
        (decisions[index] == decisions[inner_tree.children_right[index]])):
        # turn node into a leaf by "unlinking" its children
        inner_tree.children_left[index] = TREE_LEAF
        inner_tree.children_right[index] = TREE_LEAF
        inner_tree.feature[index] = TREE_UNDEFINED
        ##print("Pruned {}".format(index))

def prune_duplicate_leaves(mdl):
    # Remove leaves if both 
    decisions = mdl.tree_.value.argmax(axis=2).flatten().tolist() # Decision for each node
    prune_index(mdl.tree_, decisions)

【讨论】:

    猜你喜欢
    • 2019-05-22
    • 2019-10-03
    • 2019-05-06
    • 2018-01-21
    • 1970-01-01
    • 2020-05-23
    • 2017-12-20
    • 2020-02-05
    • 2020-04-01
    相关资源
    最近更新 更多