机器学习入门(二)线性回归

在上一篇文章 机器学习入门(一)什么是机器学习 中,提到过,机器学习一般分为监督学习、无监督学习、半监督学习以及强化学习四种。这篇文章将介绍监督学习中的线性回归算法。

单因子的线性回归

回归分析 是指根据数据,确定两种或两种以上变量间相互依赖的定量关系。而线性回归顾名思义就是变量之间呈线性的关系。这里以房价和房子的大小的关系为例,一般房子越大,其房价就越贵,如下图所示:

因此我们可以假设房价和房子的关系为 <math xmlns="http://www.w3.org/1998/Math/MathML"> y ^ i = β ^ 0 x + β ^ 1 \hat{y}_i = \hat{\beta}_0x + \hat{\beta}_1 </math>y^i=β^0x+β^1,可以看到现在的线性回归就是我们初中学习的二元一次方程。而我们的目的就是求出该函数的 <math xmlns="http://www.w3.org/1998/Math/MathML"> β ^ 0 \hat{\beta}_0 </math>β^0 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> β ^ 1 \hat{\beta}_1 </math>β^1 的值。但是,现实中数据都不是精确的,它可能会受到其他因素的影响,比如说,房价除了面积外,还受其房子年龄、周边人口等。这样就导致了,现实世界中的数据是一个个的散点图,我们无法使用一条直线完全拟合图中的数据点。

如上图所示,我们画了三条线来拟合数据,但是这三条线中哪一条更好呢?从图中可以看到,貌似是红色的直线更加拟合,但是这种目测的方式非常不准确。这时我们就可以使用损失函数来评判哪一条直线更加拟合。

损失函数

损失函数是衡量回归模型误差的函数,也就是我们要的"直线"的评价标准。这个函数的值越小,说明直线越能拟合我们的数据。

在回归算法中,一般使用残差平方和来作为损失函数,如下图所示。

残差平方和的公式为: <math xmlns="http://www.w3.org/1998/Math/MathML"> RSS = ∑ i = 1 n ( y i − y ^ i ) 2 \text{RSS} = \sum_{i=1}^{n} (y_i - \hat{y}_i)^2 </math>RSS=∑i=1n(yi−y^i)2

把公式展开后,可以发现它其实是一个二次方程。如下图所示:

由于 <math xmlns="http://www.w3.org/1998/Math/MathML"> y ^ i = β ^ 0 x + β ^ 1 \hat{y}_i = \hat{\beta}_0x + \hat{\beta}_1 </math>y^i=β^0x+β^1 中的 <math xmlns="http://www.w3.org/1998/Math/MathML"> β ^ 0 \hat{\beta}_0 </math>β^0 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> β ^ 1 \hat{\beta}_1 </math>β^1 都是未知变量,因此其画出来的是三维图形:

从图中可以看出,其最低点就是该损失函数的最小值。根据大学的微积分知识,当导数为0时,就是该函数的极小值点。因此我们分别对 <math xmlns="http://www.w3.org/1998/Math/MathML"> β ^ 0 \hat{\beta}_0 </math>β^0和 <math xmlns="http://www.w3.org/1998/Math/MathML"> β ^ 1 \hat{\beta}_1 </math>β^1求偏导并令其为0:

这样就可以求出 <math xmlns="http://www.w3.org/1998/Math/MathML"> β ^ 0 \hat{\beta}_0 </math>β^0和 <math xmlns="http://www.w3.org/1998/Math/MathML"> β ^ 1 \hat{\beta}_1 </math>β^1的值了。

上文中这种求极小值的方法叫做梯度下降法,关于梯度下降更详细的解释可以看 Gradient-Descent

scikit-learn

scikit-learn 是 python 中专门针对机器学习应用而发展起来的一款开源框架(算法库),开源实现数据预处理、分类、回归、降维、模型选择等常用的机器学习算法。但是不支持python以外的语言,也不支持深度学习和强化学习。

代替手动计算 <math xmlns="http://www.w3.org/1998/Math/MathML"> β ^ 0 \hat{\beta}_0 </math>β^0和 <math xmlns="http://www.w3.org/1998/Math/MathML"> β ^ 1 \hat{\beta}_1 </math>β^1的值,我们可以直接使用 scikit-learn 库。代码示例如下:

python 复制代码
from sklearn.linear_model import LinearRegression

# 创建一个 LinearRegression 类的实例,命名为 lr_model
lr_model = LinearRegression()
# 使用 fit 方法对线性回归模型进行训练
lr_model.fit(x, y)
# 斜率(系数)
a = lr_model.coef_
# 获取线性回归模型的截距
b = lr_model.intercept_

把获取到的直线绘制出来,其效果如下图所示:

多因子的线性回归

上面考虑的只是房子大小单个因子对房价的影响,但是现实世界不可能只有一个影响因素,而是多个。假设房价由两个因素影响,分别是房子大小和房子年龄。那么假设房价与房子大小和房子年龄的关系为: <math xmlns="http://www.w3.org/1998/Math/MathML"> y ^ i = β ^ 0 x 0 + β ^ 1 x 1 + β ^ 2 \hat{y}_i = \hat{\beta}_0x_0 + \hat{\beta}_1x_1 + \hat{\beta}_2 </math>y^i=β^0x0+β^1x1+β^2 其中, <math xmlns="http://www.w3.org/1998/Math/MathML"> x 0 x_0 </math>x0 表示房子大小, <math xmlns="http://www.w3.org/1998/Math/MathML"> x 1 x_1 </math>x1表示房子年龄。

其对应的损失函数仍为: <math xmlns="http://www.w3.org/1998/Math/MathML"> RSS = ∑ i = 1 n ( y i − y ^ i ) 2 \text{RSS} = \sum_{i=1}^{n} (y_i - \hat{y}_i)^2 </math>RSS=∑i=1n(yi−y^i)2 ,区别是 <math xmlns="http://www.w3.org/1998/Math/MathML"> y ^ i = β ^ 0 x 0 + β ^ 1 x 1 + β ^ 2 \hat{y}_i = \hat{\beta}_0x_0 + \hat{\beta}_1x_1 + \hat{\beta}_2 </math>y^i=β^0x0+β^1x1+β^2。

同样,对于多因子的线性回归,我们可以使用scikit-learn库来计算。代码示例如下:

python 复制代码
from sklearn.metrics import mean_squared_error,r2_score
# 使用线性回归模型
from sklearn.linear_model import LinearRegression
lr_model_multi = LinearRegression()
# 一样使用 fit 方法,区别是 多因子的线性回归 传入的 x_multi 是多维数组
lr_model_multi.fit(x_multi, y)
# 输出数组,比如: [1. 1.](每个特征一个系数)
print("斜率(系数):", model_multi.coef_)  
print("截距:", model_multi.intercept_)    

参考

相关推荐
追逐☞3 小时前
机器学习(7)——K均值聚类
机器学习·均值算法·聚类
追逐☞4 小时前
机器学习(9)——随机森林
人工智能·随机森林·机器学习
云天徽上6 小时前
【数据可视化-28】2017-2025 年每月产品零售价数据可视化分析
机器学习·信息可视化·数据挖掘·数据分析·零售
硅谷秋水6 小时前
CoT-Drive:利用 LLM 和思维链提示实现自动驾驶的高效运动预测
人工智能·机器学习·语言模型·自动驾驶
IT古董7 小时前
【漫话机器学习系列】214.停用词(Stop Words)
人工智能·机器学习
云天徽上8 小时前
【数据可视化-27】全球网络安全威胁数据可视化分析(2015-2024)
人工智能·安全·web安全·机器学习·信息可视化·数据分析
硅谷秋水8 小时前
ORION:通过视觉-语言指令动作生成的一个整体端到端自动驾驶框架
人工智能·深度学习·机器学习·计算机视觉·语言模型·自动驾驶
小墙程序员8 小时前
机器学习入门(一)什么是机器学习
机器学习
豆芽8198 小时前
强化学习(Reinforcement Learning, RL)和深度学习(Deep Learning, DL)
人工智能·深度学习·机器学习·强化学习