【机器学习】反向传播如何求梯度(公式推导)

写在前面

前期学习深度学习的时候,很多概念都是一笔带过,只是觉得它在一定程度上解释得通就行,但是在强化学习的过程中突然意识到,反向传播求梯度其实并不是一件简单的事情,这篇博客的目的就是要讲清楚反向传播是如何对特定的某一层求梯度,进而更新其参数的

为什么反向传播不容易实现

首先,众所周知,深度网络就是一个广义的多元函数,

但在通常情况下,想要求一个函数的梯度,就必须知道这个函数的具体表达式

但是问题就在于,深度网络的"传递函数"并不容易获得,或者说并不容易显式地获得

进而导致反向传播的过程难以进行

为什么反向传播可以实现

损失函数是关于参数的函数

如果要将一个函数F对一个变量x求偏导,那偏导存在的前提条件就是F是关于x的函数,否则求导结果就是0

  • 符号定义(后续公式均据此展开)
    • x=[x1,x2,...xn]x=[x_1,x_2,\dots x_n]x=[x1,x2,...xn]
    • ypred=[y1,y2,...ym]y_{pred}=[y_1,y_2, \dots y_m]ypred=[y1,y2,...ym]
    • θ=[w1,b1,w2,b2,...wn,bn]\theta=[w_1,b_1,w_2,b_2,\dots w_n,b_n]θ=[w1,b1,w2,b2,...wn,bn]
    • ai=第i层网络激活函数的输出,最后一层的输出就是ypreda^{i}=第 i层网络激活函数的输出,最后一层的输出就是y_{pred}ai=第i层网络激活函数的输出,最后一层的输出就是ypred
    • zi=第i层网络隐藏层的输出z^{i}=第i层网络隐藏层的输出zi=第i层网络隐藏层的输出
    • gi ′(zi)第i层激活函数的导数,在输入=zi处的值g^i\ '(z^i)第i层激活函数的导数,在输入=z^i处的值gi ′(zi)第i层激活函数的导数,在输入=zi处的值
  • 关系式
    • 网络的抽象函数式 ypred=F(x;θ)y_{pred}=F(x;\theta)ypred=F(x;θ)
      即网络就是一个巨大的多元函数,接受两个向量(模型输入和参数)作为输入,经过内部正向传播后输出一个向量
    • 损失函数Loss的抽象函数式 Loss=L(ytrue,ypred)=L(ytrue,F(x;θ))Loss=L(y_{true},y_{pred})=L(y_{true},F(x;\theta))Loss=L(ytrue,ypred)=L(ytrue,F(x;θ))
      其中 ytruey_{true}ytrue 和 xxx 属于参变量,它虽然会变,但是和模型本身没什么关系,唯一属于模型自己的变量就是 θ\thetaθ,所以不难看出,损失函数L是关于模型参数 θ\thetaθ 的函数,损失值Loss完全由模型参数 θ\thetaθ 决定

链式法则

  • 这个法则是深层网络得以实现梯度计算的关键

    核心公式如下:
    ∂L∂θi=∂L∂zi⋅∂zi∂θi \frac{\partial L}{\partial \theta^i}=\frac{\partial L}{\partial z^i}·\frac{\partial z^i}{\partial \theta^i} ∂θi∂L=∂zi∂L⋅∂θi∂zi

    其中,∂L∂zi\frac{\partial L}{\partial z^i}∂zi∂L是损失L对第i层加权输入ziz^izi的梯度,∂zi∂θi\frac{\partial z^i}{\partial \theta^i}∂θi∂zi是第i层加权输入ziz^izi对本层参数θi\theta^iθi的梯度

  • 进一步深究可以发现∂zi∂θi\frac{\partial z^i}{\partial \theta^i}∂θi∂zi相对容易求 ,因为它只涉及到当前层的当前神经元的求解,在面向对象语言中,很容易为每个属于同一个类的实例增加一个方法,比如像这里的输入对参数求导,举例来说;
    if θi=Wi and Zi=Wi∗ai−1+bi,then ∂zi∂θi=(ai−1)T if\ \theta^i=W^i\ and\ Z^i=W^i*a^{i-1}+b^i,\\ then\ \frac{\partial z^i}{\partial \theta^i}=(a^{i-1})^T if θi=Wi and Zi=Wi∗ai−1+bi,then ∂θi∂zi=(ai−1)T

    其中,

    (说实话,我非常想把隐藏层称为"传递函数",控制和机器学习实际上有非常多可以相互借鉴的地方,而且在事实上,二者也确实是不可分割的关系)

  • 然后我们要来处理相对麻烦的 ∂L∂zi\frac{\partial L}{\partial z^i}∂zi∂L

    • 多层感知机为例,共k层,已知网络输出,求网络第i层的梯度
    • 用数学归纳法在这种递归系统中比较合适
      • 归纳奠基
        L=L(ytrue,ypred)=L(ytrue,F(x;θ))L=L(y_{true},y_{pred})=L(y_{true},F(x;\theta))L=L(ytrue,ypred)=L(ytrue,F(x;θ))

        ∂L∂zk=∂L∂ak⋅∂ak∂zk=∂L∂ak⊗(gk)′(zk)\frac{\partial L}{\partial z^k}=\frac{\partial L}{\partial a^k}·\frac{\partial a^k}{\partial z^k}=\frac{\partial L}{\partial a^k}\otimes (g^k)'(z^k)∂zk∂L=∂ak∂L⋅∂zk∂ak=∂ak∂L⊗(gk)′(zk)

        上面的公式说明:损失对隐藏层输出的偏导,等价于损失函数 对最终输出的偏导 ,再逐元素乘 上最后层激活函数 在隐藏层输出处 的导数

        其中,激活函数在创建网络时就明确已知,因此求导取值并没有难度

        由于Loss=L(ytrue,ypred)Loss=L(y_{true},y_{pred})Loss=L(ytrue,ypred)直接与网络最终输出ypredy_{pred}ypred相关,因此损失对最终输出的偏导并不难求;

        比如将损失函数定义为均方差MSE:(其他网络基本同理)
        L=12∑j=1m(yj−ajk)2∂L∂ak=−(yj−ajk) L=\frac{1}{2}\sum^m_{j=1}(y_j-a_j^k)^2\\ \frac{\partial L}{\partial a^k}=-(y_j-a_j^k) L=21j=1∑m(yj−ajk)2∂ak∂L=−(yj−ajk)

      • 归纳递推(从第 i 层到第 i-1 层)

        假设已知 ∂L∂zi\frac{\partial L}{\partial z^i}∂zi∂L(反向传播,因此我们假设的是后一层已知)

        由链式法则可得:
        ∂L∂zi−1=(∂L∂zi)⋅(∂zi∂ai−1)⋅(∂ai−1∂zi−1) \frac{\partial L}{\partial z^{i-1}}=(\frac{\partial L}{\partial z^i})·(\frac{\partial z^i}{\partial a^{i-1}})·(\frac{\partial a^{i-1}}{\partial z^{i-1}}) ∂zi−1∂L=(∂zi∂L)⋅(∂ai−1∂zi)⋅(∂zi−1∂ai−1)

        其中,第一个因子已知

        第二个因子∂zi∂ai−1\frac{\partial z^i}{\partial a^{i-1}}∂ai−1∂zi,分子为第 i 层隐藏层的输出,分母为第 i 层隐藏层的输入(即第 i-1 层激活层的输出),因此其值就是第 i 层隐藏层的权重矩阵WiW^iWi本身

        第三个因子∂ai−1∂zi−1\frac{\partial a^{i-1}}{\partial z^{i-1}}∂zi−1∂ai−1,分子为第 i-1 层激活层的输出,分母为第 i-1 层激活层的输入,因此其值就是 第 i-1 层激活函数 在隐藏层输出处 的导数

        综上:在已知第 i 层损失对输出的梯度的情况下,可以推出第 i-1 层损失对输出的梯度,递推成立

      • 归纳总结

        综上所述,反向传播求梯度完全可行,按照上面的过程撰写程序,就可以很方便地反向逐层 根据损失梯度 更新参数

相关推荐
嘗_3 分钟前
机器学习/深度学习训练day1
人工智能·深度学习·机器学习
shelgi7 分钟前
unsloth微调Qwen3实现知识总结
人工智能·aigc
菜鸡00017 分钟前
存在两个cuda环境,在conda中切换到另一个
linux·人工智能·conda
阿里云大数据AI技术31 分钟前
阿里云 EMR Serverless Spark: 面向 Data+AI 的高性能 Lakehouse 产品
大数据·人工智能·数据分析
新智元40 分钟前
刚刚,H20重返中国!老黄亲自斡旋,还有特供版RTX PRO
人工智能·openai
我爱一条柴ya1 小时前
【AI大模型】BERT微调文本分类任务实战
人工智能·pytorch·python·ai·分类·bert·ai编程
学废了wuwu1 小时前
【终极指南】ChatGPT/BERT/DeepSeek分词全解析:从理论到中文实战
人工智能·chatgpt·bert
杨小扩1 小时前
AI驱动的软件工程(中):文档驱动的编码与执行
大数据·人工智能·软件工程
墨尘游子1 小时前
一文读懂循环神经网络(RNN)—语言模型+n元语法(1)
人工智能·python·rnn·深度学习·神经网络·语言模型
Listennnn2 小时前
Agent自动化与代码智能
人工智能·自动化