这张图是 GBDT 算法中**构建第一个弱学习器(CART 回归树)**的完整计算过程,我用通俗的例子和数值拆解来一步步解释:
一、先明确初始状态
我们有 10 个样本,每个样本的目标值 是真实标签(比如学生的考试分数),初始时模型的预测值是所有目标值的平均值(因为平方损失下,均值是最优初始预测)。
| 样本序号 (x) | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 |
|---|---|---|---|---|---|---|---|---|---|---|
| 目标值(真实分) | 5.56 | 5.70 | 5.91 | 6.40 | 6.80 | 7.05 | 8.90 | 8.70 | 9.00 | 9.05 |
| 初始预测值(均值) | 7.31 | 7.31 | 7.31 | 7.31 | 7.31 | 7.31 | 7.31 | 7.31 | 7.31 | 7.31 |
| 负梯度(残差 = 真实值 - 预测值) | -1.75 | -1.61 | -1.40 | -0.91 | -0.51 | -0.26 | 1.59 | 1.39 | 1.69 | 1.74 |
- 初始预测值
7.31是所有目标值的平均值:\\frac{5.56+5.70+5.91+6.40+6.80+7.05+8.90+8.70+9.00+9.05}{10} = 7.31
- 负梯度(残差)反映了初始预测的错误:负的残差表示"预测值偏高",正的残差表示"预测值偏低"。
二、构建第一个弱学习器:找最优切分点
我们的目标是训练一个 CART 回归树,让它拟合这些残差(负梯度)。CART 树的核心是找到一个切分点,把样本分成两组,让两组内的残差波动最小(平方损失最小)。
1. 切分点的选择
我们遍历所有可能的切分点(图里是 1.5, 2.5, ..., 9.5,对应样本序号的中间值),对每个切分点计算平方损失:
-
切分点 1.5:把第 1 个样本和第 2-10 个样本分成两组。
- 左子树(第 1 个样本):残差
-1.75→ 子树均值-1.75 - 右子树(第 2-10 个样本):残差
-1.61, -1.40, -0.91, -0.51, -0.26, 1.59, 1.39, 1.69, 1.74→ 子树均值0.19 - 平方损失:左子树损失为
0(只有一个样本,无波动),右子树损失为15.72→ 总损失15.72
- 左子树(第 1 个样本):残差
-
切分点 6.5:把第 1-6 个样本和第 7-10 个样本分成两组。
- 左子树(第 1-6 个样本):残差
-1.75, -1.61, -1.40, -0.91, -0.51, -0.26→ 子树均值-1.07 - 右子树(第 7-10 个样本):残差
1.59, 1.39, 1.69, 1.74→ 子树均值1.60 - 平方损失:左子树损失
1.85+ 右子树损失0.07→ 总损失1.93(所有切分点中最小)
- 左子树(第 1-6 个样本):残差
2. 最优切分点的结论
对比所有切分点的平方损失,6.5 对应的损失最小(1.93),因此我们选择 6.5 作为第一个弱学习器的切分点,构建出一棵简单的二叉树:
- 左子叶(
x ≤ 6.5):输出-1.07(前 6 个样本的残差均值) - 右子叶(
x > 6.5):输出1.60(后 4 个样本的残差均值)
三、这个弱学习器的作用
这个决策树的输出是残差的修正值,它告诉我们:
- 对于前 6 个样本(残差为负,预测值偏高),需要在初始预测值
7.31上减去 1.07 (即7.31 - 1.07 = 6.24),让预测更接近真实值。 - 对于后 4 个样本(残差为正,预测值偏低),需要在初始预测值
7.31上加上 1.60 (即7.31 + 1.60 = 8.91),让预测更接近真实值。
叠加这个修正后,新的预测值会比初始预测更准确,这就是 GBDT"迭代修正错误"的核心逻辑。