决策树的损失函数公式详细说明和例子说明

公式的详细说明

L α ( T ) = ∑ t = 1 ∣ T ∣ N t H t ( T ) + α ∣ T ∣ L_{\alpha}(T) = \sum_{t=1}^{|T|} N_t H_t(T) + \alpha |T| Lα(T)=t=1∑∣T∣NtHt(T)+α∣T∣

这是决策树的损失函数,它由两部分组成:

  1. ∑ t = 1 ∣ T ∣ N t H t ( T ) \sum_{t=1}^{|T|} N_t H_t(T) ∑t=1∣T∣NtHt(T) - 这一部分衡量了整个决策树的分类效果,通过每个叶节点的经验熵来评估。公式的含义如下:

    • H t ( T ) H_t(T) Ht(T) 是决策树 T T T 中第 t t t 个叶节点的经验熵(也称作分类不纯度)。经验熵是衡量某个叶节点上数据混乱程度的一个指标。举例来说,如果一个叶节点上的所有数据样本都属于同一类,则该节点的经验熵为 0,表示该节点非常纯。反之,如果该节点上的数据样本均匀分布在多个类别上,则经验熵较高,表示该节点不纯。
    • N t N_t Nt 是第 t t t 个叶节点的样本数量。为了评估树的整体分类效果,我们对每个叶节点的经验熵进行加权,权重就是该叶节点上的样本数量 N t N_t Nt。因此,包含大量数据的叶节点对总损失的影响更大。
  2. α ∣ T ∣ \alpha |T| α∣T∣ - 这一部分是正则化项,用于控制树的复杂度:

    • ∣ T ∣ |T| ∣T∣ 是树的叶节点数。叶节点越多,意味着树的分割越多,模型的复杂度越高。
    • α \alpha α 是正则化参数,它用来控制叶节点数带来的复杂度惩罚。如果 α \alpha α 值较小,则模型会允许更多的叶节点(树更复杂),如果 α \alpha α 值较大,模型则会倾向于保持较少的叶节点(树更简单)。

通过调整这个损失函数,算法的目标是找到一棵在经验熵和树的复杂度之间取得平衡的树。

例子说明

假设我们有一组数据用来分类,决策树的某个结构如下图所示(假设已经生成了一棵树):

       根节点
      /      \
   节点1     节点2
  /   \      /   \
叶1   叶2   叶3   叶4
  • 假设决策树共有 4 个叶节点: 叶1 , 叶2 , 叶3 , 叶4 \text{叶1}, \text{叶2}, \text{叶3}, \text{叶4} 叶1,叶2,叶3,叶4,所以 ∣ T ∣ = 4 |T| = 4 ∣T∣=4。
  • 每个叶节点包含的样本数量分别为 N 1 = 30 N_1 = 30 N1=30, N 2 = 20 N_2 = 20 N2=20, N 3 = 25 N_3 = 25 N3=25, N 4 = 25 N_4 = 25 N4=25。
  • 每个叶节点的经验熵分别为 H 1 ( T ) = 0.3 H_1(T) = 0.3 H1(T)=0.3, H 2 ( T ) = 0.1 H_2(T) = 0.1 H2(T)=0.1, H 3 ( T ) = 0.2 H_3(T) = 0.2 H3(T)=0.2, H 4 ( T ) = 0.4 H_4(T) = 0.4 H4(T)=0.4。
  • 设正则化参数 α = 0.01 \alpha = 0.01 α=0.01。
计算第一项:经验熵之和

首先,计算加权经验熵之和:
∑ t = 1 ∣ T ∣ N t H t ( T ) = N 1 H 1 ( T ) + N 2 H 2 ( T ) + N 3 H 3 ( T ) + N 4 H 4 ( T ) \sum_{t=1}^{|T|} N_t H_t(T) = N_1 H_1(T) + N_2 H_2(T) + N_3 H_3(T) + N_4 H_4(T) t=1∑∣T∣NtHt(T)=N1H1(T)+N2H2(T)+N3H3(T)+N4H4(T)

代入已知数据:
∑ t = 1 4 N t H t ( T ) = 30 × 0.3 + 20 × 0.1 + 25 × 0.2 + 25 × 0.4 = 9 + 2 + 5 + 10 = 26 \sum_{t=1}^{4} N_t H_t(T) = 30 \times 0.3 + 20 \times 0.1 + 25 \times 0.2 + 25 \times 0.4 = 9 + 2 + 5 + 10 = 26 t=1∑4NtHt(T)=30×0.3+20×0.1+25×0.2+25×0.4=9+2+5+10=26

计算第二项:复杂度惩罚

接着,计算复杂度惩罚项:
α ∣ T ∣ = 0.01 × 4 = 0.04 \alpha |T| = 0.01 \times 4 = 0.04 α∣T∣=0.01×4=0.04

计算损失函数

将这两部分加在一起得到总的损失:
L α ( T ) = 26 + 0.04 = 26.04 L_{\alpha}(T) = 26 + 0.04 = 26.04 Lα(T)=26+0.04=26.04

解释:

  • 经验熵之和(26) 反映了决策树的分类效果。这个值越低,说明叶节点的分类越纯,树的分类效果越好。
  • 复杂度惩罚项(0.04) 反映了树的复杂度。值越高,说明树的叶节点越多,树越复杂。
  • 总损失(26.04) 是两者的综合。我们希望总损失尽可能小,以找到既能很好分类数据又不过于复杂的树。

通过调节 α \alpha α 的值,可以控制树的复杂度。较小的 α \alpha α 会让树倾向于复杂的结构,而较大的 α \alpha α 则会使得树倾向于保持简单的结构,以避免过拟合。

结论

这个公式帮助我们在训练决策树时,不仅关注分类的准确性,还通过正则化项控制树的复杂度,确保生成的模型具有良好的泛化能力,而不会过度复杂导致过拟合。

相关推荐
极客代码4 分钟前
【Python TensorFlow】入门到精通
开发语言·人工智能·python·深度学习·tensorflow
义小深6 分钟前
TensorFlow|咖啡豆识别
人工智能·python·tensorflow
passer__jw76711 分钟前
【LeetCode】【算法】283. 移动零
数据结构·算法·leetcode
Ocean☾18 分钟前
前端基础-html-注册界面
前端·算法·html
顶呱呱程序26 分钟前
2-143 基于matlab-GUI的脉冲响应不变法实现音频滤波功能
算法·matlab·音视频·matlab-gui·音频滤波·脉冲响应不变法
Tianyanxiao1 小时前
如何利用探商宝精准营销,抓住行业机遇——以AI技术与大数据推动企业信息精准筛选
大数据·人工智能·科技·数据分析·深度优先·零售
爱吃生蚝的于勒1 小时前
深入学习指针(5)!!!!!!!!!!!!!!!
c语言·开发语言·数据结构·学习·计算机网络·算法
羊小猪~~1 小时前
数据结构C语言描述2(图文结合)--有头单链表,无头单链表(两种方法),链表反转、有序链表构建、排序等操作,考研可看
c语言·数据结构·c++·考研·算法·链表·visual studio
撞南墙者1 小时前
OpenCV自学系列(1)——简介和GUI特征操作
人工智能·opencv·计算机视觉
OCR_wintone4211 小时前
易泊车牌识别相机,助力智慧工地建设
人工智能·数码相机·ocr