监督学习 - 梯度提升回归(Gradient Boosting Regression)

什么是机器学习

梯度提升回归 (Gradient Boosting Regression)是一种集成学习方法,用于解决回归问题。它通过迭代地训练一系列弱学习器(通常是决策树)来逐步提升模型的性能。梯度提升回归的基本思想是通过拟合前一轮模型的残差(实际值与预测值之差)来构建下一轮模型,从而逐步减小模型对训练数据的预测误差。

以下是梯度提升回归的主要步骤:

  1. 初始化: 初始模型可以是一个简单的模型,比如均值模型。这个模型将用于第一轮训练。
  2. 迭代训练: 对于每一轮迭代,都会训练一个新的弱学习器(通常是决策树),该学习器将拟合前一轮模型的残差。新模型的预测结果将与前一轮模型的预测结果相加,从而逐步改善模型的性能。
  3. 残差计算: 在每一轮迭代中,计算实际值与当前模型的预测值之间的残差。残差表示模型尚未能够正确拟合的部分。
  4. 学习率: 通过引入学习率(learning rate)来控制每一轮模型的权重。学习率是一个小于 1 的参数,它乘以每一轮模型的预测结果,用于缓慢地逼近真实的目标值。
  5. 停止条件: 迭代可以在达到一定的轮数或者当模型的性能满足一定条件时停止。

在实际应用中,可以使用梯度提升回归的库,如Scikit-Learn中的GradientBoostingRegressor类,来实现梯度提升回归。

以下是一个简单的Python代码示例:

python 复制代码
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import numpy as np
import matplotlib.pyplot as plt

# 创建示例数据集
np.random.seed(42)
X = np.sort(5 * np.random.rand(80, 1), axis=0)
y = np.sin(X).ravel() + np.random.normal(0, 0.1, X.shape[0])

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 创建梯度提升回归模型
gb_regressor = GradientBoostingRegressor(n_estimators=100, learning_rate=0.1, max_depth=3, random_state=42)

# 在训练集上训练模型
gb_regressor.fit(X_train, y_train)

# 在测试集上进行预测
y_pred = gb_regressor.predict(X_test)

# 评估模型性能
mse = mean_squared_error(y_test, y_pred)
print(f"均方误差(MSE): {mse}")

# 可视化结果
plt.figure(figsize=(8, 6))
plt.scatter(X, y, s=20, edgecolor="black", c="darkorange", label="data")
plt.plot(X_test, y_pred, color="cornflowerblue", label="prediction")
plt.xlabel("data")
plt.ylabel("target")
plt.title("Gradient Boosting Regression")
plt.legend()
plt.show()

在这个例子中,GradientBoostingRegressor 类的关键参数包括 n_estimators(迭代次数)、learning_rate(学习率)、max_depth(树的最大深度)等。这些参数可以根据实际问题进行调整。

相关推荐
企销客CRM2 小时前
CRM管理软件的数据可视化功能使用技巧:让数据驱动决策
信息可视化·数据挖掘·数据分析·用户运营
武子康4 小时前
大数据-276 Spark MLib - 基础介绍 机器学习算法 Bagging和Boosting区别 GBDT梯度提升树
大数据·人工智能·算法·机器学习·语言模型·spark-ml·boosting
武子康4 小时前
大数据-277 Spark MLib - 基础介绍 机器学习算法 Gradient Boosting GBDT算法原理 高效实现
大数据·人工智能·算法·机器学习·ai·spark-ml·boosting
人大博士的交易之路14 小时前
今日行情明日机会——20250606
大数据·数学建模·数据挖掘·数据分析·涨停回马枪
产品何同学15 小时前
数据分析后台设计指南:实战案例解析与5大设计要点总结
数据挖掘·数据分析·产品经理·墨刀·原型设计·后台管理系统·数据分析后台
lilye6617 小时前
精益数据分析(95/126):Socialight的定价转型启示——B2B商业模式的价格策略与利润优化
人工智能·数据挖掘·数据分析
电商API_180079052471 天前
构建高效可靠的电商 API:设计原则与实践指南
运维·服务器·爬虫·数据挖掘·网络爬虫
拓端研究室TRL1 天前
PySpark、Plotly全球重大地震数据挖掘交互式分析及动态可视化研究
人工智能·plotly·数据挖掘
思通数科多模态大模型1 天前
重构城市应急指挥布控策略 ——无人机智能视频监控的破局之道
人工智能·深度学习·安全·重构·数据挖掘·音视频·无人机
十三画者1 天前
【数据分析】R版IntelliGenes用于生物标志物发现的可解释机器学习
python·机器学习·数据挖掘·数据分析·r语言·数据可视化