decisionTree

一、引言

k-近邻算法可以完成很多分类任务,但是它最大的缺点就是无法给出数据的内
在含义,决策树的主要优势就在于数据形式非常容易理解。
而决策树算法能够读取数据集合,构建决策树。决策树很多任务都
是为了数据中所蕴含的知识信息,因此决策树可以使用不熟悉的数据集合,并从中提取出一系列规则,机器学习算法最终将使用这些机器从数据集中创造的规则。
专家系统中经常使用决策树,而且决策树给出结果往往可以匹敌在当前领域具有几十年工作经验的人类专家。

决策树的一般流程
(1) 收集数据:可以使用任何方法。
(2) 准备数据:树构造算法只适用于标称型数据,因此数值型数据必须离散化。
(3) 分析数据:可以使用任何方法,构造树完成之后,我们应该检查图形是否符合预期。
(4) 训练算法:构造树的数据结构。
(5) 测试算法:使用经验树计算错误率。
(6) 使用算法:此步骤可以适用于任何监督学习算法,而使用决策树可以更好地理解数据
的内在含义。

机器学习中采用的ID3算法划分数据集,每次划分数据集时我们只选取一个特征属性,如果训练集中存在20个特征,第一次我们选择哪个特征作为划分的参考属性呢?

###二、实现

优点:计算复杂度不高,输出结构易于理解,对中间值的缺失不敏感,可以处理不相关特征数据
缺点:可能会产生过度匹配问题

  1. 熵表示随机变量的不确定性。
  2. 条件熵表示在一个条件下,随机变量的不确定性
  3. 信息增益=熵-条件熵

python 实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
from math import log
import operator
import matplotlib.pyplot as plt #绘图
# 数据源
def CreatDataSet():
dataSet = [[1,1,'yes'],
[1,1,'yes'],
[1,0,'no'],
[0,1,'no'],
[0,1,'no']]
labels = ['no surfacing','flippers']
return dataSet,labels

# 1.0 计算信息熵 熵越多,混合的数据越多,分类越多 熵越高
def calcsShangnonEnt(dataSet):
numEntries = len(dataSet)
dic={}
for data in dataSet:
currentVec = data[-1]
if currentVec not in dic.keys():
dic[currentVec]=0
dic[currentVec]+=1
ShangnonEnt = 0.0
for key in dic:
prob = float(dic[key])/numEntries
ShangnonEnt -= prob * log(prob,2)
return ShangnonEnt

# 2.0 划分数据集
def splitDataSet(dataSet,axis,value):
retDataSet= []
for featVec in dataSet:
# print(featVec[axis])
if featVec[axis] == value:
reducedFeatVec = featVec[:axis] #去除特征 第i列就去除第i列数据
reducedFeatVec.extend(featVec[axis+1:]) #返回去除特征值后的List
retDataSet.append(reducedFeatVec)
return retDataSet


#3.0 正式使用
#满足的要求: 1.数据必须是一种由列表[]组成的列表[],而且所有的列表元素要有相同的数据长度;
# 2.数据的最后一列或者每个实例的最后一个元素是当前实例的列别标签
# 熵,条件熵增强理解: https://blog.csdn.net/xwd18280820053/article/details/70739368
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0])-1 #得到特征值 -1 ,最后一个是类别
baseEntropy = calcsShangnonEnt(dataSet) #计算整个数据集的原始香农熵,用于与划分之后的数据集计算的熵进行比较
bestInfoGain=0.0;bestFeature = -1
for i in range(numFeatures):
featList= [example[i] for example in dataSet] #将dataset的每条数据的第i列出来
uniqueVals = set(featList) #去除重复值
newEntropy = 0.0
for value in uniqueVals: #遍历当前特征值的所有唯一属性值
subDataSet = splitDataSet(dataSet,i, value) #对每个特征划分一次数据集
prob = len(subDataSet)/float(len(dataSet)) #计算p
newEntropy +=prob * calcsShangnonEnt(subDataSet)#计算条件熵
# print(newEntropy)
infoGain = baseEntropy- newEntropy #计算信息增益,在一个条件下,信息不确定性减少的程度
if(infoGain>bestInfoGain):
bestInfoGain = infoGain
bestFeature=i
return bestFeature

def majorityCnt(classList):
classCount={}
for vote in classList:
if vote not in classCount:
classCount[vote] = 0;
classCount[vote]+=1
soredClassCount = sorted(clssCount.items(),key=operator.itemgetter(1),reverse=True)
return soredClassCount[0][0]


#递归 分类后的结果
def createTree(dataSet,labels):
classList = [example[-1] for example in dataSet]
if classList.count(classList[0])==len(classList):
return classList[0]
if len(dataSet[0])==1:
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLable= labels[bestFeat]
myTree = {bestFeatLable:{}}
del(labels[bestFeat])
featValues = [data[bestFeat] for data in dataSet]
uniqueVals = set(featValues)
for val in uniqueVals:
subLabels = labels[:]
myTree[bestFeatLable][val] = createTree(splitDataSet(dataSet,bestFeat,val),subLabels)
return myTree

三、绘图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
###############绘图############

#获取叶节点
def getNumLeafs(myTree):
numLeafs =0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs +=1
return numLeafs

def getTreeDepth(myTree):
maxDepth = 0
firstStr= list(myTree.keys())[0]
secondDict=myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
thisDepth =1+getTreeDepth((secondDict[key]))
else:
thisDepth =1
if thisDepth>maxDepth:maxDepth = thisDepth
return maxDepth
def plotMidText(cntrPt,parentPt,txtString):
xMid=(parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
yMid=(parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid,yMid,txtString)

decisionNode = dict(boxstyle="sawtooth", fc="0.8")
arrow_args = dict(arrowstyle="<-")
leafNode = dict(boxstyle="round4", fc="0.8")
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
# annotate是关于一个数据点的文本
# nodeTxt为要显示的文本,centerPt为文本的中心点,箭头所在的点,parentPt为指向文本的点
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )


def plotTree(myTree,parentPt,nodeTxt):
numLeafs= getNumLeafs(myTree)
depth = getTreeDepth(myTree)
firstStr = list(myTree.keys())[0]
cntrPt = (plotTree.xOff+(1.0 +float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
plotMidText(cntrPt,parentPt,nodeTxt)
plotNode(firstStr,cntrPt,parentPt,decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff -1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
plotTree(secondDict[key],cntrPt,str(key))
else:
plotTree.xOff = plotTree.xOff +1.0/plotTree.totalW
plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
plotTree.yOff = plotTree.yOff +1.0/plotTree.totalD

def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])# 定义横纵坐标轴,无内容
#createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) # 绘制图像,无边框,无坐标轴
createPlot.ax1 = plt.subplot(111, frameon=False)
plotTree.totalW = float(getNumLeafs(inTree)) #全局变量宽度 = 叶子数
plotTree.totalD = float(getTreeDepth(inTree)) #全局变量高度 = 深度
#图形的大小是0-1 ,0-1
plotTree.xOff = -0.5/plotTree.totalW; #例如绘制3个叶子结点,坐标应为1/3,2/3,3/3
#但这样会使整个图形偏右因此初始的,将x值向左移一点。
plotTree.yOff = 1.0;
plotTree(inTree, (0.5,1.0), '')
plt.show()


##############绘图结束###############

绘图结构

坚持记录世界,分享世界