【Python机器学习】决策树的构造——递归构建决策树

我们可以采用递归的原则处理数据集,递归结束的条件是:程序遍历完所有划分数据集的属性,或者每个分支下的所有实例都具有相同的分类。如果所有实例具有相同的分类,则得到一个叶子节点或者终止块。任何到达叶子节点的数据必然属于叶子节点的分类。

我们可以设置算法可以划分的最大分组数目。像是其他决策树算法,比如C4.5和CART,这些算法在运行时并不总是在每次划分分组时都会消耗特征。由于特征数目并不是在每次划分数据分组时都减少,因此这些算法在实际使用时可能引起一定的问题。

目前我们并不需要考虑这些问题,只要在算法开始运行前计算列的数目,查看算法是否使用了所有属性即可。如果数据集已经处理了所有属性,但是类标签依然不是唯一的,此时我们需要决定如何定义该叶子节点,在这种情况下,我们通常会采用多数表决的方法决定该叶子节点的分类。

python 复制代码
import operator

def majorityCnt(classList):
    classCount={}
    for vote in classList:
        if vote not in classCount.keys():classCount[vote]=0
        classCount[vote]=classCount[vote]+1
    sortedClassCount=sorted(classCount.iteritem(),key=operator.itemgetter(1),reverse=True)
    return sortedClassCount[0][0]

上述的函数使用分类名称的列表,然后创建键值为classList中唯一值的数据字典,字典对象存储了classList中每个类标签出现的频率,最后利用operator操作键值排序字典,并范围出现次数最多的分类名称。

python 复制代码
def createTree(dataSet,labels):
    classList=[example[-1]for example in dataSet]
    if classList.count(classList[0])==len(classList):
        return classList[0]
    if len(dataSet[0])==1:
        return majorityCnt(classList)
    bestFeat=chooseBestFeatureToSplit(dataSet)
    bestFeatLabel=labels[bestFeat]
    myTree={bestFeatLabel:{}}
    del(labels[bestFeat])
    featVaues=[example[bestFeat] for example in dataSet]
    uniqueVals=set(featVaues)
    for value in uniqueVals:
        subLabels=labels[:]
        myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
    return myTree

函数createTree()使用两个输入参数:数据集和标签列表。标签列表包含了数据集中所有特征的标签,算法本身并不需要这个变量,但是为了给出数据明确的含义,我们将它作为一个输入参数提供。此外,对数据集的要求这里依然需要满足。

上述代码首先创建了一个名为classList的列表变量,其中包含了数据集的所有类标签。递归函数的第一个停止条件是所有的类标签完全相同,则直接返回该类标签。递归函数的第二个停止条件是使用完了所有特征,仍然不能将数据集划分成仅包含唯一类别的分组。由于第二个条件无法简单地返回唯一的类标签,这里使用majorityCnt函数挑选出现次数最多的类别座位返回值。

下一步程序开始创建树,这里使用Python的字典类型存储树的信息,当然也可以声明特殊的数据类型存储树,但这里没有必要。字典变量myTree存储了树的所有信息,这对于之后绘制树状图非常重要。当前数据集选取的最好特征存储在变量bestFeat中,得到列表包含的所有属性值。

最后代码遍历当前选择特征包含的所有属性值,在每个数据及划分上递归调用函数createTree(),得到的返回值将被插入到字典变量myTree中,因此函数终止执行时,字典中将会嵌套很多代表叶子节点信息的字典数据。

测试代码:

python 复制代码
myDat,labels=createDataSet()
print(createTree(myDat,labels))
print(myDat)
相关推荐
肖永威4 分钟前
CentOS环境上离线安装python3及相关包
linux·运维·机器学习·centos
baiduopenmap7 分钟前
百度世界2024精选公开课:基于地图智能体的导航出行AI应用创新实践
前端·人工智能·百度地图
小任同学Alex10 分钟前
浦语提示词工程实践(LangGPT版,服务器上部署internlm2-chat-1_8b,踩坑很多才完成的详细教程,)
人工智能·自然语言处理·大模型
新加坡内哥谈技术16 分钟前
微软 Ignite 2024 大会
人工智能
nuclear201129 分钟前
使用Python 在Excel中创建和取消数据分组 - 详解
python·excel数据分组·创建excel分组·excel分类汇总·excel嵌套分组·excel大纲级别·取消excel分组
江瀚视野43 分钟前
Q3净利增长超预期,文心大模型调用量大增,百度未来如何分析?
人工智能
Lucky小小吴44 分钟前
有关django、python版本、sqlite3版本冲突问题
python·django·sqlite
带多刺的玫瑰1 小时前
Leecode刷题C语言之统计不是特殊数字的数字数量
java·c语言·算法
爱敲代码的憨仔1 小时前
《线性代数的本质》
线性代数·算法·决策树
陪学1 小时前
百度遭初创企业指控抄袭,维权还是碰瓷?
人工智能·百度·面试·职场和发展·产品运营