交叉熵损失函数为代表的两层神经网络的反向传播量化求导计算公式

反向传播(back propagation,BP)算法也称误差逆传播,是神经网络训练的核心算法。我们通常说的 BP 神经网络是指应用反向传播算法进行训练的神经网络模型。反向传播算法的工作机制究竟是怎样的呢?我们以一个两层(即单隐层)网络为例,也就是图 8-5 中的网络结构,给出反向传播的基本推导过程。

假设输入层为 x x x ,有 m m m个训练样本,输入层与隐藏层之间的权重和偏置分别为 w 1 w_1 w1 和 b 1 b_1 b1,线性加权计算结果为: z 1 = w 1 x + b 1 z_1 = w_1 x + b_1 z1=w1x+b1,采用 Sigmoid 激活函数,激活输出为: a 1 = σ ( z 1 ) a_1 = \sigma(z_1) a1=σ(z1)

而隐藏层到输出层的权重和偏置分别为 w 2 w_2 w2 和 b 2 b_2 b2,线性加权计算结果为: z 2 = w 2 x + b 2 z_2 = w_2 x + b_2 z2=w2x+b2,激活输出为: a 2 = σ ( z 2 ) a_2 = \sigma(z_2) a2=σ(z2)。所以,这个两层网络的前向计算过程是为: x → z 1 → a 1 → z 2 → a 2 x → z_1 → a_1 → z_2→a_2 x→z1→a1→z2→a2

直观而言,反向传播就是将前向计算过程反过来,但必须是梯度计算的方向反过来,假设这里采用如下交叉熵损失函数:
L ( y , a ) = − ( y log ⁡ a + ( 1 − y ) log ⁡ ( 1 − a ) ) (8-11) L(y, a) = -(y \log a + (1 - y) \log (1 - a)) \tag{8-11} L(y,a)=−(yloga+(1−y)log(1−a))(8-11)

反向传播是基于梯度下降策略的,主要是从目标参数的负梯度方向更新参数,所以基于损失函数对前向计算过程中各个变量进行梯度计算是关键。将前向计算过程反过来,基于损失函数的梯度计算顺序就是 d a 2 → d z 2 → d w 2 → d b 2 → d a 1 → d z 1 → d w 1 → d b 1 da_2→ dz_2 → dw_2 → db_2→da_1→ dz_1→ dw_1 → db_1 da2→dz2→dw2→db2→da1→dz1→dw1→db1

首先,计算损失函数 L ( y , a 2 ) L(y, a_2) L(y,a2) 关于 a 2 a_2 a2 的导数 d a 2 da_2 da2,影响输出 a 2 a_2 a2 的是谁呢?由前向传播可知, a 2 a_2 a2 是由 z 2 z_2 z2 经激活函数激活后计算而来的,所以计算损失函数关于 z 2 z_2 z2 的导数 d z 2 dz_2 dz2,必须经过 a 2 a_2 a2 进行复合函数求导,即微积分中常说的链式求导法则。然后继续往前推导,影响 z 2 z_2 z2 的又是哪些变量呢?由前向计算可知, z 2 = w 2 x + b 2 z_2 = w_2x + b_2 z2=w2x+b2,影响 z 2 z_2 z2 的有 w 2 w_2 w2, a 1 a_1 a1 和 b 2 b_2 b2,继续按照链式求导法则进行求导即可。最终以交叉熵损失函数为代表的两层神经网络的反向传播量化求导计算公式如下:

∂ L ∂ a 2 = d d a 2 L ( a 2 , y ) = ( − y log ⁡ a 2 − ( 1 − y ) log ⁡ ( 1 − a 2 ) ) ′ = − y a 2 + 1 − y 1 − a 2 (8-12) \frac{\partial L}{\partial a_2} = \frac{d}{da_2}L(a_2, y) = (-y\log a_2 - (1-y)\log(1-a_2))' = -\frac{y}{a_2}+ \frac{1-y}{1-a_2} \tag{8-12} ∂a2∂L=da2dL(a2,y)=(−yloga2−(1−y)log(1−a2))′=−a2y+1−a21−y(8-12)

∂ L ∂ Z 2 = ∂ L ∂ a 2 ∂ a 2 ∂ Z 2 = a 2 − y (8-13) \frac{\partial L}{\partial Z_2} = \frac{\partial L}{\partial a_2}\frac{\partial a_2}{\partial Z_2} = a_2 - y \tag{8-13} ∂Z2∂L=∂a2∂L∂Z2∂a2=a2−y(8-13)

∂ L ∂ w 2 = ∂ L ∂ a 2 ∂ a 2 ∂ Z 2 ∂ Z 2 ∂ w 2 = 1 m ∂ L ∂ Z 2 a 1 = 1 m ( a 2 − y ) a 1 (8-14) \frac{\partial L}{\partial w_2} = \frac{\partial L}{\partial a_2}\frac{\partial a_2}{\partial Z_2}\frac{\partial Z_2}{\partial w_2} = \frac{1}{m}\frac{\partial L}{\partial Z_2} a_1= \frac{1}{m}(a_2 - y)a_1 \tag{8-14} ∂w2∂L=∂a2∂L∂Z2∂a2∂w2∂Z2=m1∂Z2∂La1=m1(a2−y)a1(8-14)

∂ L ∂ b 2 = ∂ L ∂ a 2 ∂ a 2 ∂ Z 2 ∂ Z 2 ∂ b 2 = ∂ L ∂ Z 2 = a 2 − y (8-15) \frac{\partial L}{\partial b_2} = \frac{\partial L}{\partial a_2}\frac{\partial a_2}{\partial Z_2}\frac{\partial Z_2}{\partial b_2} = \frac{\partial L}{\partial Z_2} = a_2 - y \tag{8-15} ∂b2∂L=∂a2∂L∂Z2∂a2∂b2∂Z2=∂Z2∂L=a2−y(8-15)

∂ L ∂ a 1 = ∂ L ∂ a 2 ∂ a 2 ∂ Z 2 ∂ Z 2 ∂ a 1 = ( a 2 − y ) w 2 (8-16) \frac{\partial L}{\partial a_1} = \frac{\partial L}{\partial a_2}\frac{\partial a_2}{\partial Z_2}\frac{\partial Z_2}{\partial a_1} = (a_2 - y)w_2 \tag{8-16} ∂a1∂L=∂a2∂L∂Z2∂a2∂a1∂Z2=(a2−y)w2(8-16)

∂ L ∂ Z 1 = ∂ L ∂ a 2 ∂ a 2 ∂ Z 2 ∂ Z 2 ∂ a 1 ∂ a 1 ∂ Z 1 = ( a 2 − y ) w 2 σ ′ ( Z 1 ) (8-17) \frac{\partial L}{\partial Z_1} = \frac{\partial L}{\partial a_2}\frac{\partial a_2}{\partial Z_2}\frac{\partial Z_2}{\partial a_1} \frac{\partial a_1}{\partial Z_1} = (a_2 - y)w_2\sigma'(Z_1) \tag{8-17} ∂Z1∂L=∂a2∂L∂Z2∂a2∂a1∂Z2∂Z1∂a1=(a2−y)w2σ′(Z1)(8-17)

∂ L ∂ w 1 = ∂ L ∂ a 2 ∂ a 2 ∂ Z 2 ∂ Z 2 ∂ a 1 ∂ a 1 ∂ Z 1 ∂ Z 1 ∂ w 1 = ( a 2 − y ) w 2 σ ′ ( Z 1 ) x (8-18) \frac{\partial L}{\partial w_1} = \frac{\partial L}{\partial a_2}\frac{\partial a_2}{\partial Z_2}\frac{\partial Z_2}{\partial a_1} \frac{\partial a_1}{\partial Z_1}\frac{\partial Z_1}{\partial w_1} = (a_2 - y)w_2\sigma'(Z_1)x \tag{8-18} ∂w1∂L=∂a2∂L∂Z2∂a2∂a1∂Z2∂Z1∂a1∂w1∂Z1=(a2−y)w2σ′(Z1)x(8-18)

∂ L ∂ b 1 = ∂ L ∂ a 2 ∂ a 2 ∂ Z 2 ∂ Z 2 ∂ a 1 ∂ a 1 ∂ Z 1 ∂ Z 1 ∂ b 1 = ( a 2 − y ) w 2 σ ′ ( Z 1 ) (8-19) \frac{\partial L}{\partial b_1} = \frac{\partial L}{\partial a_2}\frac{\partial a_2}{\partial Z_2}\frac{\partial Z_2}{\partial a_1} \frac{\partial a_1}{\partial Z_1}\frac{\partial Z_1}{\partial b_1} = (a_2 - y)w_2\sigma'(Z_1) \tag{8-19} ∂b1∂L=∂a2∂L∂Z2∂a2∂a1∂Z2∂Z1∂a1∂b1∂Z1=(a2−y)w2σ′(Z1)(8-19)


以上公式具体的推导过程:

公式8-13:损失函数对输出层激活值Z2的导数

公式8-14:损失函数对输出层权重w2的梯度

公式8-15:损失函数对输出层偏置b2的梯度

公式8-16:损失函数对隐藏层激活值a1的梯度

公式8-17:损失函数对隐藏层加权输入Z1的导数

公式8-18:损失函数对隐藏层权重w1的梯度

公式8-19:损失函数对隐藏层偏置b1的梯度

相关推荐
武昌库里写JAVA6 分钟前
一文读懂Redis6的--bigkeys选项源码以及redis-bigkey-online项目介绍
c语言·开发语言·数据结构·算法·二维数组
禊月初三14 分钟前
LeetCode 4.寻找两个中序数组的中位数
c++·算法·leetcode
自不量力的A同学19 分钟前
微软发布「AI Shell」
人工智能·microsoft
学习使我飞升33 分钟前
spf算法、三类LSA、区间防环路机制/规则、虚连接
服务器·网络·算法·智能路由器
一点一木37 分钟前
AI与数据集:从零基础到全面应用的深度解析(超详细教程)
人工智能·python·tensorflow
花生糖@1 小时前
OpenCV图像基础处理:通道分离与灰度转换
人工智能·python·opencv·计算机视觉
2zcode1 小时前
基于YOLOv8深度学习的智慧农业棉花采摘状态检测与语音提醒系统(PyQt5界面+数据集+训练代码)
人工智能·深度学习·yolo
庞传奇1 小时前
【LC】560. 和为 K 的子数组
java·算法·leetcode
SoraLuna1 小时前
「Mac玩转仓颉内测版32」基础篇12 - Cangjie中的变量操作与类型管理
开发语言·算法·macos·cangjie
daiyang123...2 小时前
Java 复习 【知识改变命运】第九章
java·开发语言·算法