结合rpart包的决策树介绍

决策树与CART算法

决策树是一种基于树状结构的监督学习算法。它通过从根节点开始递归地对特征进行划分,构建出一棵树来进行决策。决策树的构建过程需要解决的重要问题有三个:如何选择自变量、如何选择分割点、确定停止划分的条件。解决这些问题是希望随着划分过程的不断进行,决策树的分支节点所包含的样本尽可能属于同一类别,即节点的"纯度"(purity)越来越高。传统决策树算法包括ID3、C4.5和CART,主要的区别就是决策树生成的划分过程中选择的目标函数不同,ID3使用的是信息增益,C4.5使用信息增益率,CART使用的是Gini系数。相比于ID3和C4.5算法,CART算法处理连续型特征能力强、不受特征取值个数影响、非参数化、可剪枝,在解决各种类型的问题时更加灵活和有效。

rpart包提供了实现CART算法的功能,并允许在构建树的过程中进行剪枝,避免过拟合问题。使用rpart包,可以根据给定的训练数据集自动生成决策树模型,并使用该模型进行预测。本文会首先对CART算法原理进行简单介绍,再对rpart包中的主要函数进行介绍,最后利用实例看如何使用rpart包构建决策树模型。

CART模型介绍

CART树的构建过程是递归的,它通过反复选择最佳的特征进行节点分裂,直到满足停止条件为止。在每次分裂时,CART算法会选择最佳的特征和最佳的切分点,以最小化切分后的不纯度(分类问题)或者最小化切分后的均方误差(回归问题),CART使用的是基尼指数来衡量数据的不纯度。

基尼指数

假设有K个类别,第k个类别的概率为 p k p_{k} pk,数据集D的纯度可以用基尼值来度量:

G i n i ( D ) = ∑ k = 1 K p k ( 1 − p k ) = 1 − ∑ k = 1 K p k 2 Gini(D)=\sum_{k=1}^Kp_{k}(1-p_{k})=1-\sum_{k=1}^Kp_{k}^2 Gini(D)=k=1∑Kpk(1−pk)=1−k=1∑Kpk2

直观来说,Gini(D)反映了从数据集D中随机抽取两个样本,其类别标记不一致的概率。因而,基尼值越小,数据集D的纯度越高。

属性a的基尼指数定义为:

G i n i _ i n d e x ( D , a ) = ∑ v = 1 V ∣ D v ∣ ∣ D ∣ G i n i ( D v ) Gini\index(D,a)=\sum{v=1}^V\frac{\left| D_{v} \right|}{\left| D\right|}Gini(D^{v}) Gini_index(D,a)=v=1∑V∣D∣∣Dv∣Gini(Dv)

于是在候选属性集合A中,选择使划分后的基尼指数最小的属性作为最优划分属性,即

a ∗ = a r g m i n a ∈ A G i n i _ i n d e x ( D , a ) a_{*}= \mathop{argmin}\limits_{a\in A}\quad Gini\_index(D,a) a∗=a∈AargminGini_index(D,a)

过拟合处理

由于决策树算法在学习的过程中为了尽可能的正确的分类训练样本,不停地对结点进行划分,因此这会导致整棵树的分支过多,也就容易导致了过拟合。剪枝是决策树学习算法对付"过拟合"的主要手段,决策树剪枝的基本策略有"预剪枝"和"后剪枝"。

  • 预剪枝(pre-pruning):在构建决策树过程时,提前停止。预剪枝就是在构造决策树的过程中,先对每个结点在划分前进行估计,若果当前结点的划分不能带来决策树模型泛华性能的提升,则不对当前结点进行划分并且将当前结点标记为叶结点。
  • 后剪枝(post-pruning):决策树构建完毕后,然后开始剪枝。先把整颗决策树构造完毕,然后自底向上的对非叶结点进行考察,若将该结点对应的子树换为叶结点能够带来泛华性能的提升,则把该子树替换为叶结点。

CART分类树算法具体流程

CART分类树算法有剪枝算法流程,算法输入训练集D,基尼系数的阈值,样本个数阈值,输出的是决策树T。算法从根节点开始,用训练集递归建立CART分类树。

  1. 对于当前节点的数据集为 D,如果样本个数小于阈值或没有特征,则返回决策子树,当前节点停止递归。
  2. 计算样本集 D 的基尼系数,如果基尼系数小于阈值,则返回决策树子树,当前节点停止递归。
  3. 计算当前节点现有的各个特征的各个特征值对数据集 D 的基尼系数,对于离散值和连续值的处理方法和基尼系数的计算见第二节。缺失值的处理方法和 C4.5 算法里描述的相同。
  4. 在计算出来的各个特征的各个特征值对数据集 D 的基尼系数中,选择基尼系数最小的特征A和对应的特征值a。根据这个最优特征和最优特征值,把数据集划分成两部分D1和 D2,同时建立当前节点的左右节点,做节点的数据集 D 为 D1,右节点的数据集 D 为 D2。
  5. 对左右的子节点递归的调用 1-4 步,生成决策树。

对生成的决策树做预测的时候,假如测试集里的样本 A 落到了某个叶子节点,而节点里有多个训练样本。则对于 A 的类别预测采用的是这个叶子节点里概率最大的类别。

主要函数介绍

rpart包提供的函数主要用于递归分区和剪枝,其中主要函数包括:

  • rpart():递归地构建一棵决策树。
  • printcp():打印交叉验证结果,显示在不同复杂度下测试误差率和复杂度参数的关系。
  • prune():根据不同的剪枝方法,来选择最优的剪枝点,并返回剪枝后的决策树。
  • predict():使用训练好的树模型对新数据进行预测。

其他函数为用于输出上述主要函数的结果和决策树绘制等辅助函数,此处不做详细说明。

函数rpart()

rpart()函数根据给定的训练数据集和参数来递归地构建决策树模型,函数的基本形式如下:

r 复制代码
rpart(formula, data, weights, subset, na.action = na.rpart, method, model = FALSE, x = FALSE, y = TRUE, parms, control, cost, ...)
markdown 复制代码
## rpart()函数参数介绍

- **formula**: 回归方程的形式,例如 `y ~ x1 + x2 + x3`
- **data**: 数据框形式的数据,包含前面的 `formula` 方程中的数据
- **weights**: 可选大小写权重
- **subset**: 可选表达式,表示在拟合中只应使用数据行的子集
- **na.action**: 缺失数据的处理办法,默认操作删除所有缺少因变量的观测值,但保留自变量缺失的观测值
- **method**: 利用树的末端数据类型来选择相应的变量分割方法,它影响了决策树中拆分变量和拆分点的选择方式

本参数有四种模型类型来指定用于构建决策树的特定类型:

  • "class": 用于分类问题,响应变量是离散型。
  • "anova": 用于回归问题,使用F检验选择最佳的连续型变量和拆分点。
  • "exp": 用于回归问题,适用于指数分布响应变量,使用最小二乘法拟合模型。
  • "poisson": 用于回归问题,适用于泊松分布响应变量,使用最小二乘法拟合模型。
复制代码

x : 在结果中保留 x 矩阵的副本
y : 在结果中保留因变量的副本。如果缺少并且提供了模型,则默认为 FALSE
parms : 用来设置三个参数:先验概率、损失矩阵、分类纯度的度量方法
control: 对树进行一些设置,控制每个节点上的最小样本量、交叉验证次数、复杂性度量即 cp 值等对树的一些设置

```r
 rpart.control(minsplit=20, minbucket=round(minsplit/3), cp=0.01, maxcompete=4, maxsurrogate=5, usesurrogate=2, xval=10, surrogatestyle=0, maxdepth=30, ...)

rpart.control中参数介绍:
- minsplit: 在进行拆分之前,节点中的观察数必须达到最小要求。
- minbucket:叶子节点最小样本数
- cp(complexity parameter): 指某个点的复杂度,对每一步拆分,模型的拟合优度必须提高的程度
- maxcompete:指定保留在输出中的竞争拆分数量
- maxsurrogate:指定保留在输出中的替代拆分数量
- usesurrogate:指定是否使用替代拆分
- maxdepth:树的深度
- xval: 交叉验证次数
  • cost: 损失矩阵,在剪枝的时候,叶子节点的加权误差与父节点的误差进行比较,考虑损失矩阵的时候,从将"减少-误差"调整为"减少-损失"

输出信息

rpart()函数的输出信息包含了决策树模型的相关信息,可以通过使用print()函数或者summary()函数查看,具体包含:

  • node: 节点的编号,每个节点都有唯一的编号,编号从1开始,一直到叶子节点的数量为止
  • split: 该节点选用的切分变量和切分点,即表示样本划分的规则
  • n: 该节点的样本数量
  • loss: 该节点的损失函数。通常是基于选定的损失函数衡量预测误差的指标
  • yval: 该节点的预测值 (或类别)。对于回归模型,预测值为该节点所有样本目标变量的平均值;对于分类模型,预测值为该节点样本中数量最多的类别
  • yprob: 各个类别的概率。它只在决策树用于分类问题时有意义,表示在该节点下,各个类别的比例
  • *: 代表该节点是叶子节点,不再分裂

函数prune()

prune()函数对决策树进行剪枝,函数的基本形式如下:

r 复制代码
prune(tree, cp, ...)

prune()函数参数介绍

  • tree: 一个回归树对象,常是rpart()的结果对象
  • cp: 复杂性参数(complex parameter),指剪枝采用的阈值。指某个点的复杂度,对每一步拆分,模型的拟合优度必须提高的程度,用来节省剪枝浪费的不必要的时间。

输出信息

prune()函数的输出信息包含了剪枝后的决策树模型的相关信息,同样可通过使用print()函数或summary()函数查看,参数含义同rpart()函数输出参数,此处不再进行赘述。

predict()

利用拟合好的决策树对象来预测响应向量,其基本形式为:

r 复制代码
predict(object, newdata, type = c("vector", "prob", "class", "matrix"), na.action = na.pass, ...)

predict()函数参数介绍

  • object: 类"rpart"的拟合模型对象,即是拟合好的决策树

  • newdata: 需要进行预测的值的数据框,是对于树模型定义的每一个变量的取值。该变量可以是数值型,分类型或逻辑型,并且必须与一开始用于训练模型时的变量/因子相同

  • type: 预测结果的类型,可选参数如下,默认"default",意味着返回与此前拟合树模型的类型相同的结果。

    type中可选参数:
    - vector: 返回连续型或类别型变量的预测值/分类
    - prob: 返回类别型变量的预测概率
    - class: 返回每个新观察的分类预测
    - matrix: 返回每个新观察的所有分类预测和概率
    
  • na.action: 指定处理缺失值的方式,默认是na.pass,即不处理缺失值。

输出信息

predict()函数返回的值,即是其type参数所指定的,默认返回与之前拟合树模型的类型相同的结果,一般来说,连续型预测值是新观察变量的数值预测,类别型预测是对新观察变量的分类预测。概率值表示每个类别的相对可能性。

实战案例

此部分通过具体的示例演示如何使用rpart包中函数构建决策树,分别介绍了分类树和回归树两种类型模型的应用案例,并在构建模型时使用了不同的剪枝方法。

分类树实例

此分类树实例利用rpart包中自带的kyphosis数据集,它包括81个接受脊柱矫正手术的儿童的数据,预测变量是Kyphosis(术后是否存在后凸(一种变形)的因素,为分类数据),预测变量是Age(年龄(以月记))、Number(涉及的椎骨数目)和Start(第一椎体手术数)。为了防止过拟合,该分类树有预剪枝操作。R语言代码操作如下:

r 复制代码
# 载入数据
data(kyphosis)
#分类树构建
fit <- rpart(Kyphosis ~ Age + Number + Start, data = kyphosis,
             parms = list(prior = c(0.65, 0.35), split = "gini"),
             control = rpart.control(cp = 0.05))
# 查看决策树的具体信息
print(fit)
rpart.plot(fit)
# 预测
newdata <- data.frame(Age = 10, Number = 1, Start = 3)
predict(fit, newdata = newdata)

在上面的分类树构建的过程中,parms参数指定两个类别Kyphosis="absent"和Kyphosis="present"的先验概率,设定Kyphosis="absent"和Kyphosis="present"的先验概率分别为0.65和0.35;split参数指定使用Gini系数作为衡量节点分裂能力的指标,control参数指定了树的复杂程度,如果初始cp值为0,则没有预剪枝操作。

接着使用print函数查看拟合树模型的细节信息,split表示样本划分的规则,n代表样本大小,loss为分类错误的代价,yval为分类结果,yprob为两类的百分比,*表示叶子节点。结果如下:

可以看到最终剪枝为3个叶节点。每个叶节点显示患病情况、基尼值及数量占比。尽管rpart()函数能够清晰的输出决策过程,但使用rpart.plot函数绘制树状图,可以更直观地了解树模型的结构和分类规则,也有助于通俗解释树模型的结果,展示如下:

最后使用predict函数对新数据进行预测。新数据包括年龄(Age)为10岁、手术次数(Number)为1次、手术位置(Start)为3级的患者的基本信息。

函数返回值为该患者所属类别未患有/患有Kyphosis的预测概率,分别约为0.8161和0.1838,因而预测结果为未患有Kyphosis病变。

回归树实例

此决策树实例选择使用ISLR包中的Hitters数据集,它包括263个专业棒球运动员的各类信息,预测变量是home runs(平均本垒)和years played(职业经验),响应变量为运动员的Salary.首先构建大的初始回归树,为了让树足够大使用了较小的cp值,意味着只要回归模型总体R方增加就继续产生新的分支,后为避免过拟合问题,寻找最优cp值,cp的最佳值是致xerror最低的记录值,利用该值对决策树进行剪枝处理,最后利用剪枝后的决策树对新数据进行预测。

初始回归树的构建结果如上图所示,能够发现叶节点的数量较多,可能存在过拟合问题,利用决策树生成过程中交叉验证数据的观察结果的误差,找出最优对应的最优cp值,进行剪枝处理,结果如下。由于此部分侧决策树内容较多,便不妨决策树的输出结果,仅呈现图形。

最终剪枝结果如上图,每个终端节点显示运动员的薪资及原始数据中属于该节点的观察值的占比。

最后使用最终的剪枝树预测新的运动员薪资,基于职业经验和平均本垒(home runs).比如,某运动员有7年职业经验,平均本垒为4,则预测薪资约为502k.此部分R语言代码如下:

r 复制代码
library(rpart)   
library(rpart.plot) #决策树更好看
#回归树构建【后剪枝】
library(ISLR)  
#初始回归树构建
tree <- rpart(Salary ~ Years + HmRun, data=Hitters, control=rpart.control(cp=.0001))
print(tree)
printcp(tree) 
rpart.plot(tree)
#剪枝
best<- tree$cptable[which.min(tree$cptable[,"xerror"]),"CP"]
print(best)
pruned_tree <- prune(tree, cp=best)
rpart.plot(pruned_tree)
# 预测
new <- data.frame(Years=7, HmRun=4)
predict(pruned_tree, newdata=new)
相关推荐
浪九天1 小时前
人工智能直通车系列14【机器学习基础】(逻辑回归原理逻辑回归模型实现)
人工智能·深度学习·神经网络·机器学习·自然语言处理
紫雾凌寒4 小时前
计算机视觉应用|自动驾驶的感知革命:多传感器融合架构的技术演进与落地实践
人工智能·机器学习·计算机视觉·架构·自动驾驶·多传感器融合·waymo
安忘4 小时前
LeetCode 热题 -189. 轮转数组
算法·leetcode·职场和发展
Y1nhl4 小时前
力扣hot100_二叉树(4)_python版本
开发语言·pytorch·python·算法·leetcode·机器学习
龚大龙4 小时前
机器学习(李宏毅)——Auto-Encoder
人工智能·机器学习
曼诺尔雷迪亚兹5 小时前
2025年四川烟草工业计算机岗位备考详细内容
数据结构·数据库·计算机网络·算法
蜡笔小新..5 小时前
某些网站访问很卡 or 力扣网站经常进不去(2025/3/10)
算法·leetcode·职场和发展
IT猿手6 小时前
2025最新群智能优化算法:基于RRT的优化器(RRT-based Optimizer,RRTO)求解23个经典函数测试集,MATLAB
开发语言·人工智能·算法·机器学习·matlab
刘大猫266 小时前
五、MyBatis的增删改查模板(参数形式包括:String、对象、集合、数组、Map)
人工智能·算法·智能合约
修己xj6 小时前
算法系列之深度/广度优先搜索解决水桶分水的最优解及全部解
算法