第3章 决策树

决策树经常处理分类问题,近来的调查表明决策树也是经常使用的数据挖掘算法。

决策树的流程图:

长方形代表判断模块(decision block),椭圆形代表中止模块(terminating block),表示已经得出结论,可以中止运行。

从判断模块引出左右箭头称作分支(branch),它可以到底另一个判断模块或者中止模块。

决策树算法能够读取数据集合,构建决策树

决策树的一个重要任务是为了数据中蕴含的知识信息,因此决策树可以使用不熟悉的数据集合,并从中提取出一系列规则,在这些机器根据数据集创建规则时,就是机器学习的过程。

3.1 决策树的构造

  1. 优点:计算复杂度不高,输出结果易于理解,对中间值确实不敏感,可以处理不相关特征数据。
  2. 缺点:可能会产生过度匹配问题。
  3. 适用数据类型:数值型和标称型。

创建分支的伪代码函数createBranch()如下所示:

检测数据集中的每个子项是否属于同一分类:
    If so return 类标签;
    Else
        寻找划分数据集的最好特征
        划分数据集
        创建分支节点
            for 每个划分的子集
                ## 递归调用createBranch
                调用createBranch并增加返回结果到分支节点中
        return 分支节点

决策树的一般流程:

  1. 收集数据:可以使用任何方法。
  2. 准备数据:树构造只是用于标称型数据,因此数值型数据必须离散化。
  3. 分析数据:可以使用任何方法,构造树完成之后,我们应该检查图形是否符合预期。
  4. 训练算法:构造树的数据结构。
  5. 测试算法:使用经验树计算错误率。
  6. 使用算法:此步骤可以适用于任何监督学习算法,而使用决策树可以更好地理解数据的内在含义。

3.1.1 信息增益

在划分数据集之前之后信息发生的变换称为信息增益。

获得信息增益最高的特征就是最好的选择。

熵:定义为信息的期望值

如果待分类的事务划分在多个分类之中,在符号 x i x_i xi的信息定义为:
l ( x i ) = − log ⁡ 2 p ( x i ) l(x_i) = - \log _{2} p(x_i) l(xi)=−log2p(xi)

为了计算熵,我们需要计算所有类别所有可能值包含的信息期望值,通过下面的公式得到:
H = − ∑ i = 1 n p ( x i ) log ⁡ 2 p ( x i ) H = - \sum _{i = 1}^{n} p(x_i) \log _{2} p(x_i) H=−i=1∑np(xi)log2p(xi)

计算给定数据集的香农熵

    from math import log

    def calcShannonEnt(dataSet):
        numEntries = len(dataSet)
        labelCount = {}

        ##为所有可能分类创建字典
        for featVec in dataSet:
            currentLabel = featVec[-1]
            if currentLabel not in labelCount.keys():
                labelCount[currentLabel] = 0
            labelCount[currentLabel] += 1
        shannonEnt = 0.0


        for key in labelCount:
            prob = float(labelCount[key]) / numEntries
            shannonEnt -= prob * log(prob, 2)       ##以2为底数求对数
        return shannonEnt
    ##简单鱼鉴定数据集
    def creatDataSet():
        dateSet = [[1, 1, 'yes'],
                [1, 1, 'yes'],
                [1, 0, 'no'],
                [0, 1, 'no'],
                [0, 1, 'no']]

        labels = ['no surfacing', 'flippers']
        return dateSet, labels

    if __name__ == '__main__':
        dataSet, labels = creatDataSet()
        shannonEnt = calcShannonEnt(dataSet)
        print(shannonEnt)

3.1.2 划分数据集

按照给定特征划分数据集

    def splitDataSet(dataSet, axis, value):
        '''

        :param dataSet: 待划分的数据集
        :param axis: 划分数据集的特征
        :param value: 需要返回的特征的值
        :return:  结果数据集
        '''
        retDataSet = []     ##创建一个新的list对象
        for featVec in dataSet:
            if featVec[axis] == value:
                ##抽取
                reducedFeatVec = featVec[:axis] 
                reducedFeatVec.extend(featVec[axis + 1:])
                retDataSet.append(reducedFeatVec)
        return retDataSet
    
    if __name__ == '__main__':
        dataSet, labels = creatDataSet()
        retDataSet = splitDataSet(dataSet, 0, 1)
        print(retDataSet)

Python列表对象的方法,append()和extend()用于在列表末尾添加元素,但它们的用法和效果有所不同。

  1. append()方法:
    用法:list.append(element)
    参数:element,要添加的单个元素。
    功能:将指定的元素添加到列表的末尾。
    结果:在列表的末尾添加一个新的元素,并扩展列表的长度。
    特点:可以添加任意类型的元素,包括可迭代对象(如列表)。当添加可迭代对象时,整个对象作为单个元素添加到列表中。
  2. extend()方法:
    用法:list.extend(iterable)
    参数:iterable,一个可迭代对象,如列表、元组、字符串等。
    功能:将可迭代对象中的每个元素添加到列表的末尾。
    结果:不会创建一个新的列表对象,而是在原列表的末尾追加元素。
    特点:只能添加可迭代对象的元素值,而不是整个可迭代对象。当添加的是另一个列表时,extend()会将那个列表中的每个元素逐个添加到原列表中。
  3. 总结:
    append()适用于添加单个元素或可迭代对象,但整个可迭代对象被视为单个元素添加。
    extend()适用于添加可迭代对象的每个元素,而不是整个可迭代对象。
    这些方法都不返回新列表,而是直接修改原列表。

选择最好的数据集划分方式

    def chooseBestFeatureToSplit(dataSet):
        numFeatures = len(dataSet[0]) - 1
        baseEntropy = calcShannonEnt(dataSet)
        bestInfoGain = 0.0
        bestFeature = -1

        for i in range(numFeatures):
            ## 创建唯一的分类标签列表
            featList = [example[i] for example in dataSet]
            uniqueVals = set(featList)
            newEntropy = 0.0
            ## 计算每种划分方式的信息熵
            for value in uniqueVals:
                subDataSet = splitDataSet(dataSet, i, value)
                prob = len(subDataSet) / float(len(dataSet))
                newEntropy += prob * calcShannonEnt(subDataSet)
            infoGain = baseEntropy - newEntropy
            ## 计算最好的信息增益
            if (infoGain > bestInfoGain) :
                bestInfoGain = infoGain
                bestFeature = i
        return bestFeature

    if __name__ == '__main__':
        dataSet, labels = creatDataSet()
        feature = chooseBestFeatureToSplit(dataSet)
        print(feature)

3.1.3 递归构建决策树

其工作原理如下:

创建原始数据集,然后基于最好的属性值划分数据集合,由于特征值可能多余两个,因此可能存在大于两个分支的数据集划分。第一次划分之后,数据将向下递归到树分支的下一个节点,在这个节点上我们再次划分数据。因此我们可以采用递归的原则处理数据集。

递归结束的条件:

程序遍历完所有划分数据集的性质,或者每个分支下的所有实例都有相同的分类。如果所有实例具有相同的分类,则得到一个叶子节点或者终止块。任何到达叶子节点的数据必然属于叶子节点的分类。

	import operator
	
	def majorityCnt(classList):
	    '''
	
	    :param classList: 分类名称列表
	    :return: 出现次数最多的分类标签
	    '''
	    classCount = {}
	    for vote in classList:
	        if vote not in classCount.keys(): classCount[vote] = 0
	        classCount[vote] += 1
	    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
	    return sortedClassCount[0][0]

创建树的函数代码:

	def createTree(dataSet, labels):
	    '''
	
	    :param dataSet: 数据集
	    :param labels: 标签列
	    :return:
	    '''
	    classList = [example[-1] for example in dataSet]
	    ## 类别完全相同则停止继续划分
	    if classList.count(classList[0]) == len(classList):
	        return classList[0]
	    ## 遍历完所有特征值时返回出现次数最多的
	    if len(dataSet[0]) == 1 :
	        return majorityCnt(classList)
	    bestFeat = chooseBestFeatureToSplit(dataSet)
	    bestFeatLabel = labels[bestFeat]
	    myTree = {bestFeatLabel: {}}
	    ##得到列表包含的所有属性值
	    del(labels[bestFeat])
	    featValues = [example[bestFeat] for example in dataSet]
	    uniqueVals = set(featValues)
	    for value in uniqueVals:
	        subLabels = labels[:]
	        myTree[bestFeatLabel][value]= createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
	    return myTree
	
	if __name__ == '__main__':
	    myDat, labels = creatDataSet()
	    myTree = createTree(myDat, labels)
	    print(myTree)

3.2 在Python中使用Matplotlib注解绘制树形图

3.2.1 Matplotlib注解

Matplotlib提供了一个注解工具annotations,非常有用,它可以在数据图形上添加文本注释。注解通常用于解释数据的内容。

使用文本注解绘制树节点

    import matplotlib
    matplotlib.use('TkAgg')
    import matplotlib.pyplot as plt
    ## 定义文本框的箭头格式
    decisionNode = dict(boxstyle="sawtooth", fc="0.8")
    leafNode = dict(boxstyle="round4", fc="0.8")
    arrow_args = dict(arrowstyle="<-")

    def plotNode(nodeTxt, centerPt, parentPt, nodeType):

        '''
        绘制带箭头的注解
        :param nodeTxt:
        :param centerPt:
        :param parentPt:
        :param nodeType:
        :return:
        '''
        createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                                xytext=centerPt, textcoords='axes fraction',
                                va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)

    def createPlot():
        fig = plt.figure(1, facecolor='white')
        fig.clf()
        createPlot.ax1 = plt.subplot(111, frameon= False)
        plotNode(U'Decision Node', (0.5, 0.1), (0.1, 0.5), decisionNode)
        plotNode(U'Leaf Node', (0.8, 0.1), (0.3, 0.8), leafNode)

        plt.show()


    if __name__ == '__main__':
        createPlot();

3.2.2 构造注解树

获得叶节点的数目和树的层数

    def getNumLeafs(myTree):
        numLeafs = 0;
        firstStr = list(myTree)[0]
        secondDict = myTree[firstStr]
        for key in secondDict.keys():
            ## 测试节点的数据类型是否为字典
            if type(secondDict[key]).__name__=='dict':
                numLeafs += getNumLeafs(secondDict[key])
            else:
                numLeafs += 1
        return numLeafs

    def getTreeDepth(myTree):
        maxDepth = 0
        firstStr = list(myTree)[0]
        secondDict = myTree[firstStr]
        for key in secondDict.keys():
            if type(secondDict[key]).__name__ == 'dict':
                thisDepth = 1 + getTreeDepth(secondDict[key])
            else:
                thisDepth = 1
            if thisDepth > maxDepth:
                maxDepth = thisDepth
        return maxDepth

    def retrieveTree(i):
        listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                    {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
                    ]
        return listOfTrees[i]


    if __name__ == '__main__':
        listOfTrees = retrieveTree(1)
        print(listOfTrees)
        myTree = retrieveTree(0)
        print(getNumLeafs(myTree))
        print(getTreeDepth(myTree))

plotTree函数

    def plotTree(myTree, parentPt, nodeTxt):
        ## 计算树的宽与高
        numLeafs = getNumLeafs(myTree)
        depth = getTreeDepth(myTree)
        firstStr = list(myTree)[0]
        cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
        ## 标记子节点属性值
        plotMidText(cntrPt, parentPt, nodeTxt)
        plotNode(firstStr, cntrPt, parentPt, decisionNode)
        secondDict = myTree[firstStr]
        ## 减少y偏移值
        plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
        for key in secondDict.keys():
            if type(secondDict[key]).__name__ == 'dict':
                plotTree(secondDict[key], cntrPt, str(key))
            else:
                plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
                plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
                plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
        plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
    
    ##更新后的createPlot
    def createPlot(inTree):
        fig = plt.figure(1, facecolor='white')
        fig.clf()
        # createPlot.ax1 = plt.subplot(111, frameon= False)
        # plotNode(U'Decision Node', (0.5, 0.1), (0.1, 0.5), decisionNode)
        # plotNode(U'Leaf Node', (0.8, 0.1), (0.3, 0.8), leafNode)
        axprops = dict(xticks=[], yticks=[])
        createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
        plotTree.totalW = float(getNumLeafs(inTree))
        plotTree.totalD = float(getTreeDepth(inTree))
        plotTree.xOff = -0.5 / plotTree.totalW
        plotTree.yOff = 1.0
        plotTree(inTree, (0.5, 1.0), '')
        plt.show()

    if __name__ == '__main__':
        myTree = retrieveTree(0)
        createPlot(myTree)

3.3 测试和存储分类器

3.3.1 测试算法:使用决策树执行分类

使用决策树的分类函数

    def classify(inputTree, featLabels, testVec):
        firstStr = list(inputTree)[0]
        secondDict = inputTree[firstStr]
        ## 将标签字符串转换为索引
        featIndex = featLabels.index(firstStr)
        for key in secondDict.keys():
            if testVec[featIndex] == key:
                if type(secondDict[key]).__name__ == 'dict':
                    classLabel = classify(secondDict[key], featLabels, testVec)
                else:
                    classLabel = secondDict[key]
        return classLabel

    if __name__ == '__main__':
        myDat, labels = creatDataSet()
        print(labels)
        myTree = treePlotter.retrieveTree(0)
        print(myTree)
        print(classify(myTree, labels, [1, 0]))
        print(classify(myTree, labels, [1, 1]))

3.3.2 使用算法:决策树的存储

使用pickle模块存储决策树:

    def storeTree(inputTree, filename):
        import pickle
        fw = open(filename, 'wb')
        pickle.dump(inputTree, fw)
        fw.close()

    def grabTree(filename):
        import pickle
        fr = open(filename, 'rb')
        return pickle.load(fr)

    if __name__ == '__main__':
        myDat, labels = creatDataSet()
        print(labels)
        myTree = treePlotter.retrieveTree(0)
        storeTree(myTree, './resource/classifierStorage.txt')
        classifyTree = grabTree('./resource/classifierStorage.txt')
        print(classifyTree)

3.4 示例:使用决策树预测隐形眼镜类型

    def use():
        fr = open('./resource/lenses.txt', 'r')
        lenses = [inst.strip().split('\t') for inst in fr.readlines()]
        lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
        lensesTree = createTree(lenses, lensesLabels)
        print(lensesTree)
        treePlotter.createPlot(lensesTree)

    if __name__ == '__main__':
        use()

决策树非常好地匹配了实验数据,然而这些匹配选项可能太多了。我们将这种问题称之为过度匹配(overfitting)。为了减少过度匹配问题,我们可以裁剪决策树,去掉一些不必要地叶子节点。如果叶子节点只能增加少许信息,则可以删除该节点,将它并入到其他叶子节点中。

相关推荐
武子康8 分钟前
大数据-212 数据挖掘 机器学习理论 - 无监督学习算法 KMeans 基本原理 簇内误差平方和
大数据·人工智能·学习·算法·机器学习·数据挖掘
___Dream22 分钟前
【CTFN】基于耦合翻译融合网络的多模态情感分析的层次学习
人工智能·深度学习·机器学习·transformer·人机交互
passer__jw76737 分钟前
【LeetCode】【算法】283. 移动零
数据结构·算法·leetcode
Ocean☾43 分钟前
前端基础-html-注册界面
前端·算法·html
顶呱呱程序1 小时前
2-143 基于matlab-GUI的脉冲响应不变法实现音频滤波功能
算法·matlab·音视频·matlab-gui·音频滤波·脉冲响应不变法
爱吃生蚝的于勒1 小时前
深入学习指针(5)!!!!!!!!!!!!!!!
c语言·开发语言·数据结构·学习·计算机网络·算法
羊小猪~~1 小时前
数据结构C语言描述2(图文结合)--有头单链表,无头单链表(两种方法),链表反转、有序链表构建、排序等操作,考研可看
c语言·数据结构·c++·考研·算法·链表·visual studio
王哈哈^_^2 小时前
【数据集】【YOLO】【VOC】目标检测数据集,查找数据集,yolo目标检测算法详细实战训练步骤!
人工智能·深度学习·算法·yolo·目标检测·计算机视觉·pyqt
星沁城2 小时前
240. 搜索二维矩阵 II
java·线性代数·算法·leetcode·矩阵
脉牛杂德2 小时前
多项式加法——C语言
数据结构·c++·算法