机器学习——构建决策树

第1关:返回分类次数最多的分类名称

python 复制代码
import operator

def majorityCnt(classList):
    classCount = {}
    for i in classList:
        if i not in classCount:
            classCount[i] = 0
        classCount[i] += 1

    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]

第2关:创建树函数

python 复制代码
from ex03_lib import majorityCnt,splitDataSet,chooseBestFeatureToSplit

def createTree(dataSet,labels):
    classList = [example[-1] for example in dataSet]  #获取数据集的所有类别
    #### 请补充完整代码 ####
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel:{}}
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
    #######################
    return myTree

第3关:获取叶子节点数目

python 复制代码
def getNumLeafs(myTree):
    numLeafs = 0
    if type(myTree).__name__ == 'dict':
        fi = list(myTree.keys())[0]
        se = myTree[fi]
        for i in se.keys():
            if type(se[i]).__name__ == 'dict':
                numLeafs += getNumLeafs(se[i])
            else:
                numLeafs += 1
    else:
        numLeafs += 1
    return numLeafs

第4关:获取树的层数

python 复制代码
def getTreeDepth(myTree):
    maxDepth = 0
    #### 请补充完整代码 ####
    fi = list(myTree.keys())[0]
    se = myTree[fi]
    for i in se.keys():
        if type(se[i]).__name__ == 'dict':
            thisDepth = 1 + getTreeDepth(se[i])
        else:
            thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    #######################
    return maxDepth

第5关:注解树节点

python 复制代码
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
#定义决策树决策结果的特征,以字典的形式定义  
#下面的字典定义也可写作 decisionNode={boxstyle:'sawtooth',fc:'0.8'}  
#boxstyle为文本框的类型,sawtooth是锯齿形,fc是边框线粗细  
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    #annotate是关于一个数据点的文本  
    #nodeTxt为要显示的文本,centerPt为文本的中心点,parentPt为指向文本的点 
    #### 请补充完整代码 ####
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',
             xytext=centerPt, textcoords='axes fraction',
             va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
    #######################
def createPlot():     
    fig = plt.figure(1,facecolor='white') # 定义一个画布,背景为白色
    fig.clf() # 把画布清空
    #createPlot.ax1为全局变量,绘制图像的句柄,subplot为定义了一个绘图,
    #111表示figure中的图有1行1列,即1个,最后的1代表第一个图 
    #frameon表示是否绘制坐标轴矩形 
    #### 请补充完整代码 ####
    createPlot.ax1 = plt.subplot(111,frameon=False) 
    plotNode('a decision node',(0.2,0.2),(0.6,0.8),decisionNode) 
    plotNode('a leaf node',(0.6,0.1),(0.8,0.8),leafNode) 
    plt.show()
    #######################

第6关:绘制树形图

python 复制代码
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from ex03_lib import plotNode,getNumLeafs,getTreeDepth

def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)                      #当前树的叶子数
    depth = getTreeDepth(myTree)                         #没有用到这个变量
    firstSides = list(myTree.keys())
    firstStr = firstSides[0]
    #cntrPt是文本中心点,parentPt指向文本中心点 
    #### 请补充完整代码 ####
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)                 #画分支上的键
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD      #从上往下画
    for key in secondDict.keys():
    #如果是字典则是一个判断(内部)结点
        if type(secondDict[key]).__name__=='dict': 
            plotTree(secondDict[key],cntrPt,str(key))       
        else:   #打印叶子结点
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
    #######################

def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    #### 请补充完整代码 ####
    axprops = dict(xticks=[], yticks=[])                   #定义横纵坐标轴  
    createPlot.ax1 = plt.subplot(111, frameon=False) 
    plotTree.totalW = float(getNumLeafs(inTree))       #全局变量宽度 = 叶子数
    plotTree.totalD = float(getTreeDepth(inTree))      #全局变量高度 = 深度
    #图形的大小是0-1 ,0-1
    plotTree.xOff = -0.5/plotTree.totalW;  
    #例如绘制3个叶子结点,坐标应为1/3,2/3,3/3
    #但这样会使整个图形偏右因此初始的,将x值向左移一点。
    plotTree.yOff = 1.0;
    plotTree(inTree, (0.5,1.0), '')
    #######################
    plt.show()
相关推荐
新缸中之脑5 分钟前
Llama 3.2 安卓手机安装教程
前端·人工智能·算法
人工智障调包侠6 分钟前
基于深度学习多层感知机进行手机价格预测
人工智能·python·深度学习·机器学习·数据分析
开始King1 小时前
Tensorflow2.0
人工智能·tensorflow
Elastic 中国社区官方博客1 小时前
Elasticsearch 开放推理 API 增加了对 Google AI Studio 的支持
大数据·数据库·人工智能·elasticsearch·搜索引擎
infominer1 小时前
RAGFlow 0.12 版本功能导读
人工智能·开源·aigc·ai-native
涩即是Null1 小时前
如何构建LSTM神经网络模型
人工智能·rnn·深度学习·神经网络·lstm
本本的小橙子1 小时前
第十四周:机器学习
人工智能·机器学习
励志成为美貌才华为一体的女子2 小时前
《大规模语言模型从理论到实践》第一轮学习--第四章分布式训练
人工智能·分布式·语言模型
学步_技术2 小时前
自动驾驶系列—自动驾驶背后的数据通道:通信总线技术详解与应用场景分析
人工智能·机器学习·自动驾驶·通信总线
winds~2 小时前
自动驾驶-问题笔记-待解决
人工智能·笔记·自动驾驶