机器学习:决策树——ID3算法、C4.5算法、CART算法

决策树是一种常用于分类和回归问题的机器学习模型。它通过一系列的"决策"来对数据进行分类或预测。在决策树中,每个内部节点表示一个特征的测试,每个分支代表特征测试的结果,而每个叶节点则表示分类结果或回归值。

决策树工作原理

  • 根节点:从根节点开始,决策树将数据集分割为多个子集。

  • 选择特征:在每个节点处,选择一个特征进行划分。选择标准通常是信息增益(ID3算法)、增益率或基尼指数(CART算法)。

  • 分割数据:根据选定的特征,将数据分割成不同的子集。

  • 递归构建树:对每个子集递归执行相同的过程,直到满足停止条件(例如,节点纯度达到最大、树的深度达到限制或节点数过少)。

  • 叶节点:最终的叶节点给出预测值,对于分类问题 ,通常是该节点中样本最多的类别;对于回归问题,通常是该节点中样本的平均值。

决策树的优化

1. 剪枝

  • 预剪枝:在树构建过程中提前停止分裂,以避免生成过于复杂的树。例如,通过设定最小样本数、最大树深度或者最大叶节点数等参数来控制树的生长。

  • 后剪枝:树生成后,检查每个节点的分裂是否真的能提高模型的预测性能。如果没有,可以通过剪枝去除不重要的分支。剪枝的目标是减少模型的复杂度,防止过拟合。

2. 最大化信息增益或基尼指数

  • 在构建决策树时,可以选择不同的标准来选择最佳分裂节点。最常见的标准是信息增益(率)和基尼指数。优化这些指标的计算,能够帮助树做出更加有效的分裂。

3. 调整超参数

  • 最大深度:限制树的最大深度,避免树的过度生长,降低过拟合风险。

  • 最小样本分裂数:控制每个内部节点最少需要的样本数,如果样本数小于该值,就不再分裂。

  • 最小样本叶节点数:控制每个叶子节点最少需要的样本数,保证树不会过度拟合。

  • 最大特征数:控制每个节点分裂时使用的特征数。通过随机选择部分特征,有助于避免过拟合,并增加模型的泛化能力。

4. 集成学习

  • 随机森林:通过生成多个决策树,并对每棵树的预测结果进行投票,减少单一决策树的偏差和方差。

  • 梯度提升决策树 (GBDT):通过逐步修正之前树的错误,构建多个弱决策树,将它们组合成一个强大的模型,能够有效提高预测性能。

5. 特征选择和工程

  • 选择合适的特征,有助于减少噪声和提高决策树的效率。特征选择方法如L1正则化可以帮助筛选出重要特征,减少不必要的计算。

  • 通过合适的特征工程,例如对数转换、标准化等,确保数据的尺度一致,提升模型的稳定性和准确性。

6. 处理类别不平衡问题

  • 在处理类别不平衡的数据时,决策树可能偏向于多数类。可以通过调整类别权重、过采样或欠采样等方法来处理类别不平衡问题。

7. 样本权重调整

  • 对于某些重要样本,给它们更大的权重,有助于提高模型对这些样本的关注度,尤其在类别不平衡的情况下特别有用。

ID3决策树

ID3是一种经典的决策树算法,它用于分类 任务。ID3的核心思想是通过选择信息增益最大的特征来递归地构建决策树。是信息论中的一个概念,用来衡量一个数据集的不确定性或混乱程度。熵越大,意味着数据集中各类别的分布越不均匀,系统的混乱程度越高。反之,熵越小,表示数据集中各类别的分布更加集中,不确定性较小。

基本原理

1、**熵的计算公式:**假设有一个数据集D,它包含n个类别的样本,各类别的概率为 p1,p2,...,pn,则熵的计算公式为:

注:H(D)为正数,因为概率pi的对数为负数,再加一个负号就变为正。

2、信息增益的计算公式: 信息增益是通过某个特征A来划分数据集D后,信息不确定性减少量。 设D的熵为 H(D),而根据特征A划分后得到的子集的熵分别为 H(D1),H(D2),...,H(Dk),则根据特征A划分的信息增益的计算公式为:

即,信息增益 = 总熵 - sum ( Di在D的占比 * 子集的熵 )

其中,∣Di∣是子集Di的样本数,∣D∣ 是数据集D的样本数。

3、计算所有特征的信息增益

4、选择信息增益最大的特征: 信息增益越大,表示选择该特征进行划分会使得数据集更加"纯净",即++减少了更多的不确定性++,适合作为当前树节点的划分特征。

ID3决策树的局限性:

优点:

  • 简单易理解:ID3算法简单直观,容易实现。

  • 分类效率高:适用于特征数较小的数据集。

缺点:

  • 容易过拟合:由于ID3倾向于选择信息增益最大的特征,有时会造成过拟合,特别是当数据中有很多噪声时。

  • 倾向于选择取值多的特征 :ID3在选择特征时,倾向于选择取值多的特征,这可能导致不合理的划分。举个例子 :在 ID3 算法中,如果我们选择"身份证号"特征进行划分,数据集几乎可以被划分到纯净(即每个子集中只有一个样本),使得信息增益很高。然而,这种划分并没有真正的分类效果,因为"身份证号码"实际上不具有区分类别的意义,而只是过度划分了数据集。为了解决这一问题,C4.5算法提出了使用增益率来代替信息增益。

  • 不能处理连续特征:ID3算法只能处理离散的特征,对于连续特征需要离散化处理。

C4.5算法

C4.5在ID3基础上改进,使用信息增益率 进行分裂,能够处理连续特征和缺失值,还增加了后剪枝 机制。C4.5和CART都是决策树算法,但CART可以用于分类和回归问题,而C4.5只能用于分类。 C4.5算法的特点:

1、以增益率作为特征选择标准

  • 信息增益率:信息增益和固有值的比值,用于减少对多取值特征的偏向。
  • 固有值: 衡量特征 A 的取值数量和分布。
  • 优先选择增益率最大的特征作为划分点。

2、能处理缺失值

C4.5 在处理缺失值时,采取样本加权 的方法,即在计算信息增益增益率时,赋予缺失值样本一个权重。这个权重表示缺失值样本属于每个分支的可能性,权重的大小依赖于该特征在其他样本中的(概率)分布。

举个例子:

假设有数据集D,A有两种取值(a1和a2),样本4为缺失值样本:

1、计算特征 A 的概率分布(计算概率时不考虑缺失值样本),并加权分配:

  • a1出现 2 次(样本 1 和 3),所以 P(a1)=2/4=0.5,则样本 4 的分配权重0.5到 a1 分支

  • a2出现 2 次(样本 2 和 5),所以 P(a2)=2/4=0.5,则样本 4 的分配权重0.5到 a1 分支

2、计算分支的样本数和熵

  • 对于 a1 分支:

    • 原始样本中有 2 个样本(样本 1 和 3),加上样本 4 的权重(0.5),所以总样本数|Da1|为: 2+0.5=2.5。

    • 计算该分支的熵 H(Da1)。(根据类别标签 "Yes" 和 "No"),

  • 对于 a2 分支:

    • 原始样本中有 2 个样本(样本 2 和 5),加上样本 4 的权重(0.5),所以总样本数|Da2|为: 2+0.5=2.5。

    • 计算该分支的熵 H(Da2)。(根据类别标签 "Yes" 和 "No")

**注:**根据分配到a1的0.5权重,分子∣Da1,Yes|也要加0.5,但是∣Da1,No|不用加0.5,因为缺失样本4的类别为Yes,而不是No。这里只以a1分支为例,其他分支同理。

3、计算信息增益和增益率:

其中,Di为Dai。

4、总的来说,这些公式都变为带权公式而已。(过程不过多赘述)

3、能处理连续特征

  • 排序 :对于某个连续特征,将所有样本的特征值按++升序排列++。

  • 确定候选分裂点:计算相邻样本之间的均值作为候选分裂点。例如,对于排序后的连续特征值序列 x1,x2,...,xn​,候选分裂点为(x1+x2)/2,... 。一共n-1个候选分裂点。

  • 计算信息增益率:对于每个候选分裂点,C4.5 将候选分裂点作为划分依据(分为小于或等于分裂点的子集、大于分裂点的子集),并计算此划分下的信息增益率。

  • 选择最佳分裂点:选择信息增益率最大的候选分裂点,作为该连续特征的最优分裂点。信息增益率越大,分裂效果越好。

  • 生成决策节点:根据该分裂点将数据划分为左右两个子集,并继续对子集递归应用 C4.5 算法,直至满足停止条件(如纯度较高或数据量不足)。

  • 生成树结构:以递归方式将分裂节点连接起来,最终形成一棵决策树。

4、剪枝操作

C4.5 使用 后剪枝 方法,即先生成一个完全的决策树,然后评估每个节点的剪枝效果。如果剪枝后的树在验证集(或测试集)上的误差率更低或相当,剪枝就被认为是有效的。如果剪枝后误差率增加,那么不进行剪枝。这样,C4.5 通过对比剪枝前后的误差率来判断剪枝操作的有效性。

后剪枝的过程:

  • N为节点的样本总数, e 为节点样本被错误分类,则误差率 E 计算公式为:

其中,0.5 是一种平滑参数,用于在数据样本量较小时防止误差率过小。

  • 叶节点误差率:直接使用上述公式计算叶节点的误差率

  • 子树误差率 :对子树中的所有叶节点误差进行加权平均,得到子树误差率,也是剪枝前的误差率。

  • **剪枝后的误差率:**即当前特征节点(非叶节点)的误差率,也是根据误差率 E 公式计算。

  • 比较剪枝前、后的误差率:若 剪枝后的误差率 < 剪枝前的误差率,则认为剪枝是有效果的,执行剪枝操作。否者不执行剪枝操作。

  • **重复剪枝:**直到剪枝后不能再有效降低误差率。

C4.5算法的优缺点

C4.5 的优点:

  1. 处理连续值:与 ID3 不同,C4.5 能够处理连续属性。

  2. 处理缺失值:C4.5 可以处理缺失的数据,它使用了概率估计方法来分配缺失值的权重,避免了丢弃包含缺失值的数据。

  3. 剪枝功能避免过拟合 :C4.5 采用了后剪枝技术,通过修剪已经生成的决策树来++避免过拟合++,提高了模型的泛化能力。

  4. 支持多类别分类:与 ID3 类似,C4.5 能够处理多类别的分类问题。

  5. 决策树生成效果好:C4.5 相对于 ID3,在生成树的过程中考虑了信息增益比,可以减少过拟合的风险。

  6. 模型易解释

C4.5 的缺点:

  1. 计算开销较大:C4.5 需要遍历所有属性的所有可能划分点,计算每个属性的增益比,特别是在数据量较大时,这会导致较高的计算成本。

  2. 对噪声敏感:尽管 C4.5 在一定程度上能减少过拟合,但它仍然可能对数据中的噪声敏感,尤其在数据非常复杂或存在许多异常值的情况下,决策树可能会变得过于复杂。

  3. 生成的树可能过大:尽管采用了剪枝技术,生成的决策树有时仍然可能较大,影响模型的可解释性。

  4. 处理不平衡数据的能力有限:C4.5 在面对类别不平衡时可能表现不佳,因为它的划分标准侧重于信息增益比,对于类别较少的类别可能无法进行有效划分。

  5. 只能用于分类

CART算法

CART(Classification and Regression Trees)算法是一个基于树形结构的监督学习算法,用于解决分类和回归问题。其核心思想是通过递归地将数据集划分为更小的子集,从而构建一个二叉树,树的每个叶子节点代表最终的分类标签或回归预测值。

总体流程

  • 选择特征和分割点:每个节点选择一个特征及其最优分割点,将数据划分为两个子集。

  • 递归分裂:对每个子集继续执行步骤1,直到满足停止条件(如最大树深度或最小样本数)。

  • 决策节点与叶子节点:节点的分裂基于特征和分裂点(阈值),最终将数据划分为叶子节点,叶子节点提供预测结果。

  • 剪枝(可选):通过后剪枝或预剪枝来避免过拟合。

CART算法的关键部分

(一)分类问题中的分割标准:基尼指数。假设某个节点包含 m 个类别,类别 i 的样本比例为 pi,那么该节点的基尼指数 Gini(t)为:

基尼指数的特点:

  • 当节点完全纯净时(所有样本属于同一类别),基尼指数为0。

  • 当节点的类别均匀分布时,基尼指数最大,表示该节点的"杂乱"程度最高。

1.1 选择一个特征及其分裂点

对于每个特征 X(如年龄、收入等),CART会尝试所有可能的++分裂点++。然后,基于每个分裂点,我们将数据集D分为两个子集:左子集 D_left 和右子集 D_right。

注: 对于数值型特征,分裂点选择参考上文C4.5算法------处理连续特征部分。

1.2 计算加权基尼指数(表示分裂后两个子集的纯度)

1.3 选择最小加权基尼指数,作为特征X的最优分裂点,也作为特征X的分裂效果

1.4 计算其他特征的最小加权基尼指数,并选择所有特征中加权基尼指数最小的特征作为当前节点的分裂点。

1.5 同理,继续这样分裂下去,直到分裂出一颗完整的决策树。

**(二)回归问题中的分割标准:**均方误差(MSE)。MSE衡量的是数据的目标值(通常是一个连续变量)在当前节点的均匀性。MSE越小,表示数据点的目标值分布越集中,节点的纯度越高。

选择当前节点的分裂特征的当方法与使用基尼指数类似。

(三)关于"最佳分割点"的总结

CART选择最佳分割点的过程本质上是一个贪心算法,在每个节点选择最优的特征和分割点,但不考虑未来可能的分割。这个过程的关键步骤可以总结为:

  • 对于每个特征,遍历所有可能的分割点。
  • 对于每个分割点,根据该分割点将数据划分为两部分,并计算子集的纯度。
  • 选择使得纯度最小的分割点(基于基尼指数或均方误差)。

通过这种方式,CART算法能够在每个节点做出最优的分割,逐步构建决策树,直到满足停止条件(如树的深度、节点样本数、纯度等)

(四)举个简单的CART回归例子。

假设有这样一个数据集:

4.1 根节点分裂

假设选择面积作为根节点的分裂特征,并使用一个分裂点(阈值)进行分裂,比如"面积 ≤ 70":

  • 左子树:面积 ≤ 70(A, B, C)
  • 右子树:面积 > 70(D, E)

4.2 左子树分裂

选择使用另一个特征"房龄"进行分裂。假设选择"房龄 ≤ 7"作为分裂条件:

  • 左子树(继续分裂):房龄 ≤ 7(A, B)
  • 右子树:房龄 > 7(C)

4.3 右子树分裂

假设选择"房龄 ≤ 5"作为分裂条件,数据集会被再次分裂:

  • 左子树:房龄 ≤ 5(D)
  • 右子树:房龄 > 5(E)

4.4 生成决策树

4.5 预测

假设有一个测试样本:面积为40,房龄为6。根据决策树,最终会到达叶子节点(A,B)。。因为例子是一个回归任务,该叶子节点的预测值就是A和B的目标值的平均,即(100+120)/2 = 110 万。所以测试样本的预测值为110万。

CART的剪枝操作

CART算法包括两种剪枝方式:预剪枝和后剪枝。预剪枝在树构建过程中进行,而后剪枝在树构建完毕后对其进行优化。CART主要使用后剪枝,通过代价复杂度剪枝方法,生成一个平衡复杂度和误差的、更简洁、泛化能力更强的树。

剪枝过程:

  • 从底部向上剪枝:从树的叶节点开始逐层向上,每次合并最小化代价复杂度的分支。

  • 计算剪枝后的代价复杂度:对每个非叶节点,将该节点的子树替换为叶节点,计算替换后整棵树的代价复杂度。若剪枝后 的代价复杂度小于剪枝前 ,则执行剪枝操作;否则保留原分支。

  • 选择最优子树:逐步合并分支,直到找到使代价复杂度最小的子树为止。

  • 选择最优 α 值:通过调整 α,可以找到一个能平衡复杂度和误差的最佳树,可以通过交叉验证独立验证集来确定最优的 α 值。

CART算法的优缺点

优点

  1. 模型易解释

  2. 能处理分类和回归任务

  3. 无需大量数据预处理:相比于一些其他机器学习算法,CART对数据的预处理要求较低。它不要求特征标准化或归一化,并且能够处理缺失数据(通过缺失值处理机制,如分配默认值或插补)。

  4. 能够处理数值型和类别型特征:CART能够处理数值型特征(通过划分数值区间)和类别型特征(通过类别的组合划分)。这种灵活性使得它能够处理多种类型的数据。

  5. 能够自动处理特征选择:在每个节点分裂时,CART自动选择最能有效分割数据的特征。因此,用户不需要进行显式的特征选择或特征工程,这减少了模型开发的工作量。

  6. 能处理非线性关系:决策树通过一系列简单的条件判断进行决策,因此它能够捕捉数据中的非线性关系,而不像线性模型那样仅能处理线性关系。

  7. 不易受到异常值影响:由于CART算法是基于划分数据的,它通常对异常值具有一定的鲁棒性。异常值不会对树的构建产生过大的影响,除非异常值非常突出。

缺点

  1. 容易过拟合 :CART算法容易发生过拟合,特别是在数据集较小或者树的深度过大时。决策树可能会过分复杂地拟合训练数据中的噪声,从而在测试数据上表现差。

  2. 对数据噪声敏感 :尽管决策树对单个异常值具有一定的鲁棒性,但在存在大量噪声数据时,决策树可能会构建非常复杂的模型来拟合这些噪声,从而降低模型的泛化能力。

  3. 可能产生不稳定的模型

    • 对于相似的数据集,CART算法可能会生成不同的树结构,导致模型的稳定性较差。这种不稳定性可以通过集成方法(如随机森林)来缓解。
  4. 无法捕捉复杂的关系 :虽然CART可以捕捉非线性关系,但它只能通过简单的划分来表示数据的关系。因此,对于一些非常复杂的非线性关系,单棵 决策树可能不足以拟合得非常好。

  5. 决策边界较为硬性:CART通过逐步划分特征空间来生成决策边界,这种边界通常是直线的(即通过一系列轴对齐的平面划分数据)。这意味着对于一些数据分布较为复杂的问题(如多维曲线的分布),决策边界可能无法很好地适应数据的结构。

  6. 对不平衡数据敏感 :如果数据集中的类别不平衡,CART可能会偏向于多数类。例如,在二分类任务中,如果一类样本占大多数,CART可能会倾向于预测该类,从而影响模型的性能。可以通过使用类别权重调整分裂标准来缓解这一问题。

  7. 计算复杂度较高:尽管CART在构建决策树时通常表现出较好的计算效率,但在大数据集或高维数据集上,遍历所有特征和可能的分割点会导致计算复杂度较高,尤其是在特征较多的情况下。

  8. 不适合高维数据 :当数据的维度非常高时,CART可能会遇到维度灾难。在高维数据中,决策树的分裂点选择变得更加困难,可能导致模型无法有效捕捉数据的结构。

  9. **处理缺失值能力有限:**CART算法的缺失值处理较为简单,一般是通过删除缺失值的样本或用统计方法(如均值、众数)进行填充。它在分裂过程中会通过概率分配来处理缺失值,但整体处理策略相对较为简单。C4.5的缺失值处理方法比CART更加精细和复杂,因此在面对缺失值较多的数据时,C4.5可能会表现得更为稳定和鲁棒。

文章整理的内容来自多方渠道。如有错误,欢迎大噶指正!

制作不易,如果文章对大噶有帮助,点个小赞 鼓励一下叭!

相关推荐
言之。9 分钟前
【K-Means】
算法·机器学习·kmeans
hummhumm41 分钟前
第 10 章 - Go语言字符串操作
java·后端·python·sql·算法·golang·database
Jeffrey_oWang1 小时前
软间隔支持向量机
算法·机器学习·支持向量机
算法歌者2 小时前
[算法]入门1.矩阵转置
算法
用户8134411823612 小时前
分布式训练
算法
林开落L2 小时前
前缀和算法习题篇(上)
c++·算法·leetcode
远望清一色2 小时前
基于MATLAB边缘检测博文
开发语言·算法·matlab
tyler_download2 小时前
手撸 chatgpt 大模型:简述 LLM 的架构,算法和训练流程
算法·chatgpt
封步宇AIGC2 小时前
量化交易系统开发-实时行情自动化交易-3.4.1.2.A股交易数据
人工智能·python·机器学习·数据挖掘
m0_523674212 小时前
技术前沿:从强化学习到Prompt Engineering,业务流程管理的创新之路
人工智能·深度学习·目标检测·机器学习·语言模型·自然语言处理·数据挖掘