Ivan
  • 首页
  • 友链
  • 留言
  • 关于

Decision tree

Ivan 发布于 2022-04-01

  • 📒 Machine Learning
  • 📒 Classification
  • 🏷️ Python
  • 🏷️ Machine Learning
  • 🏷️ Decision tree

决策树

本章内容
决策树简介
在数据集中度量一致性
使用递归构造决策树
使用Matplotlib绘制树形图

决策树(Decision tree)是一种非参数的监督学习方法,它主要用于分类和回归。决策树的目的是构造一种模型,使之能够从样本数据的特征属性中,通过学习简单的决策规则——IF THEN规则,也可以认为是定义在特征空间与类空间上的条件概率分布。本文主要讲解如何在实际环境中应用决策树算法,同时涉及如何使用Python工具和相关机器学习术语,算法原理详见决策树原理详解。
决策树算法主要包括三个部分:特征选择、树的生成、树的剪枝。常用算法有 ID3、C4.5、CART。

Github代码获取

https://github.com/Ivan020121/Machine-Learning/tree/main/Decision%20tree

决策树的构造

决策树学习的算法通常是一个递归地选择最优特征,并根据该特征对训练数据进行分割,使得各个子数据集有一个最好的分类的过程。这一过程对应着对特征空间的划分,也对应着决策树的构建。

决策树
优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据。
缺点:可能会产生过度匹配的问题。
适用数据类型:数值型和标称型。

首先:确定当前数据集上的决定性特征,为了得到该决定性特征,必须评估每个特征,完成测试之后,原始数据集就被划分为几个数据子集,这些数据子集会分布在第一个决策点的所有分支上,如果某个分支下的数据属于同一类型,则无需进一步对数据集进行分割,如果不属于同一类,则要重复划分数据子集,直到所有相同类型的数据均在一个数据子集内。

决策树的一般流程
(1) 收集数据:可以使用任何方法。比如想构建一个相亲系统,我们可以从媒婆那里,或者通过参访相亲对象获取数据。根据他们考虑的因素和最终的选择结果,就可以得到一些供我们利用的数据了。
(2)准备数据:收集完的数据,我们要进行整理,将这些所有收集的信息按照一定规则整理出来,并排版,方便我们进行后续处理。
(3) 分析数据:可以使用任何方法,决策树构造完成之后,我们可以检查决策树图形是否符合预期。
(4) 训练算法:这个过程也就是构造决策树,同样也可以说是决策树学习,就是构造一个决策树的数据结构。
(5) 测试算法:使用经验树计算错误率。当错误率达到了可接收范围,这个决策树就可以投放使用了。
(6)使用算法:此步骤可以使用适用于任何监督学习算法,而使用决策树可以更好地理解数据的内在含义。

创建分支的伪代码createBranch()如下图所示:

检测数据集中每个子项是否属于同一类:
    If so return 类标签:
    Else
        寻找划分数据集的最好特征
        划分数据集
        创建分支节点
            for 每个划分的子集
                调用函数createBranch()并增加返回结果到分支节点中
            return 分支节点

信息增益

划分数据集的大原则是:将无序数据变得更加有序,但是各种方法都有各自的优缺点,信息论是量化处理信息的分支科学,在划分数据集前后信息发生的变化称为信息增益,获得信息增益最高的特征就是最好的选择,,集合信息的度量方式称为香农熵,或者简称熵。

计算给定数据集的香农熵

from math import log

def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelCounts = {}
    # the number of unique elements and their occurance
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        # log base 2
        shannonEnt -= prob * log(prob,2)
    return shannonEnt

得到熵之后,我们就可以按照获取最大信息增益的方式划分数据集。另一个度量集合无序性的方法是基尼不纯度(Gini impurity),简单地说就是从一个数据集中随机选取子项,度量其被错误分类到其他组里的概率。

划分数据集

分类算法除了需要测量信息熵,还需要划分数据集,度量划分数据集的熵,以便判断按照哪个特征划分数据集是最好的划分方式。
按照给定特征划分数据集

def splitDataSet(dataSet, axis, value):
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            # chop out axis used for splitting
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

函数splitDataSet()需要三个输入参数:待划分的数据集、划分数据集的特征、需要返回的特征的值。
接下来我们需要遍历整个数据集,循环计算香农熵和splitDataSet(),找到最好的划分方式。
选择最好的数据集划分方式

def chooseBestFeatureToSplit(dataSet):
    # the last column is used for the labels
    numFeatures = len(dataSet[0]) - 1
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0; bestFeature = -1
    # iterate over all the features
    for i in range(numFeatures):
        # create a list of all the examples of this feature
        featList = [example[i] for example in dataSet]
        # get a set of unique values
        uniqueVals = set(featList)
        newEntropy = 0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet)/float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)
            # calculate the info gain; ie reduction in entropy
        infoGain = baseEntropy - newEntropy
        # compare this to the best gain so far
        if (infoGain > bestInfoGain):
            # if better than current best, set to best
            bestInfoGain = infoGain
            bestFeature = i
            # returns an integer
    return bestFeature

函数chooseBestFeatureToSplit()实现选取特征,划分数据集,计算得出最好的划分数据集的特征。

递归构建决策树

从数据集构造决策树算法所需的子功能模块工作原理如下:得到原始数据集,然后基于最好的属性值划分数据集,由于特征值可能多于两个,因此可能存在大于两个分支的数据集划分,第一次划分之后,数据将被向下传递到树分支的下一个节点,在此节点在此划分数据,因此可以使用递归的原则处理数据集。
递归结束的条件是:程序完全遍历所有划分数据集的属性,或者每个分支下的所有实例都具有相同的分类,如果所有实例具有相同的分类,则得到一个叶子节点或者终止块,任何到达叶子节点的数据必然属于叶子节点的分类。
创建树的函数代码

def createTree(dataSet,labels):
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0]) == len(classList):
        # stop splitting when all of the classes are equal
        return classList[0]
    # stop splitting when there are no more features in dataSet
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel:{}}
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        # copy all of labels, so trees don't mess up existing labels
        subLabels = labels[:]
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
    return myTree

函数createTree()使用两个输入参数:数据集合和标签列表。标签列表包含了数据集中所有特征的标签,算法本身并不需要这个变量,但是为了给出数据明确的含义,我们将它作为一个输入参数提供。

在Python中使用Matplotlib注解绘制树形图

通过字典表示决策树非常不易于理解,而且直接绘制图形也比较困难。因此我们使用Matplotlib库创建树形图,决策树的主要优点就是直观易于理解,如果不能将其直观的显示出来,就无法发挥其优势。

决策树的范例

Matplotlib注解

Matplotlib提供了一个非常有用的注解工具annotations,它可以在数据图像上添加文本注解。通过注解功能绘制树形图,它可以对文字着色并提供多种形状以供选择,而且可以反转箭头将它指向文本框而不是数据点。
使用文本注解绘制树节点

import matplotlib.pyplot as plt

decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")

def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',
             xytext=centerPt, textcoords='axes fraction',
             va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )

def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #no ticks
    #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
    plotTree(inTree, (0.5,1.0), '')
    plt.show()

代码定义了描述树节点格式的常量,然后定义plotNode()函数执行了实际的绘图功能,该函数需要一个绘图区,该区域由全局变量createPlot.ax1定义。

构造注解树

绘制一棵完整的树需要一些技巧,必须知道有多少个叶节点,以便可以正确定义x轴的长度;我们还需要知道书有多少层,以便可以正确确定y轴的高度。
这里我们定义两个新函数getNumLeafs()和getTreeDepth(),来获取叶节点的数目和树的层数。
获取叶节点的数目和树的层数

def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
            numLeafs += getNumLeafs(secondDict[key])
        else:   numLeafs +=1
    return numLeafs

def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:   thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth


简单数据集绘制的树形图
plotTree函数

def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
    numLeafs = getNumLeafs(myTree)  #this determines the x width of this tree
    depth = getTreeDepth(myTree)
    firstStr = myTree.keys()[0]     #the text label for this node should be this
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes   
            plotTree(secondDict[key],cntrPt,str(key))        #recursion
        else:   #it's a leaf node print the leaf node
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
#if you do get a dictonary you know it's a tree, and the first element will be another dict

函数createPlot()是我们使用的主函数,它调用了ployTree(),函数又依次调用了前面的函数plotMidText()。

测试和存储分类器

使用决策树构建分类器后,将其用于实际数据的分类。

测试算法:使用决策树执行分类

依靠训练数据构造了决策树之后,我们可以将它用于实际数据的分类。在执行数据分类时,需要使用决策树以及用于构造决策树的标签向量。然后,程序比较测试数据与决策树上的数值,递归执行该过程直到进入叶子节点;最后将测试数据定义为叶子节点所属的类型。
使用决策树的分类函数

def classify(inputTree,featLabels,testVec):
    firstStr = inputTree.keys()[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    key = testVec[featIndex]
    valueOfFeat = secondDict[key]
    if isinstance(valueOfFeat, dict): 
        classLabel = classify(valueOfFeat, featLabels, testVec)
    else: classLabel = valueOfFeat
    return classLabel

使用算法:决策树的存储

构造决策树是很耗时的任务,即使处理很小的数据集,如前面的样本数据,也要花费几秒的时间,如果数据集很大,将会耗费很多计算时间。然而用创建好的决策树解决分类问题,则可以很快完成。因此,为了节省计算时间,最好能够在每次执行分类时调用已经构造好的决策树。为了解决这个问题,需要使用Python模块picke序列化对象,序列化对象可以在磁盘上保存对象,并在需要的时候读取出来。任何对象都可以执行序列化操作,字典对象也不例外。
使用pickle模块存储决策树

def storeTree(inputTree,filename):
    import pickle
    fw = open(filename,'w')
    pickle.dump(inputTree,fw)
    fw.close()
    
def grabTree(filename):
    import pickle
    fr = open(filename)
    return pickle.load(fr)

通过上面的代码,我们可以将分类器存储在硬盘上,而不用每次对数据分类时重新学习一遍,这也是决策树的优点之一,我们可以预先提炼并存储数据集中包含的知识信息,在需要对事物进行分类时再使用这些知识。

总结

优点

(1) 易于理解和解释,决策树可以可视化。
(2) 几乎不需要数据预处理。其他方法经常需要数据标准化,创建虚拟变量和删除缺失值。
(3) 使用树的花费(例如预测数据)是训练数据点(data points)数量的对数。
(4) 可以同时处理数值变量和分类变量。其他方法大都适用于分析一种变量的集合。
(5) 可以处理多值输出变量问题。
(6) 使用白盒模型。如果一个情况被观察到,使用逻辑判断容易表示这种规则。相反,如果是黑盒模型(例如人工神经网络),结果会非常难解释。
(7) 即使对真实模型来说,假设无效的情况下,也可以较好的适用。

缺点

(1) 决策树学习可能创建一个过于复杂的树,并不能很好的预测数据。也就是过拟合。修剪机制(现在不支持),设置一个叶子节点需要的最小样本数量,或者数的最大深度,可以避免过拟合。
(2) 决策树可能是不稳定的,因为即使非常小的变异,可能会产生一颗完全不同的树。这个问题通过decision trees with an ensemble来缓解。
(3) 学习一颗最优的决策树是一个NP-完全问题under several aspects of optimality and even for simple concepts。因此,传统决策树算法基于启发式算法,例如贪婪算法,即每个节点创建最优决策。这些算法不能产生一个全家最优的决策树。对样本和特征随机抽样可以降低整体效果偏差。
(4) 概念难以学习,因为决策树没有很好的解释他们,例如,XOR, parity or multiplexer problems.
(5) 如果某些分类占优势,决策树将会创建一棵有偏差的树。因此,建议在训练之前,先抽样使样本均衡。


🤯本文作者:Ivan
🔗本文链接:https://ivan020121.github.io/2022/04/01/Decisiontree/
🔁版权声明:本站所有文章除特别声明外,均采用©BY-NC-SA许可协议。转载请注明出处!

Naive bayes

Newer

K-nearest neighbor

Older
Ivan

Ivan

学无止境

7
3
11
TOC
  1. 1. Github代码获取
  2. 2. 决策树的构造
    1. 2.1. 信息增益
    2. 2.2. 划分数据集
    3. 2.3. 递归构建决策树
  3. 3. 在Python中使用Matplotlib注解绘制树形图
    1. 3.1. Matplotlib注解
    2. 3.2. 构造注解树
  4. 4. 测试和存储分类器
    1. 4.1. 测试算法:使用决策树执行分类
    2. 4.2. 使用算法:决策树的存储
  5. 5. 总结
    1. 5.1. 优点
    2. 5.2. 缺点
NOTICE

我的动力来自盲目和愚钝。

CATEGORYS
  • Machine Learning (3)
  • Classification (3)
  • 学习笔记 (2)
TAGS
Anaconda Decision tree Linux Machine Learning Markdown MongoDB Naive bayes Python Server VSCode kNN

© 2022 Ivan

Powered by Hexo Theme - flex-block

🌞 浅色 🌛 深色 🤖️ 自动