张量链式法则(下篇):揭秘Transpose、Summation等复杂算子反向传播,彻底掌握深度学习求导精髓!

本文首发于本人的微信公众号,链接:https://mp.weixin.qq.com/s/eEDo6WF0oJtRvLYTeUYxYg

摘要

本文紧接系列的上篇,介绍了 transpose,summation,broadcast_to 等更为复杂的深度学习算子的反向传播公式推导。

写在前面

本系列文章的上篇介绍了张量函数链式法则公式,并以几个简单的算子为例子展示了公式的使用方法。本文将继续以更复杂的算子为例子演示公式的使用方法,求解这些算子的反向传播公式也是我研究张量函数链式法则的目的:因为对于 transpose,broadcast_to 这类会根据传入的参数改变输出张量维度数量的算子,常规的矩阵链式法则公式已无法解决。

常见算子的反向传播推导(下半部分)

复习一下

张量函数链式法则的公式为:

\[\nabla_{\lambda_1 \lambda_2 \cdots \lambda_n} = \sum_{\substack{\mu_1 \in [1, e_1] \\ \mu_2 \in [1, e_2] \\ \vdots \\ \mu_m \in [1, e_m]}} g_{\mu_1 \mu_2 \cdots \mu_m} \frac{\partial}{\partial x_{\lambda_1 \lambda_2 \cdots \lambda_n}} f_{\mu_1 \mu_2 \cdots \mu_m} \]

求解步骤为:我们首先需要确定各个张量的形状,然后把注意力集中到自变量里的某个元素,写出这个元素的导数表达式,然后再推广到整个导数张量。

接下来我们继续常见算子的反向传播公式推导。

Summation

这个算子是对输入张量沿着某些轴求和,这个算子有一个参数 axes,表示求和的规约轴,例如,对于一个四维张量 \(X \in \mathbb{R}^{d_1 \times d_2 \times d_3 \times d_4}\),如果 axes=(2, 3),\(F(X) \in \mathbb{R}^{d_1 \times d_4}\),是一个二维张量,且 \(f_{ij} = \sum_{k=1}^{d_2} \sum_{l=1}^{d_3} X_{iklj}\)。

由此可见,对于多个轴的 summation 操作其实可以拆分为多次的对于一个轴的 summation,所以我们仅讨论 axes 只有一个轴的公式,对于有多个轴的场景可以将其视为复合函数,通过反复使用该公式来进行扩展。

单轴 Summation 问题描述

所以我们要解决的问题就变成了:函数 \(F\) 会对张量 \(X\) 的第 \(a\) 个维度进行求和,求该函数的反向传播公式。

(注:本文统一以 1 为起始下标,实际编程时 axes 是以 0 为起始下标,这个差异需要注意)

首先确定各个张量的形状,如果自变量 \(X \in \mathbb{R}^{d_1 \times d_2 \times \cdots \times d_n}\) 是一个 \(n\) 维张量,那么 \(F(X) \in \mathbb{R}^{d_1 \times d_2 \times \cdots \times d_{a-1} \times d_{a+1} \times \cdots \times d_n}\) 为 \(n-1\) 维张量。

单轴 Summation 问题求解

接下来可以写出每个自变量的导数的表达式:

\[\nabla_{\lambda_1 \lambda_2 \cdots \lambda_n} = \sum_{\substack{\mu_1 \in [1, e_1] \\ \mu_2 \in [1, e_2] \\ \vdots \\ \mu_{a-1} \in [1, e_{a-1}] \\ \mu_{a+1} \in [1, e_{a+1}] \\ \vdots \\ \mu_m \in [1, e_m]}} g_{\mu_1 \mu_2 \cdots \mu_m} \cdot \frac{\partial f_{\mu_1 \mu_2 \cdots \mu_m}}{\partial x_{\lambda_1 \lambda_2 \cdots \lambda_n}} \]

\[= \sum_{\substack{\mu_1 \in [1, e_1] \\ \mu_2 \in [1, e_2] \\ \vdots \\ \mu_{a-1} \in [1, e_{a-1}] \\ \mu_{a+1} \in [1, e_{a+1}] \\ \vdots \\ \mu_m \in [1, e_m]}} g_{\mu_1 \mu_2 \cdots \mu_n} \cdot \frac{\partial}{\partial x_{\lambda_1 \lambda_2 \cdots \lambda_n}} \sum_{i=1}^{e_a} x_{\mu_1 \mu_2 \cdots \mu_{a-1} i \mu_{a+1} \cdots \mu_n} \]

注意到,当且仅当 \(\mu_1 = \lambda_1, \mu_2 = \lambda_2, \ldots, \mu_n = \lambda_n\) 时,这个表达式值不为 0,且满足上述条件时,只有当 \(i = \lambda_a\) 时,求和表达式值为 1,\(i\) 为其他值时都为 0,所以这一项的最终结果是 \(g_{\lambda_1 \lambda_2 \cdots \lambda_n}\)。

所以最终的 \(\nabla = \text{broadcast}(G, a)\),即把张量 \(G\) 在第 \(a\) 个轴做 broadcast_tobroadcast_to 操作的定义见下文)。

当然,这里实际操作时首先要对 \(G\) 做 reshape,把因为求和丢掉的轴 unsqueeze 回来,然后再通过 broadcast_to 操作广播到 \(X\) 的形状,具体可以参考下面的具体代码实现:

python 复制代码
a = node.inputs[0]
target_dim_num = len(a.shape)
grad_new_shape = []
for i in range(target_dim_num):
    if i in self.axes:
        grad_new_shape.append(1)
    else:
        grad_new_shape.append(a.shape[i])
return broadcast_to(reshape(out_grad, grad_new_shape), a.shape)

多轴 Summation 问题求解

接下来讨论 axes 有多个的情形,通过上面的讨论,容易想到:只需要把求和规约掉的多个轴通过 reshape 进行 unsqueeze,然后再进行 broadcast 就行了。

实际情况正是如此,以两个轴为例,这种情况可以认为是两个单轴 summation 操作的复合,在实际进行反向传播时,会先传播到第一个单轴 summation,此时会进行一次 broadcast_to,然后这个结果会作为 \(G\) 继续传播到第二个单轴 summation,此时又会进行一次 broadcast_to,最终结果等价于把这两次 broadcast_to 放到一起完成。

严格的数学推导这里就不展开了,留作习题自证不难。

所以对于 Summation,最终的导数结果为:

\[\nabla = \text{broadcast\_to}(G, X.\text{shape}) \]

BroadcastTo

这个算子是对一个张量进行广播操作,也就是把张量的元素在若干个轴上进行"复制"的操作,形成一个更"充实"的张量。numpy,pytorch 等框架在处理形状不同的张量时会自动进行广播操作。例如,\(A\) 的形状是 \((6, 6, 5, 4)\),\(B\) 的形状是 \((6, 5, 4)\),在执行 \(A \odot B\) 时,框架会自动在 \(B\) 的左边补上维度 1,变成 \((1, 6, 5, 4)\),然后再执行广播变成 \((6, 6, 5, 4)\),然后再做哈达马积。

这里我们同样先讨论只针对一个轴进行 broadcast_to 的情形,多轴的情形同样可以视为多个单轴 broadcast_to 的嵌套。

(注:以下讨论涉及到的参数和实际编程中的参数有差异,实际编程中是直接传入 broadcast_to 之后的形状作为参数)

单轴 BroadcastTo 算子定义

单轴 broadcast_to 算子有两个参数:

  • 参数 a,表示在哪一个轴进行广播,该算子要求自变量在这一维度的大小为 1
  • 参数 b,表示要将这一维度广播到多大

这一算子的形式化的定义为:

  • \(X \in \mathbb{R}^{d_1 \times d_2 \times \cdots \times d_n}\),是 \(n\) 维张量,其中 \(d_a = 1\),\(F(X) = \text{broadcast\_to}(X, a)\)
  • 则 \(F(X) \in \mathbb{R}^{d_1 \times d_2 \times \cdots \times d_{a-1} \times b \times d_{a+1} \times \cdots \times d_n}\),其中 \(f_{\lambda_1 \lambda_2 \cdots \lambda_n} = x_{\lambda_1 \lambda_2 \cdots \lambda_{a-1} 1 \lambda_{a+1} \cdots \lambda_n}\)。

直观上来看就是把 \(X\) 在第 \(a\) 维的元素复制了 \(b\) 份。

单轴 BroadcastTo 问题求解

首先可以确认,\(X\) 和 \(\nabla\) 形状相同,为 \(\mathbb{R}^{d_1 \times d_2 \times \cdots \times d_{a-1} \times 1 \times d_{a+1} \times \cdots \times d_n}\),\(G\) 和 \(F(X)\) 的形状相同,为 \(\mathbb{R}^{d_1 \times d_2 \times \cdots \times d_{a-1} \times b \times d_{a+1} \times \cdots \times d_n}\)。

写出 \(\nabla\) 的表达式可得:

\[\nabla_{\lambda_1 \lambda_2 \cdots \lambda_n} = \sum_{\substack{ \mu_1 \in [1, e_1] \\ \mu_2 \in [1, e_2] \\ \vdots \\ \mu_a \in [1, b] \\ \vdots \\ \mu_n \in [1, e_n]}} g_{\mu_1 \mu_2 \cdots \mu_n} \cdot \frac{\partial f_{\mu_1 \mu_2 \cdots \mu_n}}{\partial x_{\lambda_1 \lambda_2 \cdots \lambda_n}} \]

把 \(F\) 的定义式代入,原式子可写作:

\[\sum_{\substack{ \mu_1 \in [1, e_1] \\ \mu_2 \in [1, e_2] \\ \vdots \\ \mu_a \in [1, b] \\ \vdots \\ \mu_n \in [1, e_n]}} g_{\mu_1 \mu_2 \cdots \mu_n} \cdot \frac{\partial x_{\mu_1 \mu_2 \cdots \mu_{a-1} 1 \mu_{a+1} \cdots \mu_n}}{\partial x_{\lambda_1 \lambda_2 \cdots \lambda_n}} \]

注意到,只有当 \(\mu_1 = \lambda_1, \mu_2 = \lambda_2, \ldots, \mu_{a-1} = \lambda_{a-1}, \mu_{a+1} = \lambda_{a+1}, \ldots, \mu_n = \lambda_n\) 时,求和式不为 0,所以这个式子可以进一步化简为:

\[\sum_{\mu_a \in [1, b]} g_{\lambda_1 \lambda_2 \cdots \lambda_{a-1} \mu_a \lambda_{a+1} \cdots \lambda_n} \]

这个表达式的值恰好就等于张量 \(G\) 在 \(a\) 轴做 Summation,所以有:

\[\nabla = \text{Summation}(G, a) \]

多轴 BroadcastTo 问题求解

和 Summation 类似,多轴情形下只需要对所有广播过的轴做 Summation 即可,由此可得,多轴情形下:

\[\nabla = \text{Summation}(G, (a_1, a_2, \ldots, a_m)) \]

其中 \(a_1, a_2, \ldots, a_m\) 是所有经过广播的轴的编号,具体可以参考以下代码实现:

python 复制代码
old_shape = node.inputs[0].shape
new_shape = self.shape
sum_axes = []
for i in range(len(new_shape)):
    if i >= len(old_shape) or (old_shape[i] == 1 and new_shape[i] != 1):
        sum_axes.append(i)

return reshape(summation(out_grad, tuple(sum_axes)), old_shape)

Reshape

顾名思义,这个算子的作用就是改变张量的形状。numpy 对于这个操作的描述是:在不改变数组内容的情况下为数组赋予新的形状。可以认为 numpy 存储的多维张量本质上是一个连续的一维数组,形状只是我们看这个数组的一个视角 ,以二维张量为例,假设这个一维数组是 \([1,2,3,4,5,6]\),如果以 \(2 \times 3\) 矩阵的视角去看,那就会是:

\[\begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{bmatrix} \]

如果以 \(6 \times 1\) 的矩阵视角去看,那就会是:

\[\begin{bmatrix} 1 \\ 2 \\ 3 \\ 4 \\ 5 \\ 6 \end{bmatrix} \]

Reshape 问题求解

这里我们可以猜一下,以三维张量为例,\(\nabla, X \in \mathbb{R}^{e_1 \times e_2 \times e_3}\),\(G, F(X) \in \mathbb{R}^{d_1 \times d_2 \times d_3}\),其中 \(e_1 \times e_2 \times e_3 = d_1 \times d_2 \times d_3\)。

注意到 \(\nabla\) 和 \(G\) 的元素数量相同,只是形状不同,那只需要进行一次 reshape 即可。

事实正是如此,对于 \(F(X) = \text{reshape}(X, \text{new\_shape})\),其反向传播导数:

\[\nabla = \text{reshape}(G, X.\text{shape}) \]

这里具体的数学推导就不再赘述了,留作习题供读者练习。

(提示:可以考虑定义一个辅助函数,将原来轴的参数映射到新的轴上的参数)

Transpose

这一算子的定义是做转置,二维矩阵的转置很显然,就是行列互换。推广到 \(n\) 维张量,就是选择两个轴,然后在这两个轴上做互换。

(注:这里的 transpose 是 CMU Homework1 里面定义的,而非 numpy 里的定义,这里只会转置两个轴,但是这里推导得到的结果可以轻易推广到多轴的情形)

Transpose 形式化定义

  • 这一算子有 2 个参数 ab,表示需要转置的两个轴
  • 若 \(X \in \mathbb{R}^{d_1 \times d_2 \times \cdots \times d_n}\),是 \(n\) 维张量
  • 则 \(F(X) \in \mathbb{R}^{d_1 \times d_2 \times \cdots \times d_{a-1} \times d_b \times d_{a+1} \times \cdots \times d_{b-1} \times d_a \times d_{b+1} \times \cdots \times d_n}\) 也是 \(n\) 维张量,只是第 \(a\) 维和第 \(b\) 维的大小互换了
  • 且其中:

\[f_{\lambda_1 \lambda_2 \cdots \lambda_{a-1} \lambda_b \lambda_{a+1} \cdots \lambda_{b-1} \lambda_a \lambda_{b+1} \cdots \lambda_n} = f_{\lambda_1 \lambda_2 \cdots \lambda_{a-1} \lambda_a \lambda_{a+1} \cdots \lambda_{b-1} \lambda_b \lambda_{b+1} \cdots \lambda_n} \]

Transpose 问题求解

这里也很容易才到,对 \(G\) 做同样的转置即可得到,这里同样不展开赘述了,留作习题供读者练习。

(提示:同样可以考虑定义映射轴的辅助函数来解决)

MatMul

这一算子是矩阵乘法,二维矩阵的公式已经在上一篇文章里给出,这里主要补充一下 batch 模式下的矩阵乘法。根据 numpy 里的定义,进行 MatMul 的两个张量 \(X\),\(Y\) 可以是两个高维的张量,例如,当 \(X\) 的形状为 \((6, 6, 5, 3)\),\(Y\) 的形状为 \((6, 6, 3, 4)\) 时,会把 \(X\) 视为是 36 个 \(5 \times 3\) 矩阵按照 \(6 \times 6\) 的格式排列,然后把 \(Y\) 视为 36 个 \(3 \times 4\) 的矩阵按照 \(6 \times 6\) 排列,最后将两个大矩阵中对应位置的两个小矩阵做矩阵乘法,最终会得到 36 个 \(5 \times 4\) 的小矩阵,组成一个形状为 \((6, 6, 5, 4)\) 的张量。

这一操作同样支持广播,即:如果 \(X\) 形状为 \((6, 6, 5, 3)\),\(Y\) 的形状为 \((3,4)\),那么最终结果会是形状为 \((6, 6, 5, 4)\) 的张量,即 \(X\) 的 36 个小矩阵每一个都和 \(Y\) 做矩阵乘法。

这种情形下,如果记单个矩阵乘法的函数为 MatMul,批量矩阵乘法函数为 MatMul_Batch,那么此时 MatMul_Batch 实际上是 MatMul(X, broadcast_to(Y, X.shape)),所以在处理 MatMul_Batch 对 \(Y\) 求导时,需要考虑到这里实际上是嵌套了一层广播的,而广播的反向传播是做 Summation,所以在套用单矩阵 MatMul 的反向传播公式之后还需要做一个 Summation 将形状变回和 \(Y\) 相同的形状,具体过程可以参考如下的代码实现:

(注:理论上是需要先做 Summation 再做 Matmul 的反向传播,但是先做 Summation 和后做是等价的,为了代码实现方便就统一放到后面来做了)

python 复制代码
a, b = node.inputs
a_grad, b_grad = matmul(out_grad, transpose(b)), matmul(transpose(a), out_grad)

if len(a_grad.shape) > len(a.shape):
    sum_axes = tuple((i for i in range(len(a_grad.shape) - len(a.shape))))
    a_grad = summation(a_grad, sum_axes)

if len(b_grad.shape) > len(b.shape):
    sum_axes = tuple((i for i in range(len(b_grad.shape) - len(b.shape))))
    b_grad = summation(b_grad, sum_axes)

return a_grad, b_grad

一些剩下的简单算子

接下来放一些简单算子的反向传播公式,这里就只给出结果而省略推导过程了。

Negate

这个算子是把张量中所有元素取相反数,很显然:

\[\nabla = -G \]

Log

这个算子是对张量中所有元素取自然对数,很显然:

\[\nabla = \frac{G}{X} \]

Exp

这个算子是对张量中所有元素过一次指数函数 \(y = e^x\),很显然:

\[\nabla = G \odot \exp(X) \]

EWisePow

这个算子接收 2 个相同形状的自变量 \(X\) 和 \(Y\)(如果形状不同会进行广播到相同形状),对于 \(X\) 里的每一个 \(x\),取 \(Y\) 对应位置上的元素 \(y\),做 \(x^y\)。

很显然:

\[\nabla^X = G \odot Y \odot \text{EWisePow}(X, Y - 1) \]

\[\nabla^Y = G \odot \text{EWisePow}(X, Y) \odot \log(X) \]