决策树的损失函数公式详细说明和例子说明

公式的详细说明

L α ( T ) = ∑ t = 1 ∣ T ∣ N t H t ( T ) + α ∣ T ∣ L_{\alpha}(T) = \sum_{t=1}^{|T|} N_t H_t(T) + \alpha |T| Lα(T)=t=1∑∣T∣NtHt(T)+α∣T∣

这是决策树的损失函数,它由两部分组成:

  1. ∑ t = 1 ∣ T ∣ N t H t ( T ) \sum_{t=1}^{|T|} N_t H_t(T) ∑t=1∣T∣NtHt(T) - 这一部分衡量了整个决策树的分类效果,通过每个叶节点的经验熵来评估。公式的含义如下:

    • H t ( T ) H_t(T) Ht(T) 是决策树 T T T 中第 t t t 个叶节点的经验熵(也称作分类不纯度)。经验熵是衡量某个叶节点上数据混乱程度的一个指标。举例来说,如果一个叶节点上的所有数据样本都属于同一类,则该节点的经验熵为 0,表示该节点非常纯。反之,如果该节点上的数据样本均匀分布在多个类别上,则经验熵较高,表示该节点不纯。
    • N t N_t Nt 是第 t t t 个叶节点的样本数量。为了评估树的整体分类效果,我们对每个叶节点的经验熵进行加权,权重就是该叶节点上的样本数量 N t N_t Nt。因此,包含大量数据的叶节点对总损失的影响更大。
  2. α ∣ T ∣ \alpha |T| α∣T∣ - 这一部分是正则化项,用于控制树的复杂度:

    • ∣ T ∣ |T| ∣T∣ 是树的叶节点数。叶节点越多,意味着树的分割越多,模型的复杂度越高。
    • α \alpha α 是正则化参数,它用来控制叶节点数带来的复杂度惩罚。如果 α \alpha α 值较小,则模型会允许更多的叶节点(树更复杂),如果 α \alpha α 值较大,模型则会倾向于保持较少的叶节点(树更简单)。

通过调整这个损失函数,算法的目标是找到一棵在经验熵和树的复杂度之间取得平衡的树。

例子说明

假设我们有一组数据用来分类,决策树的某个结构如下图所示(假设已经生成了一棵树):

复制代码
       根节点
      /      \
   节点1     节点2
  /   \      /   \
叶1   叶2   叶3   叶4
  • 假设决策树共有 4 个叶节点: 叶1 , 叶2 , 叶3 , 叶4 \text{叶1}, \text{叶2}, \text{叶3}, \text{叶4} 叶1,叶2,叶3,叶4,所以 ∣ T ∣ = 4 |T| = 4 ∣T∣=4。
  • 每个叶节点包含的样本数量分别为 N 1 = 30 N_1 = 30 N1=30, N 2 = 20 N_2 = 20 N2=20, N 3 = 25 N_3 = 25 N3=25, N 4 = 25 N_4 = 25 N4=25。
  • 每个叶节点的经验熵分别为 H 1 ( T ) = 0.3 H_1(T) = 0.3 H1(T)=0.3, H 2 ( T ) = 0.1 H_2(T) = 0.1 H2(T)=0.1, H 3 ( T ) = 0.2 H_3(T) = 0.2 H3(T)=0.2, H 4 ( T ) = 0.4 H_4(T) = 0.4 H4(T)=0.4。
  • 设正则化参数 α = 0.01 \alpha = 0.01 α=0.01。
计算第一项:经验熵之和

首先,计算加权经验熵之和:
∑ t = 1 ∣ T ∣ N t H t ( T ) = N 1 H 1 ( T ) + N 2 H 2 ( T ) + N 3 H 3 ( T ) + N 4 H 4 ( T ) \sum_{t=1}^{|T|} N_t H_t(T) = N_1 H_1(T) + N_2 H_2(T) + N_3 H_3(T) + N_4 H_4(T) t=1∑∣T∣NtHt(T)=N1H1(T)+N2H2(T)+N3H3(T)+N4H4(T)

代入已知数据:
∑ t = 1 4 N t H t ( T ) = 30 × 0.3 + 20 × 0.1 + 25 × 0.2 + 25 × 0.4 = 9 + 2 + 5 + 10 = 26 \sum_{t=1}^{4} N_t H_t(T) = 30 \times 0.3 + 20 \times 0.1 + 25 \times 0.2 + 25 \times 0.4 = 9 + 2 + 5 + 10 = 26 t=1∑4NtHt(T)=30×0.3+20×0.1+25×0.2+25×0.4=9+2+5+10=26

计算第二项:复杂度惩罚

接着,计算复杂度惩罚项:
α ∣ T ∣ = 0.01 × 4 = 0.04 \alpha |T| = 0.01 \times 4 = 0.04 α∣T∣=0.01×4=0.04

计算损失函数

将这两部分加在一起得到总的损失:
L α ( T ) = 26 + 0.04 = 26.04 L_{\alpha}(T) = 26 + 0.04 = 26.04 Lα(T)=26+0.04=26.04

解释:

  • 经验熵之和(26) 反映了决策树的分类效果。这个值越低,说明叶节点的分类越纯,树的分类效果越好。
  • 复杂度惩罚项(0.04) 反映了树的复杂度。值越高,说明树的叶节点越多,树越复杂。
  • 总损失(26.04) 是两者的综合。我们希望总损失尽可能小,以找到既能很好分类数据又不过于复杂的树。

通过调节 α \alpha α 的值,可以控制树的复杂度。较小的 α \alpha α 会让树倾向于复杂的结构,而较大的 α \alpha α 则会使得树倾向于保持简单的结构,以避免过拟合。

结论

这个公式帮助我们在训练决策树时,不仅关注分类的准确性,还通过正则化项控制树的复杂度,确保生成的模型具有良好的泛化能力,而不会过度复杂导致过拟合。

相关推荐
科学熊1 天前
将chm文件格式转为PDF格式文件
人工智能
数据法师1 天前
告别付费云端转写!Memo AI:一款部署在本地的无限次音视频转文字神器
人工智能·音视频
阿乔外贸日记1 天前
以色列电商市场现状:规模、机遇与挑战
大数据·人工智能·智能手机·云计算·汽车
-cywen-1 天前
扩散模型 2
人工智能
Ulyanov1 天前
《从质点到位姿:基于Python与PyVista的导弹制导控制全栈仿真》: 同台竞技——3-DOF与6-DOF模型的终极对决与误差分析
开发语言·python·算法·系统仿真·雷达电子对抗仿真
沪漂阿龙1 天前
面试题:集成学习是什么?Boosting、Bagging、AdaBoost、随机森林为什么有效,一文讲透
人工智能·机器学习·集成学习
财经资讯数据_灵砚智能1 天前
基于全球经济类多源新闻的NLP情感分析与数据可视化(日间)2026年5月12日
人工智能·python·信息可视化·自然语言处理·ai编程
Hesionberger1 天前
LeetCode98:验证二叉搜索树(多解)
java·开发语言·python·算法·leetcode·职场和发展
千寻girling1 天前
周日那天参加的力扣周赛... —— 10号
java·javascript·c++·python·算法·leetcode·职场和发展
ZHW_AI课题组1 天前
基于SVM的手写数字分类
机器学习·支持向量机·分类