【Python机器学习】树回归——连续和离散型特征的树的构建

在树的构建过程中,需要解决多种类型数据的存储问题。这里将使用一部字典来存储树的存储结构,该字典将包含以下4种元素:

1、待切分的特征;

2、待切分的特征值;

3、右子树。当不再需要切分的时候,也可以是单个值;

4、左子树,与右子树类似。

CART算法只做二元切分,所以这里可以固定树的数据结构。树包含左键和右键,可以存储另一棵子树或者单个值。字典还包含特征和特征值这两个键,它们给出切分算法所有的特征和特征值。还可以用面向对象的编程模式来构建这个数据结构。

例如,用下面的Python代码来建立树节点:

python 复制代码
class treeNode():
    def __init__(self,feat,val,right,left):
        featureToSplitOn=feat
        valueOfSplit=val
        rightBranch=right
        leftBranch=left

Python具有足够的灵活性,可以直接使用字典来存储结构而无须再自定义一个类,从而有效地减少代码量。Python不是一种强类型编码语言,因此接下来会看到,树的每个分枝还可以再包含其他树、数值型数据甚至是向量。

对于回归树和模型树,有一些构建时可以共用的代码。

构建树(createTree())函数的伪代码大致如下:

找到最佳的待切分特征:

如果该节点不能再切分,将该节点存为叶节点

执行二元切分

在右子树调用createTree()方法

在左子树调用createTree()方法

CART算法的实现代码:

python 复制代码
from numpy import *

def loadDataSet(fileName):
    dataMat=[]
    fr=open(fileName)
    for line in fr.readlines():
        curLine=line.strip().split('\t')
        #将每行映射成浮点数
        fltLine=map(float,curLine)
        dataMat.append(fltLine)
    return dataMat

def regLeaf(dataSet):#returns the value used for each leaf
    return mean(dataSet[:,-1])
def regErr(dataSet):
    return var(dataSet[:,-1]) * shape(dataSet)[0]
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    tolS = ops[0]; tolN = ops[1]
    #if all the target variables are the same value: quit and return value
    if len(set(dataSet[:,-1].T.tolist()[0])) == 1: #exit cond 1
        return None, leafType(dataSet)
    m,n = shape(dataSet)
    #the choice of the best feature is driven by Reduction in RSS error from mean
    S = errType(dataSet)
    bestS = inf; bestIndex = 0; bestValue = 0
    for featIndex in range(n-1):
        for splitVal in set(dataSet[:,featIndex]):
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue
            newS = errType(mat0) + errType(mat1)
            if newS < bestS:
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    #if the decrease (S-bestS) is less than a threshold don't do the split
    if (S - bestS) < tolS:
        return None, leafType(dataSet) #exit cond 2
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):  #exit cond 3
        return None, leafType(dataSet)
    return bestIndex,bestValue#returns the best feature to split on
                              #and the value used for that split

def binSplitDataSet(dataSet,feature,value):
    mat0=dataSet[nonzero(dataSet[:,feature]>value)[0],:][0]
    mat1=dataSet[nonzero(dataSet[:,feature]<=value)[0],:][0]
    return mat0,mat1

def createTree(dataSet,leafType=regLeaf,errType=regErr,ops=[1,4]):
    feat,val=chooseBestSplit(dataSet,leafType,errType,ops)
    if feat == None:
        return val
    retTree={}
    retTree['spInd']=feat
    retTree['spVal'] = val
    lSet,rSet=binSplitDataSet(dataSet,feat,val)
    retTree['left']=createTree(lSet,leafType,errType,ops)
    retTree['right']=createTree(rSet,leafType,errType,ops)
    return retTree

上述代码主要包含3个函数,第一个是loadDataSet(),该函数主要作用是把数据存放在一起,该函数读取第一个以tab键为分隔符的文件,然后将每行的内容保存成一组浮点数。

第二个函数是binSplitDataSet(),该函数有3个参数:数据集合、待切分的特征和该特征的某个值。在给定特征和特征值的情况下,该函数通过数组过滤方式将上述数据集合切分得到两个子集并返回。

最后一个函数是树构建函数createTree(),它有4个参数:数据集和其他3个可选参数。这些可选参数决定了树的类型:leafTypr给出建立叶节点的函数,errType代表误差计算函数;而ops是一个包含树构建所需其他参数的元组。createTree()是一个递归函数,它首先尝试将数据集分成两个部分,切分由函数chooseBestSplit()完成。如果满足停止条件,chooseBestSplit()将返回None和某类模型的值。如果构建的是回归树,该模型是一个常数。如果是模型时,其模型是一个线性方程。后面可以看到停止条件的作用方式。如果不满足停止条件,chooseBestSplit()将创建一个新的Python字典并将数据集分成两份,在这两份数据集上将分别继续递归调用createTree()函数。

测试代码效果:

python 复制代码
testMat=mat(eye(4))
print(testMat)
mat0,mat1=binSplitDataSet(testMat,1,0.5)
print(mat0)
print(mat1)
相关推荐
lucky_lyovo2 小时前
自然语言处理NLP---预训练模型与 BERT
人工智能·自然语言处理·bert
fantasy_arch2 小时前
pytorch例子计算两张图相似度
人工智能·pytorch·python
七七&5562 小时前
2024年08月13日 Go生态洞察:Go 1.23 发布与全面深度解读
开发语言·网络·golang
java坤坤2 小时前
GoLand 项目从 0 到 1:第八天 ——GORM 命名策略陷阱与 Go 项目启动慢问题攻坚
开发语言·后端·golang
元清加油2 小时前
【Golang】:函数和包
服务器·开发语言·网络·后端·网络协议·golang
健康平安的活着3 小时前
java之 junit4单元测试Mockito的使用
java·开发语言·单元测试
No0d1es3 小时前
电子学会青少年软件编程(C/C++)5级等级考试真题试卷(2024年6月)
c语言·c++·算法·青少年编程·电子学会·五级
AndrewHZ3 小时前
【3D重建技术】如何基于遥感图像和DEM等数据进行城市级高精度三维重建?
图像处理·人工智能·深度学习·3d·dem·遥感图像·3d重建
飞哥数智坊3 小时前
Coze实战第18讲:Coze+计划任务,我终于实现了企微资讯简报的定时推送
人工智能·coze·trae
WBluuue4 小时前
数学建模:智能优化算法
python·机器学习·数学建模·爬山算法·启发式算法·聚类·模拟退火算法