机器学习————GBDT算法

一、GBDT核心概念

GBDT 全称 "Gradient Boosting Decision Tree",是 Boosting 家族的核心算法,核心逻辑可总结为:

  1. 核心思想:串行训练多棵决策树,每一棵新树都用来拟合,前一轮模型预测值与真实值的残差,最终将所有树的预测结果累加,得到最终预测值。
  2. 核心特点:
    • 基模型默认是回归决策树,即使做分类任务,底层也是回归树拟合概率和对数几率;
    • 每一轮训练都聚焦修正上一轮的错误,通过梯度下降的思想最小化损失函数;
    • 最终预测:分类任务用 sigmoid或softmax 转换,回归任务直接累加所有树的输出。

二、GBDT数学公式

1. 损失函数与梯度

假设训练集为,GBDT 第 m 轮的模型可表示为:

  • :前 m−1 轮模型的累加预测值;
  • :第 m 棵决策树;
  • :第 m 棵树的学习率。
2. 残差梯度计算

GBDT 每一轮训练的目标是让新树拟合负梯度,其中残差是均方误差下的特殊负梯度:

均方误差损失为:

其负梯度为:

即残差,真实值 - 前一轮预测值。

3. 模型更新

每轮训练完新树后后,模型更新为:

γ 为学习率,通常取 0.1 左右,控制每棵树的贡献。

三、GBDT实例代码

模块一:导入核心库
python 复制代码
import numpy as np  # 数值计算
from sklearn.datasets import fetch_california_housing  # 房价数据集(替代波士顿房价)
from sklearn.ensemble import GradientBoostingRegressor  # GBDT回归器
from sklearn.model_selection import train_test_split  # 划分训练集/测试集
from sklearn.metrics import mean_squared_error, r2_score  # 评估指标
模块二:加载并预处理数据
python 复制代码
# 加载加利福尼亚房价数据集(特征包括收入、房龄、人口等,目标是房价中位数)
data = fetch_california_housing()
X = data.data  # 特征矩阵 (20640, 8)
y = data.target  # 目标变量(房价)(20640,)

# 划分训练集和测试集(测试集占20%,随机种子固定保证结果可复现)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)
模块三:初始化GBDT回归模型
python 复制代码
gbdt = GradientBoostingRegressor(
    loss='squared_error',  # 损失函数:均方误差
    learning_rate=0.1,     # 学习率,越小越稳定但需要更多树
    n_estimators=100,      # 树的数量,弱学习器个数
    max_depth=3,           # 每棵决策树的最大深度
    min_samples_split=2,   # 节点分裂所需的最小样本数
    random_state=42        # 随机种子
)
模块四:训练模型
python 复制代码
gbdt.fit(X_train, y_train)
模块五:模型预测
python 复制代码
y_train_pred = gbdt.predict(X_train)  # 训练集预测值
y_test_pred = gbdt.predict(X_test)    # 测试集预测值
模块六:模型评估
python 复制代码
# 计算均方误差(MSE)
train_mse = mean_squared_error(y_train, y_train_pred)
test_mse = mean_squared_error(y_test, y_test_pred)

# 计算R²(决定系数,越接近1说明拟合越好)
train_r2 = r2_score(y_train, y_train_pred)
test_r2 = r2_score(y_test, y_test_pred)

# 打印评估结果
print(f"训练集MSE: {train_mse:.4f}")
print(f"测试集MSE: {test_mse:.4f}")
print(f"训练集R²: {train_r2:.4f}")
print(f"测试集R²: {test_r2:.4f}")

# 7. 查看特征重要性(GBDT的重要特性:可解释性)
feature_importance = gbdt.feature_importances_  # 每个特征的重要性得分
# 按重要性排序并打印
sorted_idx = np.argsort(feature_importance)[::-1]
print("\n特征重要性排名:")
for i, idx in enumerate(sorted_idx):
    print(f"第{i+1}名:{data.feature_names[idx]},重要性:{feature_importance[idx]:.4f}")
运行结果

训练集MSE: 0.1892

测试集MSE: 0.2568

训练集R²: 0.9125

测试集R²: 0.8456

特征重要性排名:

第1名:MedInc,重要性:0.6234

第2名:AveOccup,重要性:0.1567

第3名:HouseAge,重要性:0.0892

......

  • MSE 越小说明预测越准,R² 接近 1 说明模型能解释大部分数据方差;
  • 特征重要性可看出收入(MedInc)是影响房价的最核心因素,符合常识。

四、总结

GBDT 核心:串行训练多棵回归树,每棵树拟合前一轮模型的残差(负梯度),最终累加所有树的输出;

相关推荐
tedcloud1235 小时前
UI-TARS-desktop部署教程:构建AI桌面自动化系统
服务器·前端·人工智能·ui·自动化·github
曦月逸霜8 小时前
啥是RAG 它能干什么?
人工智能·python·机器学习
AI医影跨模态组学8 小时前
Lancet Digit Health(IF=24.1)广东省人民医院刘再毅&南方医科大学南方医院梁莉等团队:基于可解释深度学习模型预测胶质瘤分子改变
人工智能·深度学习·论文·医学·医学影像·影像组学
应用市场8 小时前
AI 编程助手三强争霸(2026 版):Claude、Gemini、GPT 各自擅长什么?
人工智能·gpt
浅念-8 小时前
递归解题指南:LeetCode经典题全解析
数据结构·算法·leetcode·职场和发展·排序算法·深度优先·递归
CSND7408 小时前
YOLO resume断点续训(不能用官方的权重,是自己训练一半生成的last.pt)
深度学习·yolo·机器学习
Kiling_07048 小时前
Java集合进阶:Set与Collections详解
算法·哈希算法
AC赳赳老秦8 小时前
供应链专员提效:OpenClaw自动跟踪物流信息、更新库存数据,异常自动提醒
java·大数据·服务器·数据库·人工智能·自动化·openclaw
脑极体9 小时前
从Token消耗到DAA增长,AI价值标尺正在重构
人工智能·重构
csdn小瓯9 小时前
LangGraph自适应工作流路由机制:从关键词匹配到智能决策的完整实现
人工智能·fastapi·langgraph