在神经网络的训练过程中,误差反向传播法是一种非常重要的算法。它通过计算损失函数对每个参数的梯度,从而更新参数以最小化损失函数。而在这个过程中,链式法则起到了至关重要的作用。本Day将深入探讨神经网络和复合函数的关系、单变量与多变量函数的链式法则。
神经网络和复合函数
-
神经网络本质上是一个复杂的复合函数。每一层神经网络都可以看作是一个函数,整个网络则是由这些函数复合而成的。例如,一个简单的三层神经网络可以表示为:
y = f 3 ( f 2 ( f 1 ( x ) ) ) y = f_3(f_2(f_1(x))) y=f3(f2(f1(x)))- 其中, f 1 f_1 f1、 f 2 f_2 f2 和 f 3 f_3 f3 分别代表神经网络的第一层、第二层和第三层的函数。这种复合函数的结构使得神经网络能够学习并表示复杂的非线性关系。
-
在神经网络中,激活函数是构成复合函数的关键部分。例如,常用的Sigmoid激活函数可以表示为:
σ ( x ) = 1 1 + e − x \sigma(x) = \frac{1}{1 + e^{-x}} σ(x)=1+e−x1-
当我们将这个激活函数应用到神经网络的某一层时,假设该层的输入是线性组合 W x + b Wx + b Wx+b,那么该层的输出就是:
y = σ ( W x + b ) y = \sigma(Wx + b) y=σ(Wx+b)- 这里, y y y是关于 x x x的一个复合函数,它首先通过线性变换 W x + b Wx + b Wx+b,然后应用Sigmoid函数。
-
单变量函数的链式法则
-
链式法则是微积分中用于计算复合函数导数的基本法则。对于单变量函数,链式法则可以表示为:
d y d x = d y d u ⋅ d u d x \frac{dy}{dx} = \frac{dy}{du} \cdot \frac{du}{dx} dxdy=dudy⋅dxdu- 其中, y = f ( u ) y = f(u) y=f(u) 且 u = g ( x ) u = g(x) u=g(x)。这意味着,要计算 y y y 关于 x x x 的导数,我们需要先计算 y y y 关于 u u u 的导数,然后乘以 u u u 关于 x x x 的导数。
示例
假设有一个复合函数 y = ( x 2 + 1 ) 3 y = (x^2 + 1)^3 y=(x2+1)3,我们可以将其分解为 u = x 2 + 1 u = x^2 + 1 u=x2+1 和 y = u 3 y = u^3 y=u3。根据链式法则,有:
d y d x = d y d u ⋅ d u d x = 3 u 2 ⋅ 2 x = 3 ( x 2 + 1 ) 2 ⋅ 2 x \frac{dy}{dx} = \frac{dy}{du} \cdot \frac{du}{dx} = 3u^2 \cdot 2x = 3(x^2 + 1)^2 \cdot 2x dxdy=dudy⋅dxdu=3u2⋅2x=3(x2+1)2⋅2x
三个函数的复合函数的链式法则。与两个变量的情形一样,可以像分数一样进行计算。
- 当 y 为 u 的函数,u 为 v 的函数,v 为 x 的函数时,有
d y d x = d y d u ⋅ d u d x \frac{dy}{dx} = \frac{dy}{du} \cdot \frac{du}{dx} dxdy=dudy⋅dxdu
多变量函数的链式法则
- 对于多变量函数,链式法则同样适用,但形式会稍微复杂一些。假设 z z z 是 u u u 和 v v v 的函数,而 u u u 和 v v v 又是 x x x 和 y y y 的函数,则 z z z 关于 x x x 的偏导数可以表示为:
∂ z ∂ x = ∂ z ∂ u ⋅ ∂ u ∂ x + ∂ z ∂ v ⋅ ∂ v ∂ x \frac{\partial z}{\partial x} = \frac{\partial z}{\partial u} \cdot \frac{\partial u}{\partial x} + \frac{\partial z}{\partial v} \cdot \frac{\partial v}{\partial x} ∂x∂z=∂u∂z⋅∂x∂u+∂v∂z⋅∂x∂v
示例
考虑一个复合函数 z = sin ( x 2 + y 2 ) z = \sin(x^2 + y^2) z=sin(x2+y2),我们可以将其分解为 u = x 2 u = x^2 u=x2、 v = y 2 v = y^2 v=y2 和 z = sin ( u + v ) z = \sin(u + v) z=sin(u+v)。根据链式法则,有:
∂ z ∂ x = ∂ z ∂ u ⋅ ∂ u ∂ x + ∂ z ∂ v ⋅ ∂ v ∂ x = cos ( u + v ) ⋅ 2 x + 0 = 2 x cos ( x 2 + y 2 ) \frac{\partial z}{\partial x} = \frac{\partial z}{\partial u} \cdot \frac{\partial u}{\partial x} + \frac{\partial z}{\partial v} \cdot \frac{\partial v}{\partial x} = \cos(u + v) \cdot 2x + 0 = 2x\cos(x^2 + y^2) ∂x∂z=∂u∂z⋅∂x∂u+∂v∂z⋅∂x∂v=cos(u+v)⋅2x+0=2xcos(x2+y2)
与上式 一样,下式也成立。
∂ z ∂ y = ∂ z ∂ u ⋅ ∂ u ∂ y + ∂ z ∂ v ⋅ ∂ v ∂ y \frac{\partial z}{\partial y} = \frac{\partial z}{\partial u} \cdot \frac{\partial u}{\partial y} + \frac{\partial z}{\partial v} \cdot \frac{\partial v}{\partial y} ∂y∂z=∂u∂z⋅∂y∂u+∂v∂z⋅∂y∂v
链式法则的应用
下一Day开始深入解释梯度的概念与计算,这里先"预热"一下。
-
在神经网络的训练过程中,误差反向传播法通过链式法则计算损失函数对每一层权重的梯度,从而更新权重以最小化损失函数。具体来说,从输出层开始,逐层计算每个节点的误差对下层节点的梯度,并将这些梯度反向传播回网络各层。
-
在误差反向传播法中,我们需要计算损失函数关于每个参数的梯度。假设我们有一个包含多个层的神经网络,每一层都有一个损失函数 L i L_i Li。那么,总的损失函数可以表示为:
L = L n ( . . . ( L 2 ( L 1 ( x ) ) ) L = L_n(...(L_2(L_1(x))) L=Ln(...(L2(L1(x)))- 其中, L i L_i Li 表示第 i i i 层的损失函数。根据链式法则,我们可以计算出损失函数关于每个参数的梯度:
-
∂ L ∂ θ i = ∂ L ∂ L n ⋅ ∂ L n ∂ θ i + ∂ L ∂ L n − 1 ⋅ ∂ L n − 1 ∂ θ i + . . . + ∂ L ∂ L 1 ⋅ ∂ L 1 ∂ θ i \frac{\partial L}{\partial \theta_i} = \frac{\partial L}{\partial L_n} \cdot \frac{\partial L_n}{\partial \theta_i} + \frac{\partial L}{\partial L_{n-1}} \cdot \frac{\partial L_{n-1}}{\partial \theta_i} + ... + \frac{\partial L}{\partial L_1} \cdot \frac{\partial L_1}{\partial \theta_i} ∂θi∂L=∂Ln∂L⋅∂θi∂Ln+∂Ln−1∂L⋅∂θi∂Ln−1+...+∂L1∂L⋅∂θi∂L1
其中, θ i \theta_i θi 表示第 i i i 层的参数。
- 通过上述方式,可以逐层计算损失函数关于每个参数的梯度,从而实现参数的更新。
示例
假设我们有一个简单的神经网络层,其输出为 a = σ ( w x + b ) a = \sigma(wx + b) a=σ(wx+b),其中 σ \sigma σ 是激活函数, w w w 是权重, x x x 是输入, b b b 是偏置。我们需要计算损失函数 L L L 对权重 w w w 的梯度 ∂ L ∂ w \frac{\partial L}{\partial w} ∂w∂L。根据链式法则,有:
∂ L ∂ w = ∂ L ∂ a ⋅ ∂ a ∂ z ⋅ ∂ z ∂ w \frac{\partial L}{\partial w} = \frac{\partial L}{\partial a} \cdot \frac{\partial a}{\partial z} \cdot \frac{\partial z}{\partial w} ∂w∂L=∂a∂L⋅∂z∂a⋅∂w∂z
- 其中, z = w x + b z = wx + b z=wx+b 且 a = σ ( z ) a = \sigma(z) a=σ(z)。进一步计算得:
∂ L ∂ w = ∂ L ∂ a ⋅ σ ′ ( z ) ⋅ x = σ ( z ) \frac{\partial L}{\partial w} = \frac{\partial L}{\partial a} \cdot \sigma'(z) \cdot x = \sigma(z) ∂w∂L=∂a∂L⋅σ′(z)⋅x=σ(z)