机器学习之 AdaBoost(Adaptive Boosting)

机器学习之 AdaBoost(Adaptive Boosting)


AdaBoost(Adaptive Boosting)是一种集成学习方法,主要用于提升弱分类器的性能,通过组合多个弱分类器(例如决策树)来构建一个强分类器。

AdaBoost在1995年由Yoav Freund和Robert Schapire提出,并且在机器学习领域得到了广泛的应用。

基本介绍

AdaBoost的基本思想是迭代地训练一系列弱分类器,每个弱分类器针对训练数据集进行训练,并根据上一个弱分类器的结果对数据进行加权。在每一轮迭代中,AdaBoost都会关注上一轮分类错误的样本,尝试通过调整权重使得这些样本在下一轮分类中得到更好的处理。最终,将这些弱分类器组合成一个强分类器,其性能优于单个弱分类器。

核心原理

  1. 初始化样本权重: 将每个样本的权重初始化为相等值。
  2. 迭代训练: 在每一轮迭代中,根据当前样本权重训练一个弱分类器。
  3. 计算错误率: 使用训练好的弱分类器对样本进行分类,并计算分类错误的样本的权重之和。
  4. 更新样本权重: 对于分类错误的样本,增加其权重,使其在下一轮迭代中更受关注;而对于分类正确的样本,减小其权重,降低其在下一轮迭代中的影响。
  5. 组合弱分类器: 根据每个弱分类器的分类准确性(权重),组合这些弱分类器得到一个强分类器。

公式推理


这样,分类错误率较低的弱分类器在最终的强分类器中会被赋予更大的权重,从而更大程度上影响最终的分类结果。

代码案例

使用多项式数据来演示AdaBoost回归器的效果,并且绘制了两个不同数量基学习器的情况下的回归结果。

python 复制代码
import numpy as np
import matplotlib.pyplot as plt
from sklearn.ensemble import AdaBoostRegressor
from sklearn.tree import DecisionTreeRegressor
from sklearn.preprocessing import PolynomialFeatures
from sklearn.pipeline import make_pipeline

# 生成多项式数据
np.random.seed(42)
X = np.random.rand(500) * 10
y = np.sin(X) + np.random.normal(0, 0.5, 500)

# 将数据转换成二维数组
X = X[:, np.newaxis]

# 定义基学习器
base_regressor = DecisionTreeRegressor(max_depth=4)

# 定义AdaBoost回归器
regr_1 = AdaBoostRegressor(base_regressor, n_estimators=10, random_state=42)
regr_2 = AdaBoostRegressor(base_regressor, n_estimators=100, random_state=42)

# 训练模型
regr_1.fit(X, y)
regr_2.fit(X, y)

# 生成预测结果
X_test = np.linspace(0, 10, 500)[:, np.newaxis]
y_1 = regr_1.predict(X_test)
y_2 = regr_2.predict(X_test)

# 绘制结果
plt.figure(figsize=(10, 6))
plt.scatter(X, y, c='b', label='Training samples')
plt.plot(X_test, y_1, c='r', label='n_estimators=10', linewidth=2)
plt.plot(X_test, y_2, c='g', label='n_estimators=100', linewidth=2)
plt.xlabel('Data')
plt.ylabel('Target')
plt.title('AdaBoost Regression with Polynomial Features')
plt.legend()
plt.show()

首先,生成了一些多项式数据,然后使用了决策树回归器作为基学习器,并且分别使用了10个和100个基学习器的AdaBoost回归器进行拟合。

最后,我们绘制出了原始数据以及两种情况下的拟合结果。

树回归器作为基学习器,并且分别使用了10个和100个基学习器的AdaBoost回归器进行拟合。

最后,我们绘制出了原始数据以及两种情况下的拟合结果。

相关推荐
就爱敲代码17 分钟前
怎么理解ES6 Proxy
1024程序员节
憧憬一下17 分钟前
input子系统的框架和重要数据结构详解
arm开发·嵌入式·c/c++·1024程序员节·linux驱动开发
三日看尽长安花26 分钟前
【Tableau】
1024程序员节
萧鼎44 分钟前
【Python】高效数据处理:使用Dask处理大规模数据
开发语言·python
sswithyou1 小时前
Linux的调度算法
1024程序员节
武子康1 小时前
大数据-187 Elasticsearch - ELK 家族 Logstash Filter 插件 使用详解
大数据·数据结构·elk·elasticsearch·搜索引擎·全文检索·1024程序员节
互联网杂货铺1 小时前
Python测试框架—pytest详解
自动化测试·软件测试·python·测试工具·测试用例·pytest·1024程序员节
yyfhq1 小时前
dcgan
深度学习·机器学习·生成对抗网络
Ellie陈1 小时前
Java已死,大模型才是未来?
java·开发语言·前端·后端·python
菜鸟的人工智能之路1 小时前
ROC 曲线:医学研究中的得力助手
python·数据分析·健康医疗