【Python机器学习】树回归——树剪枝

如果一棵树节点过多,表明该模型可能对数据进行了过拟合。

通过降低决策树的复杂度来避免过拟合的过程称为剪枝 。提过提前终止条件,实际上就是在进行一种所谓的预剪枝 ;另一种形式的剪枝需要使用测试集和训练集,称作后剪枝

预剪枝

树构建算法其实对输入的参数tolS和tolN非常敏感,如果使用其他值将不容易达到这么好的效果。为了说明这一点,可以尝试:

python 复制代码
print(createTree(myDat,ops=(0,1)))

与之前的只包含两个节点的树相比,这里构建的树过于臃肿,它甚至为数据集中的每个样本都分配了一个叶节点。

用新数据构建新的树:

这里构建的树都有很多叶节点,产生这个现象的原因在于,停止条件tolS对误差的数量级十分敏感。如果在选项中花费时间并对上述误差容忍度取平方值,或许也能得到仅有两个叶节点组成的数:

python 复制代码
print(createTree(myDat,ops=(10000,4)))

然而,通过不断修改条件来得到合理结果并不是很好的办法。事实上,我们常常甚至不确定到底需要什么样的结果。这正是机器学习所关注的内容,计算机应该可以给出总体的概貌。

后剪枝

使用后剪枝的方法需要将数据集分成测试集和训练集。首先指定参数,使得构建出的树足够大、足够复杂,便于剪枝。接下来从上而下找到叶节点,用测试集来判断将这些叶节点合并是否能降低测试误差,如果是的话就合并。

函数prune()的伪代码如下:

基于已有的树切分测试数据:

如果存在任一子树是一棵树,则在该子集递归剪枝过程

计算将当前两个叶节点合并后的误差

计算不合并的误差

如果合并会降低误差的话,就将叶节点合并

实际代码实现:

python 复制代码
def isTree(obj):
    return (type(obj).__name__=='dict')

def getMean(tree):
    if isTree(tree['right']):
        tree['right']=getMean(tree['right'])
    if isTree(tree['left']):
        tree['left']=getMean(tree['left'])
    return (tree['left']+tree['right'])/2.0

def prune(tree,testData):
    if shape(testData)[0]==0:
        return getMean(tree)
    if (isTree(tree['right']) or isTree(tree['left'])):
        lSet,rSet=binSplitDataSet(testData,tree['spInd'],tree['spVal'])
    if isTree(tree['left']):
        tree['left']=prune(tree['left'],lSet)
    if isTree(tree['right']):
        tree['right']=prune(tree['right'],rSet)
    if not isTree(tree['left']) and not isTree(tree['left']):
        lSet,rSet=binSplitDataSet(testData,tree['spInd'],tree['spVal'])
        errorNoMerge=sum(power(lSet[:,-1]-tree['left'],2))+sum(power(rSet[:,-1]-tree['right'],2))
        treeMean=(tree['left']+tree['right'])/2.0
        errorMerge=sum(power(testData[:,-1]-treeMean,2))
        if errorMerge<errorNoMerge:
            print('merging')
            return treeMean
        else:
            return tree
    else:
        return tree

上述代码中,包含3个函数:isTree()、getMean()、prune()。

其中isTree()用于测试输入变量是否为一棵树,返回布类型结果。

getMean()是一个递归函数,它从上往下遍历树直到叶节点为止。如果找到两个叶节点则计算它们的平均值。该函数对树进行塌陷处理(即返回树平均值),在prune()函数中调用该函数时应明确这一点。

prune()是主函数,它有两个参数:待剪枝的树与剪枝所需的测试数据testData。prune()函数首先需要确认测试集是否为空。一旦非空,则反复递归调用函数prune()对测试数据进行切分。因为树是由其他数据集(训练集)生成的,所以测试集上会有一些样本与原数据集样本的取值范围不同。一旦出现这种情况,假设发生了过拟合,对树进行剪枝。

接下来要检查某个分枝到底是子树还是节点。如果是子树,就调用prune()来对该子树进行剪枝。在对左右两个分支完成剪枝之后,还需要检查它们是否仍然还是子树。如果两个分支已经不再是子树,那么就可以进行合并。具体做法是对合并前后的误差进行比较,如果合并后的误差比不合并的误差小就进行合并操作,反之则不合并直接返回。

运行测试:

python 复制代码
myDat=loadDataSet('test/ex2.txt')
myDat=mat(myDat)
myTree=createTree(myDat,ops=(0,1))
myDatTest=loadDataSet('test/ex2test.txt')
myMat2Test=mat(myDatTest)
print(prune(myTree,myMat2Test))

从结果上来看,大量的节点已经被剪枝掉了,但是没有像预期的那样剪枝成两部分,这说明后剪枝可能不如预剪枝有效。一般来说,为了寻求最佳模型可以同时使用两种剪枝技术。

相关推荐
DoraBigHead1 小时前
小哆啦解题记——异位词界的社交网络
算法
木头左2 小时前
逻辑回归的Python实现与优化
python·算法·逻辑回归
quant_19863 小时前
R语言如何接入实时行情接口
开发语言·经验分享·笔记·python·websocket·金融·r语言
小牛头#4 小时前
clickhouse 各个引擎适用的场景
大数据·clickhouse·机器学习
杨小扩5 小时前
第4章:实战项目一 打造你的第一个AI知识库问答机器人 (RAG)
人工智能·机器人
whaosoft-1435 小时前
51c~目标检测~合集4
人工智能
雪兽软件5 小时前
2025 年网络安全与人工智能发展趋势
人工智能·安全·web安全
lifallen6 小时前
Paimon LSM Tree Compaction 策略
java·大数据·数据结构·数据库·算法·lsm-tree
元宇宙时间6 小时前
全球发展币GDEV:从中国出发,走向全球的数字发展合作蓝图
大数据·人工智能·去中心化·区块链
小黄人20257 小时前
自动驾驶安全技术的演进与NVIDIA的创新实践
人工智能·安全·自动驾驶