什么是梯度
-
要明白什么是梯度下降法,我们需要先知道梯度是什么。
-
梯度的定义
梯度是一个向量 ,表示多元函数在某一点处变化最快的方向 。对于多元函数 f(x1,x2,...,xn) ,其梯度记作 ∇f 或 grad f,定义为所有偏导数组成的向量:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∇ f = ( ∂ f ∂ x 1 , ∂ f ∂ x 2 , ⋅ ⋅ ⋅ , ∂ f ∂ x n ) ∇f=(\frac{∂f}{∂x_1},\frac{∂f}{∂x_2},···,\frac{∂f}{∂x_n}) </math>∇f=(∂x1∂f,∂x2∂f,⋅⋅⋅,∂xn∂f)- 其中, <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ f ∂ x i \frac{∂f}{∂x_i} </math>∂xi∂f是函数 <math xmlns="http://www.w3.org/1998/Math/MathML"> f f </math>f对 <math xmlns="http://www.w3.org/1998/Math/MathML"> x i x_i </math>xi的偏导数
-
梯度的几何意义
-
方向 :梯度指向函数在该点处局部上升最快的方向。
-
大小:梯度的大小表示函数在该方向的变化率。
-
-
举例说明
多元函数 <math xmlns="http://www.w3.org/1998/Math/MathML"> f ( x , y , z ) = x 2 + y 3 + 2 z f(x,y,z)=x^2+y^3+2z </math>f(x,y,z)=x2+y3+2z
- 该函数的梯度为∇ <math xmlns="http://www.w3.org/1998/Math/MathML"> f f </math>f=( <math xmlns="http://www.w3.org/1998/Math/MathML"> 2 x , 3 y , 2 2x, 3y, 2 </math>2x,3y,2)
- 在点(1,1,1)处,梯度为(2,3,2),指向函数上升最快的方向(右上方向)。
- 梯度下降法会沿(-2,-3,-2)方向更新参数,逐步逼近最小值点。
梯度下降法
-
定义:
**梯度下降(Gradient Descent)**是机器学习和优化领域中用于最小化目标函数(通常是损失函数)的核心迭代算法。其核心思想是通过计算函数的梯度来找到函数值下降最快的方向,并沿该方向调整参数以逐步逼近最小值。目的就是为了寻找函数的最小值点。
-
为什么需要使用梯度下降
训练神经网络本质上是一个优化问题 :我们需要找到一组参数(权重和偏置),使得神经网络的输出与真实值之间的误差(即损失函数)最小化。由于神经网络可能有数百万甚至数十亿个参数,损失函数是这些参数的复杂非线性函数。解析解(如直接求导数为零的方程)在数学上不可行。而梯度下降是通过通过迭代局部优化,逐步逼近最优解,避免直接求解不可行的全局优化问题,从而使神经网络的误差最小。
-
数学实现
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> x t + 1 = x t − η ∇ f ( x t ) x_{t+1}=x_t-η∇f(x_t) </math>xt+1=xt−η∇f(xt)
- <math xmlns="http://www.w3.org/1998/Math/MathML"> x t + 1 x_{t+1} </math>xt+1是本次梯度下降后的位置。
- <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt是本此梯度下降前的位置。
- <math xmlns="http://www.w3.org/1998/Math/MathML"> η η </math>η是学习率(步长),控制更新幅度。
- ∇ <math xmlns="http://www.w3.org/1998/Math/MathML"> f ( x t ) f(x_t) </math>f(xt)是当前参数处的梯度。
-
终止条件
- 梯度接近0(达到局部最小值)。
- 迭代次数到达预设上限。
- 函数值变化小于阈值。
-
具体步骤
**1、初始化参数:**随机初始化模型的参数(例如权重和偏置)。
2、计算梯度损失:使用当前参数计算损失函数关于这些参数的梯度。
3、更新参数:将每个参数沿着梯度的方向反方向移动一步,步长由学习率控制。
4、重复迭代:重复计算梯度和更新参数,直到满足终止条件。
梯度下降法的实例演示
-
举例说明
-
一维函数梯度下降
如我们举例一个一维函数 <math xmlns="http://www.w3.org/1998/Math/MathML"> f ( x ) = x 2 f(x)=x^2 </math>f(x)=x2该函数的梯度也就是它的导数即 <math xmlns="http://www.w3.org/1998/Math/MathML"> f ′ ( x ) = 2 x f'(x)=2x </math>f′(x)=2x。
-
首先,我们需要确定起始点为 <math xmlns="http://www.w3.org/1998/Math/MathML"> x 0 x_0 </math>x0=1,学习率选择为 <math xmlns="http://www.w3.org/1998/Math/MathML"> η η </math>η=0.2
-
那么整个迭代过程如下:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> 迭代 1 : x 1 = x 0 − 0.2 ∗ f ′ ( x 0 ) = 1 − 0.2 ∗ 2 ∗ 1 = 0.6 迭代 2 : x 2 = x 1 − 0.2 ∗ f ′ ( x 1 ) = 0.6 − 0.2 ∗ 2 ∗ 0.6 = 0.4 迭代 3 : x 3 = x 2 − 0.2 ∗ f ′ ( x 2 ) = 0.4 − 0.2 ∗ 2 ∗ 0.4 = 0.2 ⋅ ⋅ ⋅ 迭代 17 : x 17 = x 16 − 0.2 ∗ f ′ ( x 16 ) = 0.0003 − 0.2 ∗ 2 ∗ 0.0003 = 0.0002 迭代 18 : x 18 = x 17 − 0.2 ∗ f ′ ( x 17 ) = 0.0002 − 0.2 ∗ 2 ∗ 0.0002 = 0.0001 迭代 1: x1 = x0 - 0.2 * f'(x0) = 1 - 0.2 * 2 * 1 = 0.6\\ 迭代 2: x2 = x1 - 0.2 * f'(x1) = 0.6 - 0.2 * 2 * 0.6 = 0.4\\ 迭代 3: x3 = x2 - 0.2 * f'(x2) = 0.4 - 0.2 * 2 * 0.4 = 0.2\\ ···\\ 迭代 17: x17 = x16 - 0.2 * f'(x16) = 0.0003 - 0.2 * 2 * 0.0003 = 0.0002\\ 迭代 18: x18 = x17 - 0.2 * f'(x17) = 0.0002 - 0.2 * 2 * 0.0002 = 0.0001 </math>迭代1:x1=x0−0.2∗f′(x0)=1−0.2∗2∗1=0.6迭代2:x2=x1−0.2∗f′(x1)=0.6−0.2∗2∗0.6=0.4迭代3:x3=x2−0.2∗f′(x2)=0.4−0.2∗2∗0.4=0.2⋅⋅⋅迭代17:x17=x16−0.2∗f′(x16)=0.0003−0.2∗2∗0.0003=0.0002迭代18:x18=x17−0.2∗f′(x17)=0.0002−0.2∗2∗0.0002=0.0001
-
可见在数次迭代后,已经接近函数的最小值了。
-
代码实现以上过程:
pythonimport numpy as np import matplotlib.pyplot as plt # 定义函数和它的导数 def f(x): return x ** 2 def df(x): return 2 * x # 梯度下降参数 x = 1.0 # 初始点 lr = 0.2 # 学习率 max_iter = 20 # 最大迭代次数 # 存储每次迭代的x和f(x)值用于绘图 x_history = [] f_history = [] # 梯度下降过程 for i in range(max_iter): x_history.append(x) f_history.append(f(x)) # 计算梯度并更新x grad = df(x) x_ = x x = x - lr * grad print(f"迭代 {i + 1}: x{i + 1} = x{i} - {lr:.1f} * f'(x{i}) = {x_:.1g} - {lr:.1g} * 2 * {x_:.1g} = {x:.1g}") # 绘制函数曲线和梯度下降过程 x_vals = np.linspace(-1.5, 1.5, 100) plt.plot(x_vals, f(x_vals), label='f(x) = x²') plt.scatter(x_history, f_history, c='r', label='Gradient Descent Steps') plt.plot(x_history, f_history, 'r--') # 添加箭头显示下降方向 for i in range(len(x_history) - 1): plt.annotate('', xy=(x_history[i + 1], f_history[i + 1]), xytext=(x_history[i], f_history[i]), arrowprops=dict(arrowstyle='->', color='green')) plt.xlabel('x') plt.ylabel('f(x)') plt.title('Gradient Descent for f(x) = x²') plt.legend() plt.grid(True) plt.show()
-
图示
-
-
多维函数梯度下降
如我们举例一个二维函数 <math xmlns="http://www.w3.org/1998/Math/MathML"> f ( x , y ) = x 2 + y 2 f(x,y)=x^2+y^2 </math>f(x,y)=x2+y2,该函数的梯度为∇ <math xmlns="http://www.w3.org/1998/Math/MathML"> f = ( 2 x , 2 y ) f=(2x,2y) </math>f=(2x,2y)
-
首先,我们确定起始点为 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( x 0 , y 0 ) = ( 1 , 3 ) (x_0,y_0)=(1,3) </math>(x0,y0)=(1,3),学习率为 <math xmlns="http://www.w3.org/1998/Math/MathML"> η η </math>η=0.1
-
那么整个迭代过程如下:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> 迭代 1 : ( x 1 , y 1 ) = ( x 0 , y 0 ) − 0.1 ∗ ∇ f = ( 1.0 , 3.0 ) − 0.1 ∗ ( 2.0 , 6.0 ) = ( 0.8 , 2.4 ) 迭代 2 : ( x 2 , y 2 ) = ( x 1 , y 1 ) − 0.1 ∗ ∇ f = ( 0.8 , 2.4 ) − 0.1 ∗ ( 1.6 , 4.8 ) = ( 0.64 , 1.92 ) 迭代 3 : ( x 3 , y 3 ) = ( x 2 , y 2 ) − 0.1 ∗ ∇ f = ( 0.64 , 1.92 ) − 0.1 ∗ ( 1.28 , 3.84 ) = ( 0.512 , 1.536 ) ⋅ ⋅ ⋅ 迭代 10 : ( x 10 , y 10 ) = ( x 9 , y 9 ) − 0.1 ∗ ∇ f = ( 0.134217728 , 0.402653184 ) − 0.1 ∗ ( 0.268435456 , 0.805306368 ) = ( 0.1073741824 , 0.3221225472 ) 迭代 1: (x1, y1) = (x0, y0) - 0.1 * ∇f = (1.0, 3.0) - 0.1 * (2.0, 6.0) = (0.8, 2.4)\\ 迭代 2: (x2, y2) = (x1, y1) - 0.1 * ∇f = (0.8, 2.4) - 0.1 * (1.6, 4.8) = (0.64, 1.92)\\ 迭代 3: (x3, y3) = (x2, y2) - 0.1 * ∇f = (0.64, 1.92) - 0.1 * (1.28, 3.84) = (0.512, 1.536)\\ ···\\ 迭代 10: (x10, y10) = (x9, y9) - 0.1 * ∇f = (0.134217728, 0.402653184) - 0.1 * (0.268435456, 0.805306368) = (0.1073741824, 0.3221225472) </math>迭代1:(x1,y1)=(x0,y0)−0.1∗∇f=(1.0,3.0)−0.1∗(2.0,6.0)=(0.8,2.4)迭代2:(x2,y2)=(x1,y1)−0.1∗∇f=(0.8,2.4)−0.1∗(1.6,4.8)=(0.64,1.92)迭代3:(x3,y3)=(x2,y2)−0.1∗∇f=(0.64,1.92)−0.1∗(1.28,3.84)=(0.512,1.536)⋅⋅⋅迭代10:(x10,y10)=(x9,y9)−0.1∗∇f=(0.134217728,0.402653184)−0.1∗(0.268435456,0.805306368)=(0.1073741824,0.3221225472)
-
经过多次迭代,数据已经十分接近最小值(0,0)了。
-
图示
-
-