机器学习--线性回归

线性回归

引入

我们在高中的时候都学过线性回归,在这我们回顾一下

在高中的课程中,我们会被给得到一组 x m {x_m} xm 和一组 y m {y_m} ym,然后我们想用 h ( x ) = y ^ = a x + b h(x) = \hat{y} = ax + b h(x)=y^=ax+b 来拟合这组数据使得整体上 h ( x i ) ≈ y i h(x_i) \approx y_i h(xi)≈yi 也就是这样:

而在高中教材中 a a a 和 b b b 都给出的准确的计算式(其中 x ˉ \bar{x} xˉ 和 y ˉ \bar{y} yˉ 是 x m {x_m} xm 和 y m {y_m} ym 的平均数):

b ^ = ∑ i = 1 m ( x i − x ˉ ) ( y i − y ˉ ) ∑ i = 1 m ( x i − x ˉ ) 2 = ∑ i = 1 m x i y i − m x ˉ y ˉ ∑ i = 1 m x i 2 − m x ˉ 2 a = y ˉ − b ^ x ˉ \begin{aligned} \hat{b} = & \frac{\sum\limits_{i = 1}^m(x_i - \bar{x})(y_i - \bar{y})}{\sum\limits_{i = 1}^m(x_i-\bar{x})^2} = \frac{\sum\limits_{i = 1}^mx_iy_i - m\bar{x}\bar{y}}{\sum\limits_{i = 1}^mx_i^2 - m\bar{x}^2} \\ a = & \bar{y} - \hat{b}\bar{x} \end{aligned} b^=a=i=1∑m(xi−xˉ)2i=1∑m(xi−xˉ)(yi−yˉ)=i=1∑mxi2−mxˉ2i=1∑mxiyi−mxˉyˉyˉ−b^xˉ

而在机器学习中,我们希望计算机通过某种算法学习出一组 a , b a, b a,b,使得这条直线满足上述要求。

进入正题

在我们要解决的线性回归问题中,我们不仅不能使用上述已知的表达式解,我们还可以将维度升高。比如我们的数组 x m {x_m} xm 中的每个 x ( i ) x^{(i)} x(i) 不再是一个标量,而是一个 n n n 维的向量,同理我们的 w ( i ) w^{(i)} w(i) 也是一个 n n n 维的向量:

x ( i ) = [ x 1 ( i ) x 2 ( i ) ⋮ x n ( i ) ]              w = [ w 1 ( i ) w 2 ⋮ w n ] x^{(i)} = \begin{bmatrix} x^{(i)}_1 \\ x^{(i)}_2 \\ \vdots \\ x^{(i)}_n \end{bmatrix} \;\;\;\;\;\; w = \begin{bmatrix} w^{(i)}_1 \\ w_2 \\\vdots \\ w_n \end{bmatrix} x(i)= x1(i)x2(i)⋮xn(i) w= w1(i)w2⋮wn

然后我们就要用我们构造出来的 h ( x ( i ) ) h(x^{(i)}) h(x(i)) 来拟合 y ( i ) y^{(i)} y(i):

y ^ = h ( x ( i ) ) = ∑ j = 1 n w j x j ( i ) + b = w T x ( i ) + b \hat{y} = h(x^{(i)}) = \sum_{j = 1}^n w_jx^{(i)}_j + b = w^Tx^{(i)} + b y^=h(x(i))=j=1∑nwjxj(i)+b=wTx(i)+b

在这里,我们定义 J ( w , b ) = 1 2 m ∑ i = 1 m ( h ( x ( i ) ) − y ( i ) ) 2 J(w, b) = \frac 1{2m}\sum\limits_{i = 1}^m(h(x^{(i)}) - y^{(i)})^2 J(w,b)=2m1i=1∑m(h(x(i))−y(i))2 叫做损失函数 c o s t    f u n c t i o n cost\;function costfunction,这个损失函数描述的是预测值 y ^ \hat{y} y^ 与真实值 y y y 的差异程度,上述表达式很好的符合这个定义。

并且显然的,我们希望这个 J ( w , b ) J(w, b) J(w,b) 越小越好。当 J ( w , b ) J(w, b) J(w,b) 取到最小值时,我们认为 h ( x ) h(x) h(x) 对 y y y 的拟合达到最完美。

梯度下降---最小化 J ( w , b ) J(w, b) J(w,b)

我们觉得写成 h ( x ) = w T x + b h(x) = w^Tx + b h(x)=wTx+b 太丑了,所以我们考虑令 x 0 ( i ) = 1 x^{(i)}_0 = 1 x0(i)=1,并新增 w 0 w_0 w0 使得每个 x ( i ) x^{(i)} x(i) 和 w w w 都变成一个 n + 1 n + 1 n+1 维的向量:

x ( i ) = [ x 0 ( i ) x 1 ( i ) x 2 ( i ) ⋮ x n ( i ) ]              w = [ w 0 w 1 ( i ) w 2 ⋮ w n ] x^{(i)} = \begin{bmatrix} x^{(i)}_0 \\ x^{(i)}_1 \\ x^{(i)}_2 \\ \vdots \\ x^{(i)}_n \end{bmatrix} \;\;\;\;\;\; w = \begin{bmatrix} w_0 \\ w^{(i)}_1 \\ w_2 \\\vdots \\ w_n \end{bmatrix} x(i)= x0(i)x1(i)x2(i)⋮xn(i) w= w0w1(i)w2⋮wn

于是 h ( x ) = w T x h(x) = w^Tx h(x)=wTx 这样就好看多了qwq,并且此时 J ( w ) = 1 2 m ∑ i = 1 m ( h ( x ( i ) ) − y ( i ) ) 2 J(w) = \frac 1{2m}\sum\limits_{i = 1}^m(h(x^{(i)}) - y^{(i)})^2 J(w)=2m1i=1∑m(h(x(i))−y(i))2

我们想做的就是一下描述的事情:

  1. 开始时对 w w w 随机赋值
  2. 持续改变 w w w 的值,使得 J ( w ) J(w) J(w) 的值减小,直到我们最后走到期望的最小值

然后这里我们介绍一种名为梯度下降 g r a d i e n t    d e s c e n t gradient \; descent gradientdescent 的算法来执行第二部,感性的理解就是在某一点找到 "下山" 的方向,并沿着这个方向走一步,并且重复这个动作知道走到 "山脚下"。

具体来说就是每次计算 J ( w ) J(w) J(w) 对每个 w i w_i wi 的偏导数,然后执行:

f o r    i    r a n g e    f r o m    0    t o    n {                  w j : = w j − α ∂ ∂ w j J ( w ) } \begin{aligned} &for \; i \; range \; from \; 0 \; to \; n \{ \\ &\;\;\;\;\;\;\;\; w_j := w_j - \alpha \frac{\partial}{\partial w_j}J(w) \\ &\} \end{aligned} forirangefrom0ton{wj:=wj−α∂wj∂J(w)}

其中 α \alpha α 是 步长 l e a r n i n g    r a t e learning \; rate learningrate,表示这一步的大小。

这样看着可能不太直观,我们举个二维的例子更清晰的展示一下:

这里,这个红色的曲线就是我们的 J ( w ) J(w) J(w)。我们在起点对 J ( w ) J(w) J(w) 求导,得到了蓝色的切线的斜率,这个斜率显然是个正数,于是我们沿着负方向走一步,这一步的大小取决于 α \alpha α。

相关推荐
cici158742 分钟前
基于正交匹配追踪(OMP)算法的信号稀疏分解MATLAB实现
数据库·算法·matlab
sa1002715 分钟前
基于Python的京东评论爬虫
开发语言·爬虫·python
Jeremy爱编码18 分钟前
leetcode热题组合总和
算法·leetcode·职场和发展
努力学算法的蒟蒻28 分钟前
day57(1.8)——leetcode面试经典150
算法·leetcode·面试
言之。29 分钟前
大模型 API 中的 Token Log Probabilities(logprobs)
人工智能·算法·机器学习
Cigaretter735 分钟前
Day 38 早停策略和模型权重的保存
python·深度学习·机器学习
自然数e39 分钟前
c++多线程【多线程常见使用以及几个多线程数据结构实现】
数据结构·c++·算法·多线程
黛色正浓42 分钟前
leetCode-热题100-普通数组合集(JavaScript)
java·数据结构·算法
元亓亓亓1 小时前
LeetCode热题100--5. 最长回文子串--中等
linux·算法·leetcode
sunywz1 小时前
【JVM】(2)java类加载机制
java·jvm·python