基于Python的机器学习系列(17):梯度提升回归(Gradient Boosting Regression)

简介

梯度提升(Gradient Boosting)是一种强大的集成学习方法,类似于AdaBoost,但与其不同的是,梯度提升通过在每一步添加新的预测器来减少前一步预测器的残差。这种方法通过逐步改进模型,能够有效提高预测准确性。

梯度提升回归的工作原理

在梯度提升回归中,我们逐步添加预测器来修正模型的残差。以下是梯度提升的基本步骤:

  1. 初始化模型:选择一个初始预测器 h0(x),计算该预测器的预测值。
  2. 计算残差:计算每个样本的残差,残差是实际值与当前预测值之间的差异。
  3. 训练新预测器:用计算得到的残差作为目标,训练一个新的预测器 h1(x)。
  4. 更新模型:将新预测器的预测结果加到现有模型中。
  5. 重复步骤:重复上述步骤,逐步添加更多的预测器,以减少残差。

目标函数与残差

在回归问题中,我们希望通过添加新的预测器来最小化残差。具体来说,对于每个样本 (x(i),y(i)),我们计算预测器的残差:

我们希望新的预测器 h1(x)能够进一步减少这个残差:

通过这样的方式,我们可以不断改进模型的预测能力。

梯度提升回归的损失函数

在回归中,我们通常使用均方误差(MSE)作为损失函数:

我们的目标是通过每一步最小化残差,从而最小化整体损失函数。

代码示例

下面的代码示例展示了如何使用sklearn中的GradientBoostingRegressor实现梯度提升回归:

python 复制代码
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.metrics import mean_squared_error

# 生成数据集
X, y = make_regression(n_samples=500, noise=0.2, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 创建和训练模型
gbr = GradientBoostingRegressor(n_estimators=100, learning_rate=0.1, max_depth=3, random_state=42)
gbr.fit(X_train, y_train)

# 进行预测和评估
y_pred = gbr.predict(X_test)
mse = mean_squared_error(y_test, y_pred)
print(f"均方误差: {mse:.2f}")

结语

与之前讨论的决策树、Bagging、随机森林相比,梯度提升回归通过逐步优化模型的残差来提升预测性能。决策树和Bagging方法通过集成多个模型来减少方差,而随机森林进一步通过随机特征选择来去相关性。梯度提升则通过序列化的方式不断改进模型,强调对残差的逐步修正。每种方法都有其独特的优势和适用场景,选择合适的模型可以显著提高预测的准确性。

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!

相关推荐
hummhumm21 分钟前
第 25 章 - Golang 项目结构
java·开发语言·前端·后端·python·elasticsearch·golang
杜小满25 分钟前
周志华深度森林deep forest(deep-forest)最新可安装教程,仅需在pycharm中完成,超简单安装教程
python·随机森林·pycharm·集成学习
Chef_Chen1 小时前
从0开始学习机器学习--Day33--机器学习阶段总结
人工智能·学习·机器学习
databook1 小时前
『玩转Streamlit』--布局与容器组件
python·机器学习·数据分析
肖永威2 小时前
CentOS环境上离线安装python3及相关包
linux·运维·机器学习·centos
nuclear20112 小时前
使用Python 在Excel中创建和取消数据分组 - 详解
python·excel数据分组·创建excel分组·excel分类汇总·excel嵌套分组·excel大纲级别·取消excel分组
Lucky小小吴2 小时前
有关django、python版本、sqlite3版本冲突问题
python·django·sqlite
GIS 数据栈3 小时前
每日一书 《基于ArcGIS的Python编程秘笈》
开发语言·python·arcgis
爱分享的码瑞哥3 小时前
Python爬虫中的IP封禁问题及其解决方案
爬虫·python·tcp/ip
傻啦嘿哟4 小时前
如何使用 Python 开发一个简单的文本数据转换为 Excel 工具
开发语言·python·excel