决策树
本章内容
决策树简介
在数据集中度量一致性
使用递归构造决策树
使用Matplotlib绘制树形图
决策树(Decision tree)是一种非参数的监督学习方法,它主要用于分类和回归
。决策树的目的是构造一种模型,使之能够从样本数据的特征属性中,通过学习简单的决策规则——IF THEN规则
,也可以认为是定义在特征空间与类空间上的条件概率分布。本文主要讲解如何在实际环境中应用决策树算法,同时涉及如何使用Python工具和相关机器学习术语,算法原理详见决策树原理详解。
决策树算法主要包括三个部分:特征选择
、树的生成
、树的剪枝
。常用算法有 ID3、C4.5、CART。
Github代码获取
https://github.com/Ivan020121/Machine-Learning/tree/main/Decision%20tree
决策树的构造
决策树学习的算法通常是一个递归
地选择最优特征
,并根据该特征对训练数据进行分割,使得各个子数据集有一个最好的分类的过程。这一过程对应着对特征空间的划分
,也对应着决策树的构建。
决策树
优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据。
缺点:可能会产生过度匹配的问题。
适用数据类型:数值型和标称型。
首先:确定当前数据集上的决定性特征
,为了得到该决定性特征,必须评估每个特征,完成测试之后,原始数据集就被划分为几个数据子集,这些数据子集会分布在第一个决策点的所有分支上,如果某个分支下的数据属于同一类型
,则无需进一步对数据集进行分割,如果不属于同一类,则要重复划分数据子集,直到所有相同类型的数据均在一个数据子集内。
决策树的一般流程
(1) 收集数据:可以使用任何方法。比如想构建一个相亲系统,我们可以从媒婆那里,或者通过参访相亲对象获取数据。根据他们考虑的因素和最终的选择结果,就可以得到一些供我们利用的数据了。
(2)准备数据:收集完的数据,我们要进行整理,将这些所有收集的信息按照一定规则整理出来,并排版,方便我们进行后续处理。
(3) 分析数据:可以使用任何方法,决策树构造完成之后,我们可以检查决策树图形是否符合预期。
(4) 训练算法:这个过程也就是构造决策树,同样也可以说是决策树学习,就是构造一个决策树的数据结构。
(5) 测试算法:使用经验树计算错误率。当错误率达到了可接收范围,这个决策树就可以投放使用了。
(6)使用算法:此步骤可以使用适用于任何监督学习算法,而使用决策树可以更好地理解数据的内在含义。
创建分支的伪代码createBranch()如下图所示:
检测数据集中每个子项是否属于同一类:
If so return 类标签:
Else
寻找划分数据集的最好特征
划分数据集
创建分支节点
for 每个划分的子集
调用函数createBranch()并增加返回结果到分支节点中
return 分支节点
信息增益
划分数据集的大原则是:将无序数据变得更加有序
,但是各种方法都有各自的优缺点,信息论是量化处理信息的分支科学,在划分数据集前后信息发生的变化称为信息增益
,获得信息增益最高的特征就是最好的选择,,集合信息的度量方式称为香农熵,或者简称熵。
计算给定数据集的香农熵
from math import log
def calcShannonEnt(dataSet):
numEntries = len(dataSet)
labelCounts = {}
# the number of unique elements and their occurance
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key])/numEntries
# log base 2
shannonEnt -= prob * log(prob,2)
return shannonEnt
得到熵之后,我们就可以按照获取最大信息增益
的方式划分数据集。另一个度量集合无序性的方法是基尼不纯度(Gini impurity),简单地说就是从一个数据集中随机
选取子项,度量其被错误分类到其他组里的概率。
划分数据集
分类算法除了需要测量信息熵,还需要划分数据集,度量划分数据集的熵,以便判断按照哪个特征划分数据集是最好的划分方式。
按照给定特征划分数据集
def splitDataSet(dataSet, axis, value):
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
# chop out axis used for splitting
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
函数splitDataSet()
需要三个输入参数:待划分的数据集、划分数据集的特征、需要返回的特征的值。
接下来我们需要遍历整个数据集,循环计算香农熵和splitDataSet(),找到最好的划分方式。
选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet):
# the last column is used for the labels
numFeatures = len(dataSet[0]) - 1
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0; bestFeature = -1
# iterate over all the features
for i in range(numFeatures):
# create a list of all the examples of this feature
featList = [example[i] for example in dataSet]
# get a set of unique values
uniqueVals = set(featList)
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet)/float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
# calculate the info gain; ie reduction in entropy
infoGain = baseEntropy - newEntropy
# compare this to the best gain so far
if (infoGain > bestInfoGain):
# if better than current best, set to best
bestInfoGain = infoGain
bestFeature = i
# returns an integer
return bestFeature
函数chooseBestFeatureToSplit()实现选取特征,划分数据集,计算得出最好的划分数据集的特征。
递归构建决策树
从数据集构造决策树算法所需的子功能模块工作原理如下:得到原始数据集,然后基于最好的属性值
划分数据集,由于特征值可能多于两个,因此可能存在大于两个分支的数据集划分,第一次划分之后,数据将被向下传递到树分支的下一个节点,在此节点在此划分数据,因此可以使用递归
的原则处理数据集。
递归结束的条件是:程序完全遍历所有划分数据集的属性,或者每个分支下的所有实例都具有相同的分类,如果所有实例具有相同的分类,则得到一个叶子节点或者终止块,任何到达叶子节点的数据必然属于叶子节点的分类。
创建树的函数代码
def createTree(dataSet,labels):
classList = [example[-1] for example in dataSet]
if classList.count(classList[0]) == len(classList):
# stop splitting when all of the classes are equal
return classList[0]
# stop splitting when there are no more features in dataSet
if len(dataSet[0]) == 1:
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel:{}}
del(labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
# copy all of labels, so trees don't mess up existing labels
subLabels = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
return myTree
函数createTree()使用两个输入参数:数据集合
和标签列表
。标签列表包含了数据集中所有特征的标签,算法本身并不需要这个变量,但是为了给出数据明确的含义,我们将它作为一个输入参数提供。
在Python中使用Matplotlib注解绘制树形图
通过字典表示决策树非常不易于理解,而且直接绘制图形也比较困难。因此我们使用Matplotlib
库创建树形图,决策树的主要优点就是直观易于理解,如果不能将其直观的显示出来,就无法发挥其优势。
决策树的范例
Matplotlib注解
Matplotlib提供了一个非常有用的注解工具annotations
,它可以在数据图像上添加文本注解
。通过注解功能绘制树形图,它可以对文字着色并提供多种形状以供选择,而且可以反转箭头将它指向文本框而不是数据点。
使用文本注解绘制树节点
import matplotlib.pyplot as plt
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks
#createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
plotTree(inTree, (0.5,1.0), '')
plt.show()
代码定义了描述树节点格式的常量,然后定义plotNode()函数执行了实际的绘图功能,该函数需要一个绘图区,该区域由全局变量createPlot.ax1定义。
构造注解树
绘制一棵完整的树需要一些技巧,必须知道有多少个叶节点
,以便可以正确定义x轴的长度;我们还需要知道书有多少层,以便可以正确确定y轴的高度。
这里我们定义两个新函数getNumLeafs()
和getTreeDepth()
,来获取叶节点的数目和树的层数。
获取叶节点的数目和树的层数
def getNumLeafs(myTree):
numLeafs = 0
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
numLeafs += getNumLeafs(secondDict[key])
else: numLeafs +=1
return numLeafs
def getTreeDepth(myTree):
maxDepth = 0
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
thisDepth = 1 + getTreeDepth(secondDict[key])
else: thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth
简单数据集绘制的树形图
plotTree函数
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
numLeafs = getNumLeafs(myTree) #this determines the x width of this tree
depth = getTreeDepth(myTree)
firstStr = myTree.keys()[0] #the text label for this node should be this
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
plotTree(secondDict[key],cntrPt,str(key)) #recursion
else: #it's a leaf node print the leaf node
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
#if you do get a dictonary you know it's a tree, and the first element will be another dict
函数createPlot()是我们使用的主函数,它调用了ployTree(),函数又依次调用了前面的函数plotMidText()。
测试和存储分类器
使用决策树构建分类器后,将其用于实际数据的分类。
测试算法:使用决策树执行分类
依靠训练数据构造了决策树之后,我们可以将它用于实际数据的分类。在执行数据分类时,需要使用决策树以及用于构造决策树的标签向量。然后,程序比较测试数据与决策树上的数值,递归执行该过程直到进入叶子节点;最后将测试数据定义为叶子节点所属的类型。
使用决策树的分类函数
def classify(inputTree,featLabels,testVec):
firstStr = inputTree.keys()[0]
secondDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr)
key = testVec[featIndex]
valueOfFeat = secondDict[key]
if isinstance(valueOfFeat, dict):
classLabel = classify(valueOfFeat, featLabels, testVec)
else: classLabel = valueOfFeat
return classLabel
使用算法:决策树的存储
构造决策树是很耗时的任务,即使处理很小的数据集,如前面的样本数据,也要花费几秒的时间,如果数据集很大,将会耗费很多计算时间
。然而用创建好的决策树解决分类问题,则可以很快完成。因此,为了节省计算时间,最好能够在每次执行分类时调用已经构造好的决策树。为了解决这个问题,需要使用Python模块picke
序列化对象,序列化对象可以在磁盘
上保存对象,并在需要的时候读取出来。任何对象都可以执行序列化操作,字典对象也不例外。
使用pickle模块存储决策树
def storeTree(inputTree,filename):
import pickle
fw = open(filename,'w')
pickle.dump(inputTree,fw)
fw.close()
def grabTree(filename):
import pickle
fr = open(filename)
return pickle.load(fr)
通过上面的代码,我们可以将分类器存储在硬盘上,而不用每次对数据分类时重新学习一遍,这也是决策树的优点之一,我们可以预先提炼并存储数据集中包含的知识信息,在需要对事物进行分类时再使用这些知识。
总结
优点
(1) 易于理解和解释,决策树可以可视化。
(2) 几乎不需要数据预处理。其他方法经常需要数据标准化,创建虚拟变量和删除缺失值。
(3) 使用树的花费(例如预测数据)是训练数据点(data points)数量的对数。
(4) 可以同时处理数值变量和分类变量。其他方法大都适用于分析一种变量的集合。
(5) 可以处理多值输出变量问题。
(6) 使用白盒模型。如果一个情况被观察到,使用逻辑判断容易表示这种规则。相反,如果是黑盒模型(例如人工神经网络),结果会非常难解释。
(7) 即使对真实模型来说,假设无效的情况下,也可以较好的适用。
缺点
(1) 决策树学习可能创建一个过于复杂的树,并不能很好的预测数据。也就是过拟合。修剪机制(现在不支持),设置一个叶子节点需要的最小样本数量,或者数的最大深度,可以避免过拟合。
(2) 决策树可能是不稳定的,因为即使非常小的变异,可能会产生一颗完全不同的树。这个问题通过decision trees with an ensemble来缓解。
(3) 学习一颗最优的决策树是一个NP-完全问题under several aspects of optimality and even for simple concepts。因此,传统决策树算法基于启发式算法,例如贪婪算法,即每个节点创建最优决策。这些算法不能产生一个全家最优的决策树。对样本和特征随机抽样可以降低整体效果偏差。
(4) 概念难以学习,因为决策树没有很好的解释他们,例如,XOR, parity or multiplexer problems.
(5) 如果某些分类占优势,决策树将会创建一棵有偏差的树。因此,建议在训练之前,先抽样使样本均衡。
🤯本文作者:Ivan
🔗本文链接:https://ivan020121.github.io/2022/04/01/Decisiontree/
🔁版权声明:本站所有文章除特别声明外,均采用©BY-NC-SA许可协议。转载请注明出处!