【Python机器学习】树回归——树剪枝

如果一棵树节点过多,表明该模型可能对数据进行了过拟合。

通过降低决策树的复杂度来避免过拟合的过程称为剪枝 。提过提前终止条件,实际上就是在进行一种所谓的预剪枝 ;另一种形式的剪枝需要使用测试集和训练集,称作后剪枝

预剪枝

树构建算法其实对输入的参数tolS和tolN非常敏感,如果使用其他值将不容易达到这么好的效果。为了说明这一点,可以尝试:

python 复制代码
print(createTree(myDat,ops=(0,1)))

与之前的只包含两个节点的树相比,这里构建的树过于臃肿,它甚至为数据集中的每个样本都分配了一个叶节点。

用新数据构建新的树:

这里构建的树都有很多叶节点,产生这个现象的原因在于,停止条件tolS对误差的数量级十分敏感。如果在选项中花费时间并对上述误差容忍度取平方值,或许也能得到仅有两个叶节点组成的数:

python 复制代码
print(createTree(myDat,ops=(10000,4)))

然而,通过不断修改条件来得到合理结果并不是很好的办法。事实上,我们常常甚至不确定到底需要什么样的结果。这正是机器学习所关注的内容,计算机应该可以给出总体的概貌。

后剪枝

使用后剪枝的方法需要将数据集分成测试集和训练集。首先指定参数,使得构建出的树足够大、足够复杂,便于剪枝。接下来从上而下找到叶节点,用测试集来判断将这些叶节点合并是否能降低测试误差,如果是的话就合并。

函数prune()的伪代码如下:

基于已有的树切分测试数据:

如果存在任一子树是一棵树,则在该子集递归剪枝过程

计算将当前两个叶节点合并后的误差

计算不合并的误差

如果合并会降低误差的话,就将叶节点合并

实际代码实现:

python 复制代码
def isTree(obj):
    return (type(obj).__name__=='dict')

def getMean(tree):
    if isTree(tree['right']):
        tree['right']=getMean(tree['right'])
    if isTree(tree['left']):
        tree['left']=getMean(tree['left'])
    return (tree['left']+tree['right'])/2.0

def prune(tree,testData):
    if shape(testData)[0]==0:
        return getMean(tree)
    if (isTree(tree['right']) or isTree(tree['left'])):
        lSet,rSet=binSplitDataSet(testData,tree['spInd'],tree['spVal'])
    if isTree(tree['left']):
        tree['left']=prune(tree['left'],lSet)
    if isTree(tree['right']):
        tree['right']=prune(tree['right'],rSet)
    if not isTree(tree['left']) and not isTree(tree['left']):
        lSet,rSet=binSplitDataSet(testData,tree['spInd'],tree['spVal'])
        errorNoMerge=sum(power(lSet[:,-1]-tree['left'],2))+sum(power(rSet[:,-1]-tree['right'],2))
        treeMean=(tree['left']+tree['right'])/2.0
        errorMerge=sum(power(testData[:,-1]-treeMean,2))
        if errorMerge<errorNoMerge:
            print('merging')
            return treeMean
        else:
            return tree
    else:
        return tree

上述代码中,包含3个函数:isTree()、getMean()、prune()。

其中isTree()用于测试输入变量是否为一棵树,返回布类型结果。

getMean()是一个递归函数,它从上往下遍历树直到叶节点为止。如果找到两个叶节点则计算它们的平均值。该函数对树进行塌陷处理(即返回树平均值),在prune()函数中调用该函数时应明确这一点。

prune()是主函数,它有两个参数:待剪枝的树与剪枝所需的测试数据testData。prune()函数首先需要确认测试集是否为空。一旦非空,则反复递归调用函数prune()对测试数据进行切分。因为树是由其他数据集(训练集)生成的,所以测试集上会有一些样本与原数据集样本的取值范围不同。一旦出现这种情况,假设发生了过拟合,对树进行剪枝。

接下来要检查某个分枝到底是子树还是节点。如果是子树,就调用prune()来对该子树进行剪枝。在对左右两个分支完成剪枝之后,还需要检查它们是否仍然还是子树。如果两个分支已经不再是子树,那么就可以进行合并。具体做法是对合并前后的误差进行比较,如果合并后的误差比不合并的误差小就进行合并操作,反之则不合并直接返回。

运行测试:

python 复制代码
myDat=loadDataSet('test/ex2.txt')
myDat=mat(myDat)
myTree=createTree(myDat,ops=(0,1))
myDatTest=loadDataSet('test/ex2test.txt')
myMat2Test=mat(myDatTest)
print(prune(myTree,myMat2Test))

从结果上来看,大量的节点已经被剪枝掉了,但是没有像预期的那样剪枝成两部分,这说明后剪枝可能不如预剪枝有效。一般来说,为了寻求最佳模型可以同时使用两种剪枝技术。

相关推荐
The_Ticker7 分钟前
CFD平台如何接入实时行情源
java·大数据·数据库·人工智能·算法·区块链·软件工程
Elastic 中国社区官方博客13 分钟前
Elasticsearch 开放推理 API 增加了对 IBM watsonx.ai Slate 嵌入模型的支持
大数据·数据库·人工智能·elasticsearch·搜索引擎·ai·全文检索
jwolf214 分钟前
摸一下elasticsearch8的AI能力:语义搜索/vector向量搜索案例
人工智能·搜索引擎
有Li23 分钟前
跨视角差异-依赖网络用于体积医学图像分割|文献速递-生成式模型与transformer在医学影像中的应用
人工智能·计算机视觉
傻啦嘿哟25 分钟前
如何使用 Python 开发一个简单的文本数据转换为 Excel 工具
开发语言·python·excel
B站计算机毕业设计超人32 分钟前
计算机毕业设计SparkStreaming+Kafka旅游推荐系统 旅游景点客流量预测 旅游可视化 旅游大数据 Hive数据仓库 机器学习 深度学习
大数据·数据仓库·hadoop·python·kafka·课程设计·数据可视化
新加坡内哥谈技术43 分钟前
Mistral推出“Le Chat”,对标ChatGPT
人工智能·chatgpt
爪哇学长43 分钟前
双指针算法详解:原理、应用场景及代码示例
java·数据结构·算法
Dola_Pan1 小时前
C语言:数组转换指针的时机
c语言·开发语言·算法
GOTXX1 小时前
基于Opencv的图像处理软件
图像处理·人工智能·深度学习·opencv·卷积神经网络