【问题标题】:Possible to modify/prune learned trees in scikit-learn?可以在 scikit-learn 中修改/修剪学习的树吗?
【发布时间】:2016-12-24 10:36:21
【问题描述】:

可以使用 sklearn 访问树参数

tree.tree_.children_left
tree.tree_.children_right
tree.tree_.threshold
tree.tree_.feature

等等

但是,尝试写入这些变量会引发不可写异常

有没有办法修改学习树,或者绕过AttributeError not writable?

【问题讨论】:

    标签: python machine-learning scikit-learn random-forest decision-tree


    【解决方案1】:

    属性都是不能被覆盖的 int 数组。您仍然可以修改这些数组的元素。这不会减轻数据的负担。

    children_left : array of int, shape [node_count]
        children_left[i] holds the node id of the left child of node i.
        For leaves, children_left[i] == TREE_LEAF. Otherwise,
        children_left[i] > i. This child handles the case where
        X[:, feature[i]] <= threshold[i].
    
    children_right : array of int, shape [node_count]
        children_right[i] holds the node id of the right child of node i.
        For leaves, children_right[i] == TREE_LEAF. Otherwise,
        children_right[i] > i. This child handles the case where
        X[:, feature[i]] > threshold[i].
    
    feature : array of int, shape [node_count]
        feature[i] holds the feature to split on, for the internal node i.
    
    threshold : array of double, shape [node_count]
        threshold[i] holds the threshold for the internal node i.
    

    为了通过节点中的观察数量来修剪决策树,我使用了这个函数。您需要知道 TREE_LEAF 常量等于 -1。

    def prune(decisiontree, min_samples_leaf = 1):
        if decisiontree.min_samples_leaf >= min_samples_leaf:
            raise Exception('Tree already more pruned')
        else:
            decisiontree.min_samples_leaf = min_samples_leaf
            tree = decisiontree.tree_
            for i in range(tree.node_count):
                n_samples = tree.n_node_samples[i]
                if n_samples <= min_samples_leaf:
                    tree.children_left[i]=-1
                    tree.children_right[i]=-1
    

    这是一个在前后产生graphviz输出的例子:

    [from sklearn.tree import DecisionTreeRegressor as DTR
    from sklearn.datasets import load_diabetes
    from sklearn.tree import export_graphviz as export
    
    bunch = load_diabetes()
    data = bunch.data
    target = bunch.target
    
    dtr = DTR(max_depth = 4)
    dtr.fit(data,target)
    
    export(decision_tree=dtr.tree_, out_file='before.dot')
    prune(dtr, min_samples_leaf = 100)
    export(decision_tree=dtr.tree_, out_file='after.dot')][1]
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2016-06-18
      • 2020-04-01
      • 1970-01-01
      • 2017-02-03
      • 2017-08-07
      • 2020-11-15
      • 2019-02-06
      • 1970-01-01
      相关资源
      最近更新 更多