决策树的code:
# coding=utf-8 from math import log import operator import math import matplotlib.pyplot as plt ''' 对于海洋生物的数据,进行决策树分类 ''' def createDataSet(): ''' 第一列 不浮出水面是否可以生存 no surfacing 第二列 是否有脚 flippers 第三列 是否属于鱼 ''' dataSet = [ [1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no'] ] labels = ['no surfacing','flippers'] return dataSet,labels
输出数据:
dataSet,labels = createDataSet() print(dataSet) print(labels)
[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']] ['no surfacing', 'flippers']
计算给顶数据集的熵:
#计算给定数据集的熵 def calcShannoEnt(dataSet): numEntries = len(dataSet) #统计元素的个数 labelsCounts = {} #标签计数 for featVec in dataSet:#每个标签对应的个数 currentLabel = featVec[-1] if currentLabel not in labelsCounts.keys(): labelsCounts[currentLabel] = 0 labelsCounts[currentLabel] += 1 shannonEnt = 0.0 #熵的变量 for key in labelsCounts: #计算熵 prob = labelsCounts[key]/numEntries shannonEnt -= prob*math.log(prob,2) return shannonEnt
shannonEnt = calcShannoEnt(dataSet)
输出结果: 0.9709505944546686
按照给定特征划分数据集:
def splitDateSet(dataSet, axis, value): retDateSet = [] for featVec in dataSet: if featVec[axis] == value: reduceFeatVev = featVec[ : axis] # 获取列表axis前面的元素 reduceFeatVev.extend(featVec[axis + 1 : ]) #获取列表axis后面的元素 retDateSet.append(reduceFeatVev) return retDateSet
print(splitDateSet(dataSet,0,1)) print(splitDateSet(dataSet, 0, 0))
输出结果:
[[1, 'yes'], [1, 'yes'], [0, 'no']]
[[1, 'no'], [1, 'no']]
选择最好的数据集划分方式:
def chooseBestFeatureToSplit(dataSet): numFeatrues = len(dataSet[0]) -1 #特征个数 baseEntropy = calcShannoEnt(dataSet) #计算总熵 baseInfoGain = 0.0 #信息增益 baseFeature = -1 #最好特征存放变量 for i in range(numFeatrues) : #特征循环的控制 featList = [example[i] for example in dataSet] #存在某一个特征的所有样本 uniqueVals = set(featList) #每一个特征,含有不同value值 newEntropy = 0.0 for value in uniqueVals: #公式的嵌套 subDataSet = splitDateSet(dataSet, i, value) prob = len(subDataSet) / len(dataSet) newEntropy += prob*calcShannoEnt(subDataSet) infoGain = baseEntropy - newEntropy if (infoGain > baseInfoGain) : baseInfoGain = infoGain bestFeat = i return bestFeat
print(chooseBestFeatureToSplit(dataSet))
输出结果: 0
构造树:
def majorityCnt(classList): classCount = {} # print(classList) for vote in classList: if vote not in classCount.keys(): classCount[vote] = 0 classCount[vote] += 1 sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True) return sortedClassCount[0][0] def createTree(dataSet,labels): classList = [example[-1] for example in dataSet] #获取 if classList.count(classList[0]) == len(classList): #类别完全相同,停止划分 return classList[0] if len(dataSet[0]) == 1: return majorityCnt(classList) bestFeat = chooseBestFeatureToSplit(dataSet) #选取最好特征 bestFeatLabel = labels[bestFeat] #最好特征的标签 myTree = {bestFeatLabel:{}} del(labels[bestFeat]) #已经查找过的特征,进行删除 featValues = [example[bestFeat] for example in dataSet] #最好特征的value值 uniqueVals = set(featValues) # 单个不同value for value in uniqueVals: subLabels = labels[:] myTree[bestFeatLabel][value] = createTree(splitDateSet(dataSet, bestFeat, value), subLabels) #myTree 好好的理解 return myTree
myTree = createTree(dataSet, labels) print(myTree)
输出结果:
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
画出决策树:
def getNumLeafs(myTree): #获取叶子节点数目 numLeafs = 0 firstStr = list(myTree.keys())[0] ''' TypeError: ‘dict_keys’ object does not support indexing 这个问题是python版本的问题 #如果使用的是python2 firstStr = myTree.keys()[0] #LZ使用的是python3 firstSides = list(myTree.keys()) firstStr = firstSides[0] 是看决策树代码出现的问题,python3如果运行 firstStr = myTree.keys()[0] 就会报这个错误,解决办法就是先转换成list,再把需要的索引提取出来。 ''' secondDict = myTree[firstStr] for key in secondDict.keys(): #测试节点的数据类型是否为字典,若为字典则,不是叶子节点,否则是叶子节点 if type(secondDict[key]).__name__ == 'dict': numLeafs += getNumLeafs(secondDict[key]) else: numLeafs += 1 return numLeafs def getTreeDepth(myTree): #获取树的深度 maxDepth = 0 firstStr = list(myTree.keys())[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict': thisDepth = 1 + getTreeDepth(secondDict[key]) else: thisDepth = 1 if thisDepth > maxDepth: maxDepth = thisDepth return maxDepth def plotNode(nodeTxt, centerPt, parentPt, nodeType): createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args) decisionNode = dict(boxstyle='sawtooth', fc='0.8') leafNode = dict(boxstyle='round4', fc='0.8') arrow_args = dict(arrowstyle='<-') def plotMidText(cntrPt, parentPt, txtString): xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0] yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1] createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30) def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on numLeafs = getNumLeafs(myTree) #this determines the x width of this tree depth = getTreeDepth(myTree) firstStr = list(myTree.keys())[0] #the text label for this node should be this cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) plotMidText(cntrPt, parentPt, nodeTxt) plotNode(firstStr, cntrPt, parentPt, decisionNode) secondDict = myTree[firstStr] plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes plotTree(secondDict[key],cntrPt,str(key)) #recursion else: #it's a leaf node print the leaf node plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD #if you do get a dictonary you know it's a tree, and the first element will be another dict def createPlot(inTree): fig = plt.figure(1, facecolor='white') fig.clf() axprops = dict(xticks=[], yticks=[]) createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses plotTree.totalW = float(getNumLeafs(inTree)) plotTree.totalD = float(getTreeDepth(inTree)) plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0; plotTree(inTree, (0.5,1.0), '') plt.show()
使用决策树的分类函数:
# 对决策树 查找对应位置的值,进行分类的查询 def classify(inputTree,featLabels,testVec): firstStr = list(inputTree.keys())[0] secondDict = inputTree[firstStr] featIndex = featLabels.index(firstStr) key = testVec[featIndex] valueOfFeat = secondDict[key] if isinstance(valueOfFeat, dict): classLabel = classify(valueOfFeat, featLabels, testVec) else: classLabel = valueOfFeat return classLabel
classify(myTree, labels, [1,0])
输出结果: no
使用pickle模块存储决策树:
#决策树的存储 def storeTree(inputTree, filename): import pickle fw = open(filename, 'wb+') pickle.dump(inputTree, fw) fw.close() #打开存储的决策树 def grabTree(filename): import pickle fr = open(filename,'rb') return pickle.load(fr)
storeTree(myTree,'classifyfierStorage.txt') print(grabTree('classifyfierStorage.txt'))
决策树的例子:
它包含很多患者的眼部状况的观察条件以及医生推荐的隐形眼镜类型,其中隐形眼镜类型包括:硬材质(hard)、软材质(soft)和不适合佩戴隐形眼镜(no lenses) , 数据来源于UCI数据库。
young myope no reduced no lenses young myope no normal soft young myope yes reduced no lenses young myope yes normal hard young hyper no reduced no lenses young hyper no normal soft young hyper yes reduced no lenses young hyper yes normal hard pre myope no reduced no lenses pre myope no normal soft pre myope yes reduced no lenses pre myope yes normal hard pre hyper no reduced no lenses pre hyper no normal soft pre hyper yes reduced no lenses pre hyper yes normal no lenses presbyopic myope no reduced no lenses presbyopic myope no normal no lenses presbyopic myope yes reduced no lenses presbyopic myope yes normal hard presbyopic hyper no reduced no lenses presbyopic hyper no normal soft presbyopic hyper yes reduced no lenses presbyopic hyper yes normal no lenses
# coding=utf-8 import test if __name__ == '__main__': fr = open('lenses.txt') lenses = [item.strip().split('\t') for item in fr.readlines()] #print(lenses) lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate'] myTree = test.createTree(lenses, lensesLabels) print(myTree) test.createPlot(myTree)