机器学习--线性回归

线性回归

引入

  我们在高中的时候都学过线性回归,在这我们回顾一下

  在高中的课程中,我们会被给得到一组 x m {x_m} xm 和一组 y m {y_m} ym,然后我们想用 h ( x ) = y ^ = a x + b h(x) = \hat{y} = ax + b h(x)=y^=ax+b 来拟合这组数据使得整体上 h ( x i ) ≈ y i h(x_i) \approx y_i h(xi)≈yi 也就是这样:

  而在高中教材中 a a a 和 b b b 都给出的准确的计算式(其中 x ˉ \bar{x} xˉ 和 y ˉ \bar{y} yˉ 是 x m {x_m} xm 和 y m {y_m} ym 的平均数):

b ^ = ∑ i = 1 m ( x i − x ˉ ) ( y i − y ˉ ) ∑ i = 1 m ( x i − x ˉ ) 2 = ∑ i = 1 m x i y i − m x ˉ y ˉ ∑ i = 1 m x i 2 − m x ˉ 2 a = y ˉ − b ^ x ˉ \begin{aligned} \hat{b} = & \frac{\sum\limits_{i = 1}^m(x_i - \bar{x})(y_i - \bar{y})}{\sum\limits_{i = 1}^m(x_i-\bar{x})^2} = \frac{\sum\limits_{i = 1}^mx_iy_i - m\bar{x}\bar{y}}{\sum\limits_{i = 1}^mx_i^2 - m\bar{x}^2} \\ a = & \bar{y} - \hat{b}\bar{x} \end{aligned} b^=a=i=1∑m(xi−xˉ)2i=1∑m(xi−xˉ)(yi−yˉ)=i=1∑mxi2−mxˉ2i=1∑mxiyi−mxˉyˉyˉ−b^xˉ

  而在机器学习中,我们希望计算机通过某种算法学习出一组 a , b a, b a,b,使得这条直线满足上述要求。

进入正题

  在我们要解决的线性回归问题中,我们不仅不能使用上述已知的表达式解,我们还可以将维度升高。比如我们的数组 x m {x_m} xm 中的每个 x ( i ) x^{(i)} x(i) 不再是一个标量,而是一个 n n n 维的向量,同理我们的 w ( i ) w^{(i)} w(i) 也是一个 n n n 维的向量:

x ( i ) = x 1 ( i ) x 2 ( i ) ⋮ x n ( i )              w = w 1 ( i ) w 2 ⋮ w n x^{(i)} = \begin{bmatrix} x^{(i)}_1 \\ x^{(i)}_2 \\ \vdots \\ x^{(i)}_n \end{bmatrix} \;\;\;\;\;\; w = \begin{bmatrix} w^{(i)}_1 \\ w_2 \\\vdots \\ w_n \end{bmatrix} x(i)= x1(i)x2(i)⋮xn(i) w= w1(i)w2⋮wn

  然后我们就要用我们构造出来的 h ( x ( i ) ) h(x^{(i)}) h(x(i)) 来拟合 y ( i ) y^{(i)} y(i):

y ^ = h ( x ( i ) ) = ∑ j = 1 n w j x j ( i ) + b = w T x ( i ) + b \hat{y} = h(x^{(i)}) = \sum_{j = 1}^n w_jx^{(i)}_j + b = w^Tx^{(i)} + b y^=h(x(i))=j=1∑nwjxj(i)+b=wTx(i)+b

  在这里,我们定义 J ( w , b ) = 1 2 m ∑ i = 1 m ( h ( x ( i ) ) − y ( i ) ) 2 J(w, b) = \frac 1{2m}\sum\limits_{i = 1}^m(h(x^{(i)}) - y^{(i)})^2 J(w,b)=2m1i=1∑m(h(x(i))−y(i))2 叫做损失函数 c o s t    f u n c t i o n cost\;function costfunction,这个损失函数描述的是预测值 y ^ \hat{y} y^ 与真实值 y y y 的差异程度,上述表达式很好的符合这个定义。

  并且显然的,我们希望这个 J ( w , b ) J(w, b) J(w,b) 越小越好。当 J ( w , b ) J(w, b) J(w,b) 取到最小值时,我们认为 h ( x ) h(x) h(x) 对 y y y 的拟合达到最完美。

梯度下降---最小化 J ( w , b ) J(w, b) J(w,b)

  我们觉得写成 h ( x ) = w T x + b h(x) = w^Tx + b h(x)=wTx+b 太丑了,所以我们考虑令 x 0 ( i ) = 1 x^{(i)}_0 = 1 x0(i)=1,并新增 w 0 w_0 w0 使得每个 x ( i ) x^{(i)} x(i) 和 w w w 都变成一个 n + 1 n + 1 n+1 维的向量:

x ( i ) = x 0 ( i ) x 1 ( i ) x 2 ( i ) ⋮ x n ( i )              w = w 0 w 1 ( i ) w 2 ⋮ w n x^{(i)} = \begin{bmatrix} x^{(i)}_0 \\ x^{(i)}_1 \\ x^{(i)}_2 \\ \vdots \\ x^{(i)}_n \end{bmatrix} \;\;\;\;\;\; w = \begin{bmatrix} w_0 \\ w^{(i)}_1 \\ w_2 \\\vdots \\ w_n \end{bmatrix} x(i)= x0(i)x1(i)x2(i)⋮xn(i) w= w0w1(i)w2⋮wn

  于是 h ( x ) = w T x h(x) = w^Tx h(x)=wTx 这样就好看多了qwq,并且此时 J ( w ) = 1 2 m ∑ i = 1 m ( h ( x ( i ) ) − y ( i ) ) 2 J(w) = \frac 1{2m}\sum\limits_{i = 1}^m(h(x^{(i)}) - y^{(i)})^2 J(w)=2m1i=1∑m(h(x(i))−y(i))2

  我们想做的就是一下描述的事情:

  1. 开始时对 w w w 随机赋值
  2. 持续改变 w w w 的值,使得 J ( w ) J(w) J(w) 的值减小,直到我们最后走到期望的最小值

  然后这里我们介绍一种名为梯度下降 g r a d i e n t    d e s c e n t gradient \; descent gradientdescent 的算法来执行第二部,感性的理解就是在某一点找到 "下山" 的方向,并沿着这个方向走一步,并且重复这个动作知道走到 "山脚下"。

  具体来说就是每次计算 J ( w ) J(w) J(w) 对每个 w i w_i wi 的偏导数,然后执行:

f o r    i    r a n g e    f r o m    0    t o    n {                  w j : = w j − α ∂ ∂ w j J ( w ) } \begin{aligned} &for \; i \; range \; from \; 0 \; to \; n \{ \\ &\;\;\;\;\;\;\;\; w_j := w_j - \alpha \frac{\partial}{\partial w_j}J(w) \\ &\} \end{aligned} forirangefrom0ton{wj:=wj−α∂wj∂J(w)}

  其中 α \alpha α 是 步长 l e a r n i n g    r a t e learning \; rate learningrate,表示这一步的大小。

  这样看着可能不太直观,我们举个二维的例子更清晰的展示一下:

  这里,这个红色的曲线就是我们的 J ( w ) J(w) J(w)。我们在起点对 J ( w ) J(w) J(w) 求导,得到了蓝色的切线的斜率,这个斜率显然是个正数,于是我们沿着负方向走一步,这一步的大小取决于 α \alpha α。

相关推荐
地平线开发者4 小时前
J6B vio scenario sample
算法
阿里云大数据AI技术9 小时前
光轮智能 × 阿里云:共建 Physical AI 云上数据、评测与持续学习基础设施
人工智能·机器学习
SelectDB10 小时前
Apache Doris Python UDF:让 SQL 直接调用 Python 生态,支撑 Agent 时代复杂业务逻辑
大数据·数据库·python
BothSavage16 小时前
Trae远程开发中DeepSeek自定义模型4054错误的排查与修复
算法
小林ixn16 小时前
从暴力到KMP:一道题彻底搞懂字符串匹配的前世今生
算法
烬羽18 小时前
字符串算法入门:从反转字符串到回文判断,面试不再慌
算法·面试
荣码18 小时前
GraphRAG:普通RAG只能回答"点"的问题,我踩了4个坑才搞懂
java·python
金銀銅鐵1 天前
[Python] 基于欧几里得算法,实现分数约分计算器
python·数学
Lyn_Li1 天前
Kaggle Top 5 | 198只股票、200条数据的金融预测——BattleFin高分方案从零复现
python·kaggle·比赛复盘·金融预测