【问题标题】:Accumulate conditions during recursion on classification tree在分类树上递归期间累积条件
【发布时间】:2020-05-04 05:07:31
【问题描述】:

我有以下从 sci-kit 学习分类树生成代码的函数:

def mxTreeToCode(tree, feature_names, mx_name = 'mxTree', rm_file = False):

    # Remove pre-existent file
    if rm_file:
        import os
        try:
            os.remove('./tree.py')
        except OSError:
            pass

    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    file = open('tree.py', 'a')
    file.write('def ' + mx_name + '(x):'+ '\n') 
    #col_name = ''
    def recurse(node, depth):
        global col_name
        indent = "    " * depth

        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]

            file.write(indent +"if x['"+ name + "'] <= " + str(threshold) + ':' + '\n')
            col_name += "'"+name + '_' + '<=' + str(threshold) +"'"

            recurse(tree_.children_left[node], depth + 1)


            file.write(indent + "else: # if x['"+ name +"'] > " + str(threshold) + '\n')
            col_name += "'"+name + '_' + '>' + str(threshold) +"'"

            recurse(tree_.children_right[node], depth + 1)


        else:
            file.write(indent + 'return '+str(col_name) + '\n')
            #print(col_name)
            col_name = ""

    recurse(0, 1)
    file.close()

有了这个,我在给定分类树的文件“tree.py”上获得以下输出:

def mxTree(x):
    if x['V1'] <= 0.5:
        if x['V2'] <= 0.5:
            return 'V1_<=0.5''V2_<=0.5'
        else: # if x['V2'] > 0.5
            return 'V2_>0.5'
    else: # if x['V1'] > 0.5
        return 'V1_>0.5'

虽然我可以累积 IF 侧的条件并返回添加的条件,但当 IF 和 ELSE(树节点的左侧/右侧)跟随时,我无法进行累积:

def mxTree(x):
    if x['V1'] <= 0.5:
        if x['V2'] <= 0.5:
            return 'V1_<=0.5''V2_<=0.5'
        else: # if x['V2'] > 0.5
            return 'V1_<=0.5''V2_>0.5' # 'V1<=0.5' must be added
    else: # if x['V1'] > 0.5
        return 'V1_>0.5'

如果有任何建议,我将不胜感激。

【问题讨论】:

    标签: python recursion scikit-learn tree classification


    【解决方案1】:

    由于每个节点的左侧/右侧同时递归,我只是创建了一个附加变量来保存每一侧的输出。最后我连接到变量 col_name:

    col_name = ""
    names_list={}
    def mxTreeToCode(tree, feature_names, mx_name = 'mxTree', rm_file = False):
    
        # Remove pre-existent file
        if rm_file:
            import os
            try:
                os.remove('./tree.py')
            except OSError:
                pass
    
        tree_ = tree.tree_
        feature_name = [
            feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
            for i in tree_.feature
        ]
        file = open('tree.py', 'a')
        file.write('def ' + mx_name + '(x):'+ '\n') 
    
        def recurse(node, depth):
            global col_name, names_list
            indent = "    " * depth
            names_list[node] = col_name
            if tree_.feature[node] != _tree.TREE_UNDEFINED:
                name = feature_name[node]
                threshold = tree_.threshold[node]
    
                file.write(indent +"if x['"+ name + "'] <= " + str(threshold) + ':' + '\n')
                col_name += "'"+name + '_' + '<=' + str(threshold) +"'"
    
                recurse(tree_.children_left[node], depth + 1)
    
    
                file.write(indent + "else: # if x['"+ name +"'] > " + str(threshold) + '\n')
                col_name += names_list[node]
                col_name += "'"+name + '_' + '>' + str(threshold) +"'"
    
                recurse(tree_.children_right[node], depth + 1)
    
    
            else:
                file.write(indent + 'return '+str(col_name) + '\n')
                col_name = ""
    
        recurse(0, 1)
        file.close()
    

    我想知道是否还有其他工作方法。

    【讨论】:

      猜你喜欢
      • 2013-12-25
      • 2012-07-18
      • 1970-01-01
      • 2021-11-01
      • 1970-01-01
      • 1970-01-01
      • 2023-03-06
      • 2019-07-03
      相关资源
      最近更新 更多