sklearn基础--『监督学习』之线性回归

线性回归 是一种用于连续型分布预测的机器学习算法。

其基本思想是通过拟合一个线性函数来最小化样本数据和预测函数之间的误差。

1. 概述

常见的线性回归模型就是: <math xmlns="http://www.w3.org/1998/Math/MathML"> f ( x ) = w 0 + w 1 x 1 + w 2 x 2 + . . . + w n x n f(x) = w_0+w_1x_1+w_2x_2+...+w_nx_n </math>f(x)=w0+w1x1+w2x2+...+wnxn这样的一个函数。

其中

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> ( w 1 , w 2 , . . . w n ) (w_1,w_2,...w_n) </math>(w1,w2,...wn)是模型的系数向量
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> w 0 w_0 </math>w0是截距
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> ( x 1 , x 2 , . . . , x n ) (x_1, x_2,...,x_n) </math>(x1,x2,...,xn)是样本数据(n是样本数据的维度

简单来说,线性回归模型的训练就是通过样本数据来确定系数向量 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( w 1 , w 2 , . . . w n ) (w_1,w_2,...w_n) </math>(w1,w2,...wn)和截距 <math xmlns="http://www.w3.org/1998/Math/MathML"> w 0 w_0 </math>w0的具体数值。

然后可以使用模型 <math xmlns="http://www.w3.org/1998/Math/MathML"> f ( x ) f(x) </math>f(x)来预测新的样本数据。

2. 创建样本数据

首先,用scikit-learn中的自带的函数,就可以创建出适用于线性回归场景的样本数据。

python 复制代码
from sklearn.datasets import make_regression
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1, 1)

X, y = make_regression(n_samples=100, n_features=1, noise=10)
ax.scatter(X[:, 0], y, marker="o")
ax.set_title("样本数据")

plt.show()

通过 make_regression 函数可以帮助我们创建任意的回归样本数据。

具体使用可以参考之前的文章: sklearn基础--『数据加载』之样本生成器

3. 模型训练

训练线性回归模型,一般使用最小二乘法 ,而scikit-learnlinear_model模块中,

已经封装好了最小二乘法的训练算法。

首先,根据上面的样本数据,划分训练集和测试集。

python 复制代码
from sklearn.model_selection import train_test_split

# 分割训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1)

上面的代码按照9:1的比例划分了训练集测试集

然后,用基于最小二乘法的线性模型来训练数据。

python 复制代码
from sklearn.linear_model import LinearRegression

# 初始化最小二乘法线性模型
reg = LinearRegression()
# 训练模型
reg.fit(X_train, y_train)

print("模型的系数向量:", reg.coef_)
print("模型的系数截距:", reg.intercept_)

# 运行结果:
模型的系数向量: [99.59241352]
模型的系数截距: 0.6889080586801999

reg.coef_就相当于前面的 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( w 1 , w 2 , . . . w n ) (w_1,w_2,...w_n) </math>(w1,w2,...wn)
reg.intercept_就相当于前面的 <math xmlns="http://www.w3.org/1998/Math/MathML"> w 0 w_0 </math>w0

这里的样本数据为了方便绘图,只有一个维度。

最后,我们把线性模型训练集测试集都绘制出来看看效果。

python 复制代码
fig, ax = plt.subplots(1, 1)

# 训练集
ax.scatter(X_train[:, 0], y_train, marker="o", c="g")
# 测试集
ax.scatter(X_test[:, 0], y_test, marker="*", c="r")

# 线性模型
reg_x = np.array([-3, 3])
reg_y = reg.coef_ * reg_x + reg.intercept_
ax.plot(reg_x, reg_y, color="b")

plt.show()

上图中,蓝色的直线是在训练集上训练出来的线性模型

绿色的圆点是训练集 ;红色的五角星是测试集

从图中可以看出,训练的效果还不错。

总结

线性回归是一种常见的预测模型,可以用在

  1. 预测连续值:比如根据房屋面积和价格的关系预测房价,根据油耗和车辆重量之间的关系预测油耗等等。
  2. 判断因果关系:比如药物剂量和血压之间的关系,或者产品价格和销售量之间的关系

它的主要优势有简单易理解 (模型简单直观,易于理解和解释),
易于实施 (各种编程语言中都有现成的库和函数可以方便地实现和应用),

还有稳定性高(对于训练数据的小变化相对稳定,能够提供较为准确的结果)。

不过,线性回归模型需要基于一些假设 ,如误差项的独立性、同方差性等,如果这些假设不满足,模型可能无法得到准确的结果;

其次,对于一些非线性关系 的数据表现可能不佳;

而且,对异常值敏感,如果数据中存在异常值,线性回归的结果可能会受到较大影响,模型训练前最好先清洗和处理异常值。

相关推荐
B站计算机毕业设计之家3 小时前
大数据python招聘数据分析预测系统 招聘数据平台 +爬虫+可视化 +django框架+vue框架 大数据技术✅
大数据·爬虫·python·机器学习·数据挖掘·数据分析
落羽的落羽4 小时前
【C++】现代C++的新特性constexpr,及其在C++14、C++17、C++20中的进化
linux·c++·人工智能·学习·机器学习·c++20·c++40周年
云雾J视界5 小时前
AI驱动半导体良率提升:基于机器学习的晶圆缺陷分类系统搭建
人工智能·python·机器学习·智能制造·数据驱动·晶圆缺陷分类
极客学术工坊8 小时前
2023年第二十届五一数学建模竞赛-A题 无人机定点投放问题-基于抛体运动的无人机定点投放问题研究
人工智能·机器学习·数学建模·启发式算法
Theodore_10229 小时前
深度学习(9)导数与计算图
人工智能·深度学习·机器学习·矩阵·线性回归
极客学术工坊13 小时前
2022年第十二届MathorCup高校数学建模挑战赛-D题 移动通信网络站址规划和区域聚类问题
机器学习·数学建模·启发式算法·聚类
领航猿1号16 小时前
Pytorch 内存布局优化:Contiguous Memory
人工智能·pytorch·深度学习·机器学习
hakuii18 小时前
SVD分解后的各个矩阵的深层理解
人工智能·机器学习·矩阵
这张生成的图像能检测吗18 小时前
(论文速读)基于图像堆栈的低频超宽带SAR叶簇隐蔽目标变化检测
图像处理·人工智能·深度学习·机器学习·信号处理·雷达·变化检测
Blossom.11819 小时前
大模型在边缘计算中的部署挑战与优化策略
人工智能·python·算法·机器学习·边缘计算·pygame·tornado