机器学习————GBDT算法

一、GBDT核心概念

GBDT 全称 "Gradient Boosting Decision Tree",是 Boosting 家族的核心算法,核心逻辑可总结为:

  1. 核心思想:串行训练多棵决策树,每一棵新树都用来拟合,前一轮模型预测值与真实值的残差,最终将所有树的预测结果累加,得到最终预测值。
  2. 核心特点:
    • 基模型默认是回归决策树,即使做分类任务,底层也是回归树拟合概率和对数几率;
    • 每一轮训练都聚焦修正上一轮的错误,通过梯度下降的思想最小化损失函数;
    • 最终预测:分类任务用 sigmoid或softmax 转换,回归任务直接累加所有树的输出。

二、GBDT数学公式

1. 损失函数与梯度

假设训练集为,GBDT 第 m 轮的模型可表示为:

  • :前 m−1 轮模型的累加预测值;
  • :第 m 棵决策树;
  • :第 m 棵树的学习率。
2. 残差梯度计算

GBDT 每一轮训练的目标是让新树拟合负梯度,其中残差是均方误差下的特殊负梯度:

均方误差损失为:

其负梯度为:

即残差,真实值 - 前一轮预测值。

3. 模型更新

每轮训练完新树后后,模型更新为:

γ 为学习率,通常取 0.1 左右,控制每棵树的贡献。

三、GBDT实例代码

模块一:导入核心库
python 复制代码
import numpy as np  # 数值计算
from sklearn.datasets import fetch_california_housing  # 房价数据集(替代波士顿房价)
from sklearn.ensemble import GradientBoostingRegressor  # GBDT回归器
from sklearn.model_selection import train_test_split  # 划分训练集/测试集
from sklearn.metrics import mean_squared_error, r2_score  # 评估指标
模块二:加载并预处理数据
python 复制代码
# 加载加利福尼亚房价数据集(特征包括收入、房龄、人口等,目标是房价中位数)
data = fetch_california_housing()
X = data.data  # 特征矩阵 (20640, 8)
y = data.target  # 目标变量(房价)(20640,)

# 划分训练集和测试集(测试集占20%,随机种子固定保证结果可复现)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)
模块三:初始化GBDT回归模型
python 复制代码
gbdt = GradientBoostingRegressor(
    loss='squared_error',  # 损失函数:均方误差
    learning_rate=0.1,     # 学习率,越小越稳定但需要更多树
    n_estimators=100,      # 树的数量,弱学习器个数
    max_depth=3,           # 每棵决策树的最大深度
    min_samples_split=2,   # 节点分裂所需的最小样本数
    random_state=42        # 随机种子
)
模块四:训练模型
python 复制代码
gbdt.fit(X_train, y_train)
模块五:模型预测
python 复制代码
y_train_pred = gbdt.predict(X_train)  # 训练集预测值
y_test_pred = gbdt.predict(X_test)    # 测试集预测值
模块六:模型评估
python 复制代码
# 计算均方误差(MSE)
train_mse = mean_squared_error(y_train, y_train_pred)
test_mse = mean_squared_error(y_test, y_test_pred)

# 计算R²(决定系数,越接近1说明拟合越好)
train_r2 = r2_score(y_train, y_train_pred)
test_r2 = r2_score(y_test, y_test_pred)

# 打印评估结果
print(f"训练集MSE: {train_mse:.4f}")
print(f"测试集MSE: {test_mse:.4f}")
print(f"训练集R²: {train_r2:.4f}")
print(f"测试集R²: {test_r2:.4f}")

# 7. 查看特征重要性(GBDT的重要特性:可解释性)
feature_importance = gbdt.feature_importances_  # 每个特征的重要性得分
# 按重要性排序并打印
sorted_idx = np.argsort(feature_importance)[::-1]
print("\n特征重要性排名:")
for i, idx in enumerate(sorted_idx):
    print(f"第{i+1}名:{data.feature_names[idx]},重要性:{feature_importance[idx]:.4f}")
运行结果

训练集MSE: 0.1892

测试集MSE: 0.2568

训练集R²: 0.9125

测试集R²: 0.8456

特征重要性排名:

第1名:MedInc,重要性:0.6234

第2名:AveOccup,重要性:0.1567

第3名:HouseAge,重要性:0.0892

......

  • MSE 越小说明预测越准,R² 接近 1 说明模型能解释大部分数据方差;
  • 特征重要性可看出收入(MedInc)是影响房价的最核心因素,符合常识。

四、总结

GBDT 核心:串行训练多棵回归树,每棵树拟合前一轮模型的残差(负梯度),最终累加所有树的输出;

相关推荐
狐璃同学2 小时前
数据结构(1)三要素
数据结构·算法
列星随旋2 小时前
拓扑排序(Kahn算法)
算法
墨染天姬2 小时前
【AI】DeepSeek开源cuda算子库TileKernels
人工智能·开源
Hello!!!!!!2 小时前
C++基础(六)——数组与字符串
c++·算法
Agent手记2 小时前
多系统集成破局:企业级智能体打通异构系统的完整解决方案 | 2026全链路落地实操
人工智能·ai
sunneo2 小时前
从“生成视频”到“生成表演”:米哈游LPM 1.0如何重新定义数字角色的“灵魂”
人工智能·ai作画·aigc·ai编程·游戏美术
云烟成雨TD2 小时前
Spring AI Alibaba 1.x 系列【36】FlowAgent 和 BaseAgent 抽象类
java·人工智能·spring
山半仙xs2 小时前
基于卡尔曼滤波的人脸跟踪
人工智能·python·算法·计算机视觉
谷歌开发者2 小时前
Build with AI 深圳场|在大湾区科技浪潮中预见 AI 未来
人工智能·科技
谁似人间西林客2 小时前
工业互联网如何驱动工艺智能?拆解高精度制造的三大技术支柱
人工智能·制造