【机器学习】机器学习的基本分类-监督学习-梯度提升树(Gradient Boosting Decision Tree, GBDT)

梯度提升树是一种基于**梯度提升(Gradient Boosting)**框架的机器学习算法,通过构建多个决策树并利用每棵树拟合前一棵树的残差来逐步优化模型。


1. 核心思想

  • Boosting:通过逐步调整模型,使后续的模型重点学习前一阶段未能正确拟合的数据。
  • 梯度提升:将误差函数的负梯度作为残差,指导新一轮模型的训练。
与随机森林的区别
特性 随机森林 梯度提升树
基本思想 Bagging Boosting
树的训练方式 并行训练 顺序训练
树的类型 完全树 通常是浅树(弱学习器)
应用场景 抗过拟合、快速训练 高精度、复杂任务

2. 算法流程

  1. 输入

    • 数据集
    • 损失函数 ,如平方误差、对数似然等。
    • 弱学习器个数 T 和学习率 η。
  2. 初始化模型

    • 是一个常数,通常为目标变量的均值(回归)或类别概率的对数(分类)。
  3. 迭代训练每棵弱学习器(树)

    • 第 t 次迭代:
      1. 计算第 t 轮的负梯度(残差):

        残差反映当前模型未能拟合的部分。
      2. 构建决策树 拟合残差
      3. 计算最佳步长(叶节点输出值):
      4. 更新模型: 其中 η 是学习率,控制每棵树的贡献大小。
  4. 输出模型: 最终模型为:


3. 损失函数

GBDT 可灵活选择损失函数,以下是常用的几种:

  1. 平方误差(MSE,回归问题)

    • 负梯度:
  2. 对数似然(Log-Loss,二分类问题)

    • 负梯度:
  3. 指数损失(Adaboost)


4. GBDT 的优缺点

优点
  1. 灵活性:支持回归和分类任务,且损失函数可定制。
  2. 高精度:由于采用 Boosting 框架,能取得非常好的预测效果。
  3. 特征选择:内置特征重要性评估,帮助筛选关键特征。
  4. 处理缺失值:部分实现(如 XGBoost)可以自动处理缺失值。
缺点
  1. 训练时间长:由于弱学习器依次构建,训练过程较慢。
  2. 对参数敏感:需要调整学习率、树的数量、最大深度等参数。
  3. 不擅长高维稀疏数据:相比线性模型和神经网络,GBDT 在处理高维数据(如文本数据)时表现一般。

5. GBDT 的改进

  1. XGBoost

    • 增加正则化项,控制模型复杂度。
    • 支持并行化计算,加速训练。
    • 提供更高效的特征分裂方法。
  2. LightGBM

    • 提出叶子分裂(Leaf-Wise)策略。
    • 适合大规模数据和高维特征场景。
  3. CatBoost

    • 专门针对分类特征优化。
    • 避免目标泄露(Target Leakage)。

6. GBDT 的代码实现

以下是 GBDT 的分类问题实现:

python 复制代码
from sklearn.datasets import make_classification
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# 生成数据
X, y = make_classification(n_samples=1000, n_features=10, n_informative=5, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 创建 GBDT 模型
gbdt = GradientBoostingClassifier(n_estimators=100, learning_rate=0.1, max_depth=3, random_state=42)
gbdt.fit(X_train, y_train)

# 预测
y_pred = gbdt.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print("分类准确率:", accuracy)

# 特征重要性
import matplotlib.pyplot as plt
import numpy as np

feature_importances = gbdt.feature_importances_
indices = np.argsort(feature_importances)[::-1]

plt.figure(figsize=(10, 6))
plt.title("Feature Importance")
plt.bar(range(X.shape[1]), feature_importances[indices], align="center")
plt.xticks(range(X.shape[1]), indices)
plt.show()

输出结果

bash 复制代码
分类准确率: 0.9366666666666666

7. 应用场景

  1. 回归问题:如预测房价、商品销量。
  2. 分类问题:如金融风险预测、垃圾邮件分类。
  3. 排序问题:如搜索引擎的结果排序。
  4. 时间序列问题:预测趋势或模式。

GBDT 是机器学习中的经典算法,尽管深度学习在许多领域占据主导地位,但在表格数据和中小规模数据集的应用中,GBDT 仍然是非常强大的工具。

相关推荐
zzywxc78710 分钟前
AI在金融、医疗、教育、制造业等领域的落地案例
人工智能·机器学习·金融·prompt·流程图
zstar-_19 分钟前
【论文阅读】REFRAG:一个提升RAG解码效率的新思路
人工智能
武文斌7723 分钟前
arm启动代码总结
arm开发·嵌入式硬件·学习
慧一居士38 分钟前
SpringBoot改造MCP服务器(StreamableHTTP)
人工智能
索迪迈科技43 分钟前
安防芯片 ISP 的白平衡统计数据对图像质量有哪些影响?
人工智能·计算机视觉·白平衡
AiTop1001 小时前
腾讯推出AI CLI工具CodeBuddy,国内首家同时支持插件、IDE和CLI三种形态的AI编程工具厂商
ide·人工智能·ai·aigc·ai编程
我怕是好1 小时前
学习stm32 蓝牙
stm32·嵌入式硬件·学习
索迪迈科技1 小时前
STM32F103C8T6开发板入门学习——点亮LED灯2
stm32·嵌入式硬件·学习
非门由也1 小时前
《sklearn机器学习——回归指标2》
机器学习·回归·sklearn
山楂树下懒猴子1 小时前
ChatAI项目-ChatGPT-SDK组件工程
人工智能·chatgpt·junit·https·log4j·intellij-idea·mybatis