#-*- coding:utf-8 -*-
import sys
reload(sys)
sys.setdefaultencoding('utf-8')
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
import json
import numpy as np
import pandas as pd
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier
from IPython.display import display, Image
import pydotplus
from sklearn import tree
from sklearn.tree import _tree
from sklearn import tree
import collections
import drawtree
import os
from sklearn.tree._tree import TREE_LEAF
def rules(clf, features, labels, node_index=0):
"""Structure of rules in a fit decision tree classifier
Parameters
----------
clf : DecisionTreeClassifier
A tree that has already been fit.
features, labels : lists of str
The names of the features and labels, respectively.
"""
node = {}
if clf.tree_.children_left[node_index] == -1: # 叶子节点
count_labels = zip(clf.tree_.value[node_index, 0], labels)
node['name'] = ', '.join(('{} of {}'.format(int(count), label)
for count, label in count_labels))
else:
# feature = features[clf.tree_.feature[node_index]]
# threshold = clf.tree_.threshold[node_index]
# node['name'] = '{} > {}'.format(feature, threshold)
# left_index = clf.tree_.children_left[node_index]
# right_index = clf.tree_.children_right[node_index]
# node['children'] = [rules(clf, features, labels, right_index),
# rules(clf, features, labels, left_index)]
feature = features[clf.tree_.feature[node_index]]
threshold = clf.tree_.threshold[node_index]
node['name'] = '{} < {}'.format(feature, threshold)
left_index = clf.tree_.children_right[node_index]
right_index =clf.tree_.children_left[node_index]
node['children'] = [rules(clf, features, labels, right_index),
rules(clf, features, labels, left_index)]
return node
def model_json():
data = load_iris()
# print"type of data=",type(data)
# print"data",data
clf = DecisionTreeClassifier(max_depth=3)
clf.fit(data.data, data.target)
rules(clf, data.feature_names, data.target_names)
result = rules(clf, data.feature_names, data.target_names)
with open('structure.json', 'w') as f:
f.write(json.dumps(result))
print dir(data)
return clf,result,data.data,data.feature_names,data.target,data.target_names
def draw_file(model,dot_file,png_file,X_train,feature_names):
dot_data = tree.export_graphviz(model, out_file =dot_file ,
feature_names=feature_names, filled = True
, rounded = True
, special_characters = True)
graph = pydotplus.graph_from_dot_file(dot_file)
thisIsTheImage = Image(graph.create_png())
display(thisIsTheImage)
#print(dt.tree_.feature)
from subprocess import check_call
check_call(['dot','-Tpng',dot_file,'-o',png_file])
if __name__ == '__main__':
model,model_json,X_train,feature_names,label,class_names=model_json()
# print "X_train=",X_train
# feature_names= ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
# print"feature_names=",feature_names
dot_file="unpruned.dot"
png_file="unpruned.png"
draw_file(model,dot_file,png_file,X_train,feature_names)
# print"model=",model
# print X_train
when you run the above code several times, you will get the following two slightly different models.
-----------------------------------------------------------------------------------------------------------------------------------------------------------