【发布时间】:2020-05-04 05:07:31
【问题描述】:
我有以下从 sci-kit 学习分类树生成代码的函数:
def mxTreeToCode(tree, feature_names, mx_name = 'mxTree', rm_file = False):
# Remove pre-existent file
if rm_file:
import os
try:
os.remove('./tree.py')
except OSError:
pass
tree_ = tree.tree_
feature_name = [
feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
for i in tree_.feature
]
file = open('tree.py', 'a')
file.write('def ' + mx_name + '(x):'+ '\n')
#col_name = ''
def recurse(node, depth):
global col_name
indent = " " * depth
if tree_.feature[node] != _tree.TREE_UNDEFINED:
name = feature_name[node]
threshold = tree_.threshold[node]
file.write(indent +"if x['"+ name + "'] <= " + str(threshold) + ':' + '\n')
col_name += "'"+name + '_' + '<=' + str(threshold) +"'"
recurse(tree_.children_left[node], depth + 1)
file.write(indent + "else: # if x['"+ name +"'] > " + str(threshold) + '\n')
col_name += "'"+name + '_' + '>' + str(threshold) +"'"
recurse(tree_.children_right[node], depth + 1)
else:
file.write(indent + 'return '+str(col_name) + '\n')
#print(col_name)
col_name = ""
recurse(0, 1)
file.close()
有了这个,我在给定分类树的文件“tree.py”上获得以下输出:
def mxTree(x):
if x['V1'] <= 0.5:
if x['V2'] <= 0.5:
return 'V1_<=0.5''V2_<=0.5'
else: # if x['V2'] > 0.5
return 'V2_>0.5'
else: # if x['V1'] > 0.5
return 'V1_>0.5'
虽然我可以累积 IF 侧的条件并返回添加的条件,但当 IF 和 ELSE(树节点的左侧/右侧)跟随时,我无法进行累积:
def mxTree(x):
if x['V1'] <= 0.5:
if x['V2'] <= 0.5:
return 'V1_<=0.5''V2_<=0.5'
else: # if x['V2'] > 0.5
return 'V1_<=0.5''V2_>0.5' # 'V1<=0.5' must be added
else: # if x['V1'] > 0.5
return 'V1_>0.5'
如果有任何建议,我将不胜感激。
【问题讨论】:
标签: python recursion scikit-learn tree classification