机器学习————GBDT算法

一、GBDT核心概念

GBDT 全称 "Gradient Boosting Decision Tree",是 Boosting 家族的核心算法,核心逻辑可总结为:

  1. 核心思想:串行训练多棵决策树,每一棵新树都用来拟合,前一轮模型预测值与真实值的残差,最终将所有树的预测结果累加,得到最终预测值。
  2. 核心特点:
    • 基模型默认是回归决策树,即使做分类任务,底层也是回归树拟合概率和对数几率;
    • 每一轮训练都聚焦修正上一轮的错误,通过梯度下降的思想最小化损失函数;
    • 最终预测:分类任务用 sigmoid或softmax 转换,回归任务直接累加所有树的输出。

二、GBDT数学公式

1. 损失函数与梯度

假设训练集为,GBDT 第 m 轮的模型可表示为:

  • :前 m−1 轮模型的累加预测值;
  • :第 m 棵决策树;
  • :第 m 棵树的学习率。
2. 残差梯度计算

GBDT 每一轮训练的目标是让新树拟合负梯度,其中残差是均方误差下的特殊负梯度:

均方误差损失为:

其负梯度为:

即残差,真实值 - 前一轮预测值。

3. 模型更新

每轮训练完新树后后,模型更新为:

γ 为学习率,通常取 0.1 左右,控制每棵树的贡献。

三、GBDT实例代码

模块一:导入核心库
python 复制代码
import numpy as np  # 数值计算
from sklearn.datasets import fetch_california_housing  # 房价数据集(替代波士顿房价)
from sklearn.ensemble import GradientBoostingRegressor  # GBDT回归器
from sklearn.model_selection import train_test_split  # 划分训练集/测试集
from sklearn.metrics import mean_squared_error, r2_score  # 评估指标
模块二:加载并预处理数据
python 复制代码
# 加载加利福尼亚房价数据集(特征包括收入、房龄、人口等,目标是房价中位数)
data = fetch_california_housing()
X = data.data  # 特征矩阵 (20640, 8)
y = data.target  # 目标变量(房价)(20640,)

# 划分训练集和测试集(测试集占20%,随机种子固定保证结果可复现)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)
模块三:初始化GBDT回归模型
python 复制代码
gbdt = GradientBoostingRegressor(
    loss='squared_error',  # 损失函数:均方误差
    learning_rate=0.1,     # 学习率,越小越稳定但需要更多树
    n_estimators=100,      # 树的数量,弱学习器个数
    max_depth=3,           # 每棵决策树的最大深度
    min_samples_split=2,   # 节点分裂所需的最小样本数
    random_state=42        # 随机种子
)
模块四:训练模型
python 复制代码
gbdt.fit(X_train, y_train)
模块五:模型预测
python 复制代码
y_train_pred = gbdt.predict(X_train)  # 训练集预测值
y_test_pred = gbdt.predict(X_test)    # 测试集预测值
模块六:模型评估
python 复制代码
# 计算均方误差(MSE)
train_mse = mean_squared_error(y_train, y_train_pred)
test_mse = mean_squared_error(y_test, y_test_pred)

# 计算R²(决定系数,越接近1说明拟合越好)
train_r2 = r2_score(y_train, y_train_pred)
test_r2 = r2_score(y_test, y_test_pred)

# 打印评估结果
print(f"训练集MSE: {train_mse:.4f}")
print(f"测试集MSE: {test_mse:.4f}")
print(f"训练集R²: {train_r2:.4f}")
print(f"测试集R²: {test_r2:.4f}")

# 7. 查看特征重要性(GBDT的重要特性:可解释性)
feature_importance = gbdt.feature_importances_  # 每个特征的重要性得分
# 按重要性排序并打印
sorted_idx = np.argsort(feature_importance)[::-1]
print("\n特征重要性排名:")
for i, idx in enumerate(sorted_idx):
    print(f"第{i+1}名:{data.feature_names[idx]},重要性:{feature_importance[idx]:.4f}")
运行结果

训练集MSE: 0.1892

测试集MSE: 0.2568

训练集R²: 0.9125

测试集R²: 0.8456

特征重要性排名:

第1名:MedInc,重要性:0.6234

第2名:AveOccup,重要性:0.1567

第3名:HouseAge,重要性:0.0892

......

  • MSE 越小说明预测越准,R² 接近 1 说明模型能解释大部分数据方差;
  • 特征重要性可看出收入(MedInc)是影响房价的最核心因素,符合常识。

四、总结

GBDT 核心:串行训练多棵回归树,每棵树拟合前一轮模型的残差(负梯度),最终累加所有树的输出;

相关推荐
To_OC9 小时前
从一次栈溢出报错说起,我把递归彻底扒明白了
javascript·算法·程序员
火山引擎开发者社区10 小时前
火山AgentPlan/CodingPlan同步上线GLM-5.2
人工智能
冬奇Lab11 小时前
Skill 系列(05):Skill 工作流串联——4 种模式实测,并发加速 1.5x
人工智能·开源
冬奇Lab12 小时前
每日一个开源项目(第141篇):hiring-agent - HackerRank 开源了他们的简历评分系统,你的简历能得几分?
人工智能·面试·开源
甲维斯12 小时前
又升级咯!坦克大战2026,科技与复古并存!
前端·人工智能·游戏开发
姗姗来迟了14 小时前
用React Hook封装AI对话状态
人工智能
Goodbye14 小时前
从 Token 到 Embedding:LLM 核心基础深度解析
javascript·人工智能
阿瑞IT14 小时前
AI Agent 在甘特计划变更场景中的动态响应工程实践
人工智能