手撕深度学习:矩阵求导链式法则与矩阵乘法反向传播公式,深度学习进阶必备!

本文首发于本人的微信公众号,原文链接:https://mp.weixin.qq.com/s/TaWYEORpI06ycofXWSZ-Xg

摘要

本文总结了矩阵导数的本质定义,对矩阵求导链式法则公式进行了讨论,并使用矩阵求导链式法则公式推导了矩阵乘法的反向传播公式。

写在前面

最近在做CMU 10-714(Deep Learning Systems)的Homework 1时,遇到了一个困扰了我很久的问题:

矩阵乘法是如何进行反向传播的?

于是借此机会,我重新学了一遍矩阵微分。在花了一个下午推导出矩阵乘法的反向传播公式后,我决定用一篇文章来总结和记录一下这次的学习收获。

矩阵的导数是什么?

矩阵求导是深度学习里的一个基础操作,在通过反向传播求参数梯度时,就需要对一个矩阵函数进行求导,那么矩阵的导数是什么呢?

一句话概括就是:矩阵函数\(f(A) \to B\)可以被视为是有\(m\)个输入变量,\(n\)个输出变量的函数,其导数就是这\(n\)个输出变量分别针对这\(m\)个输入变量求导,最后经过排列得到的矩阵(或高维的张量)

也就是说,对这个矩阵函数求导会得到\(m \times n\)个导数,将这些导数按一定规则排列就得到了这个矩阵函数的导数。

例1:对于函数\(f(x) = x^T x = x_1^2 + x_2^2 + \cdots + x_n^2\),其中\(\mathbf{x} = [x_1, x_2, \ldots, x_n]^T\),这是一个输入变量为\(n\)个,输出变量为\(1\)个的函数。

那么对这个函数求导就会得到\(n\)个导数,即:对\(x_1\)求导得到\(2x_1\),对\(x_2\)求导得到\(2x_2\),以此类推。如果把它们按照和\(\mathbf{x}\)一致的排列方式排列起来,即排列为列向量,那就得到了这个向量函数导数的分母布局

(分子布局和分母布局这里不展开赘述了,一般我们在深度学习里使用的都是分母布局,因为分母布局可以直接求得梯度;两种布局只是导数的排列方式不同,其余并无区别)。

例2:对于函数\(f(x) = [x_1 + x_2, x_2(x_1 + x_3)]^T\),其中\(\mathbf{x} = [x_1, x_2, x_3]^T\),这是一个输入变量为\(3\)个,输出变量为\(2\)个的函数,那么求导就会得到\(6\)个导数,按照分母布局可以组成如下矩阵

\[\begin{bmatrix} 1 & x_2 \\ 1 & x_1 \\ 0 & x_2 \end{bmatrix} \]

例3:对于函数\(f(A, B) = A B\),其中\(A\)为\(m \times n\)的矩阵,\(B\)为\(n \times p\)的矩阵,这个函数同样可以视为输入\(m \times n\)个变量,输出\(n \times p\)个变量的函数。

那么导数一共有\(m \times n \times n \times p\)个。排列这些导数需要用到四维张量。

高维张量这里就不详细展开了,因为通常情况下我们都会用一些技巧来绕过对于高维张量的处理。

比如在对损失函数进行反向传播的时候,损失函数值一般都是标量,所以损失函数对矩阵求导的结果也一定是一个二维矩阵,这一事实表明我们可以通过套反向传播的公式来绕过对四维张量的显式构造。

反向传播

反向传播是目前深度学习框架自动求梯度的主流算法,其背后的数学依据是求导的链式法则

即\(\frac{dg(f(x))}{dx} = \frac{dg(f(x))}{df} \cdot \frac{df(x)}{dx}\),在构造出计算图之后就可以通过递归的方式求出导数,这部分可以见CMU 10-714的Lecture 4,这里不再赘述。

矩阵函数的链式法则

矩阵函数的链式法则和标量函数的类似,这里参考了《矩阵分析与应用(第二版)》(张贤达 著),书中148页给出了两种情形下的矩阵函数的链式法则公式:

当函数输出为标量时

公式如下:

\[\frac{\partial g(f(\boldsymbol{X}))}{\partial\boldsymbol{X}} = \frac{\mathrm{d} g(y)}{\mathrm{d} y} \frac{\partial f(\boldsymbol{X})}{\partial\boldsymbol{X}} \]

其中\(\boldsymbol{X}\)为\(m \times n\)矩阵,\(f(\boldsymbol{X})\)为标量,\(g(f(\boldsymbol{X}))\)为标量。

公式后面的\(\frac{\partial f(\boldsymbol{X})}{\partial\boldsymbol{X}}\)结果是一个二维矩阵,求导方式和上文提到的例2相同,这里不再赘述。

当函数输出为矩阵时

公式如下:

\[\left[\frac{\partial g(\boldsymbol{F})}{\partial\boldsymbol{X}}\right]{i j} = \frac{\partial g(\boldsymbol{F})}{\partial x{ij}} = \sum_{k=1}^{p} \sum_{l=1}^{q} \frac{\partial g(\boldsymbol{F})}{\partial f_{kl}} \frac{\partial f_{kl}}{\partial x_{ij}} \]

其中\(\boldsymbol{X}\)为\(m \times n\)矩阵,元素为\(x_{ij}\),\(\boldsymbol{F}(\boldsymbol{X})\)为\(p \times q\)矩阵,元素为\(f_{kl}\),\(g(\boldsymbol{F}(\boldsymbol{X}))\)为标量。

这个公式告诉我们,\(g\)关于\(\boldsymbol{X}\)求导后会得到一个二维矩阵,这个矩阵第\(i\)行\(j\)列的元素计算方式为公式最右边的那个计算式,代入计算式即可求出。

对公式来源的猜测

这个公式的来源是《The Matrix Cookbook》的第2.8.1节,但是这本书中也只是给出了该公式,没有额外的解释。

在和gpt-5探讨了一番后,有了如下的猜测:
\(\boldsymbol{X}\)为\(m \times n\)矩阵,\(\boldsymbol{F}(\boldsymbol{X})\)为\(p \times q\)矩阵,\(g(\boldsymbol{F}(\boldsymbol{X}))\)为标量,那么\(g\)对\(\boldsymbol{F}\)求导得到一个二维矩阵,\(\boldsymbol{F}\)对\(\boldsymbol{X}\)求导得到一个四维张量。

然后通过tensor contraction操作就得到了二维矩阵,这其中似乎还涉及到了Fréchet derivative。

这里由于知识水平不足,无法继续深究了。如果有数学专业的大佬对这一问题有了解,还恳请大佬指个路,也欢迎大家在评论区交流。

矩阵乘法的反向传播公式

在这一节里,我们将使用刚才学到的公式完成矩阵乘法反向传播公式的推导。

问题定义

有损失函数\(l(F(A,B))\),其中\(A\)为\(m \times n\)矩阵,\(B\)为\(n \times p\)矩阵,\(F(A, B) = A B\),且\(l\)的输出为标量,已知\(\partial I / \partial F\),要求\(\nabla = \partial l / \partial A\)和\(\partial I / \partial B\)

求解过程

这里再贴一次求解要用到的公式:

\[\left[\frac{\partial g(\boldsymbol{F})}{\partial\boldsymbol{X}}\right]{i j} = \frac{\partial g(\boldsymbol{F})}{\partial x{ij}} = \sum_{k=1}^{p} \sum_{l=1}^{q} \frac{\partial g(\boldsymbol{F})}{\partial f_{kl}} \frac{\partial f_{kl}}{\partial x_{ij}} \]

记已知的\(\partial I / \partial F = G\),\(G\)就是反向传播从后面的节点传播来的梯度。

那么公式里的\(\frac{\partial g(\boldsymbol{F})}{\partial f_{kl}}\)就是\(G_{k,l}\)

根据矩阵乘法的定义,可以得到\(f_{kl} = \sum_{q=1}^{n} a_{kq} b_{ql}\)

由于只有当\(k=i\)时,\(a_{kq}\)这一项才可能取到\(a_{ij}\),并且只有当\(q=j\)时,\(a_{kq}\)才是\(a_{ij}\),所以可以得到如下结果:

\[\frac{\partial f_{kl}}{\partial a_{ij}} = \begin{cases} 0, & k \neq i \\ b_{jl}, & k = i \end{cases} \]

由此可以得到,\(\nabla_{ij} = \sum_{k=1}^{m} \sum_{l=1}^{p} G_{kl} \begin{cases} 0, & k \neq i \\ b_{jl}, & k = i \end{cases} = \sum_{l=1}^{P} G_{il} \cdot b_{jl}\)

注意到,等式最右边的值恰好等于\(G \cdot B^{T}\)

所以最终得出结论:\(\nabla = G \cdot B^{T}\)

同理可得,\(\partial l / \partial B = A^{T} \cdot G\),具体推导过程和上面类似,就留作习题供读者自行练习了。

一些其他小知识点:

  • 提到\(n\)维向量,一般默认是列向量,即大小为\(n \times 1\)
  • 前向传播同样也可以进行自动求导,但是如果要对\(n\)个输入变量求梯度,就需要跑\(n\)次前向传播。相比之下,反向传播通常只需要跑一次就能获取到所有输入变量的梯度(因为通常最后的损失函数都只有一个),这是反向传播的优势之一。