上一篇我们介绍了 T5 的偏置型 RPE,仅仅使用一个标量偏置,配合分桶策略,就用极低的复杂度实现了 NLP 的高效位置编码。
而下一个问题就是:
一维序列上的标量偏置,到了二维图像上要怎么做?
这一篇我们来补上之前的 Swin Transformer 中一个当时没有展开的细节:二维 RPE。
1. 为什么需要二维 RPE?
在 T5 中,相对位置是一个标量:\(i-j\)。因为文本是一维序列,两个 token 之间的关系只需要一个数字就能描述。
但图像数据不同。一张图像被划分为 \(M \times M\) 的 patch 网格后,两个 patch 之间的相对位置是二维的。
一个 patch 到另一个 patch 的偏移,不仅有"水平方向的距离",还有"垂直方向的距离"。
具体来说,对于图像中的位置 \((x_1, y_1)\) 和 \((x_2, y_2)\),相对位置是分成下面两部分:
\[\Delta x = x_1 - x_2, \quad \Delta y = y_1 - y_2 \]
这时如果再用一维标量来描述这个二维偏移,必然丢失方向信息。
好在,Swin 本身的 Window Attention 设计其实已经为 RPE 减了负:注意力只在 \(7 \times 7\) 的窗口内进行。
这种设计让我们可以不再过多考虑 NLP 中的编码外推问题,但相应的,在这个局部范围内,精确的二维相对关系对建模视觉结构至关重要。
因此,Swin 设计了一套二维的相对位置编码方案。
2. 二维 RPE 如何构造?
我们直接来看 Swin 在窗口注意力中使用的公式:
\[\text{Attention}(Q, K, V) = \text{Softmax}\left(\frac{QK^T}{\sqrt d} + B\right)V \]
公式本身在形式上和 T5 是完全相同的,关键在于偏置矩阵 \(B\) 的构造上。
我们分点来展开:
2.1 直接将 RPE 推广到二维
我们先来看看最直接的方法 :
对于一个 \(M \times M\) 的窗口,直接设计 \(B \in \mathbb{R}^{M^2 \times M^2}\),其中 \(B_{ij}\) 表示窗口内第 \(i\) 个 patch 和第 \(j\) 个 patch 之间的偏置值。
我们用一个简单的例子来演示为什么是 \(M^2 \times M^2\) ,假设窗口大小:\(M=2\) ,那么窗口就是:
\[\begin{bmatrix}t_1&t_2\\t_3&t_4\end{bmatrix} \]
现在,每个 token 都要和另外所有 token 建立关系。那么 \(QK^T\) 计算的注意力得分矩阵形状就是这样的:
\[\begin{bmatrix} t_1\to t_1 & t_1\to t_2 & t_1\to t_3 & t_1\to t_4 \\ t_2\to t_1 & t_2\to t_2 & t_2\to t_3 & t_2\to t_4 \\ t_3\to t_1 & t_3\to t_2 & t_3\to t_3 & t_3\to t_4 \\ t_4\to t_1 & t_4\to t_2 & t_4\to t_3 & t_4\to t_4 \end{bmatrix} \]
偏置矩阵必须和注意力矩阵一一对应。所以 \(B\in\mathbb{R}^{M^2\times M^2}\)。
这种方法当然是可以跑通的,但我们要考虑二维带来的参数问题:
如果直接学习一个 \(M^2 \times M^2\) 的参数矩阵,那每个注意力头就得维护 \(M^4\) 个参数。一个 Swin 有多个头和多个层,累计下来参数巨大。
因此, Swin 自然有对应的改进。
2.2 空间关系的平移不变性
在 NLP 中,我们只针对每种相对位置设计偏置,但是在上面方案里,你会发现直接推广会带来很多无意义的参数,核心是因为:
在二维数据中相对逻辑更加凸显,窗口内大量位置对其实拥有相同的相对偏移。
比如,patch (0,0) 和 (1,0) 之间的偏移是 \((\Delta x=1, \Delta y=0)\),而 patch (2,0) 和 (3,0) 之间的偏移同样是 \((\Delta x=1, \Delta y=0)\)。
它们本质上描述的是同一种空间关系,理应共享同一个偏置值。
于是 Swin 的做法是:推广相对逻辑,不直接学习 \(B\),而是学习一个小得多的偏置表,再通过二维索引从中查值。
3. 紧凑偏置表与查表逻辑
3.1 二维相对位置的计算
首先,对于一个 \(M \times M\) 的窗口,给每个位置一个坐标 \((x, y)\),显然:
\[x, y \in [0, M-1] \]
对于任意两个 patch ,二维相对偏移是:
\[\Delta x = x_i - x_j, \quad \Delta y = y_i - y_j \]
那么,\(\Delta x\) 的取值范围就是 \([-(M-1), M-1]\),一共 \(2M-1\) 种可能。
\(\Delta y\) 同理,这部分的计算逻辑和 T5 是完全相同的。
现在,我们知道了:所有可能的 \((\Delta x, \Delta y)\) 组合一共有 \((2M-1)^2\) 种,也就是说:
我们只需要一个 \((2M-1) \times (2M-1)\) 的偏置表,就能覆盖窗口内所有可能的位置关系。
这就是 Swin 的紧凑偏置表 \(\hat{B}\):
\[\hat{B} \in \mathbb{R}^{(2M-1) \times (2M-1)} \]
建表本身的逻辑到此结束,但现在还有一个小问题:
\(QK^T\) 和 \(\hat{B}\) 大小不一,对于每组注意力计算,我要如何查表注入相应偏置?
3.2 查表过程
其实这步可以理解为:如何将 \(\hat{B}\) 内的值映射到总公式里的 \(B\) 中?
首先,前面我们已经知道了:\(QK^T \in \mathbb{R}^{M^2 \times M^2}\)
因此,真正参与 Attention 计算的偏置矩阵 \(B\),也必须是 \(M^2 \times M^2\)。
但我们刚刚学习的紧凑偏置表只有:
\[\hat{B}\in\mathbb{R}^{(2M-1)\times(2M-1)} \]
不难理解,为了让二者适配,Swin 的设计是这样的:
对于 Attention Matrix 中的每一个元素,都先计算两个 patch 的相对位移,再去 \(\hat{B}\) 中查对应 bias。
展开来说, \(QK^T\) 中的每一个元素本质上都对应"一对 patch 的关系",而每一对 patch 都有自己的 \((\Delta x,\Delta y)\),因此,我们可以计算相对位移,实现查表取值:
\[B_{ij}=\hat{B}[\Delta x,\Delta y] \]
这就实现了相同相对位移的 patch 对,共享同一个偏置。
不过这在实现中还有一个问题:
数组索引没有负数,负偏移并不能和其索引直接对应。
而 \(\Delta x,\Delta y \in [-(M-1),M-1]\),因此 Swin 会先做一次平移去寻找正确索引:
\[\Delta x'=\Delta x+(M-1) \]
\[\Delta y'=\Delta y+(M-1) \]
现在:
\[\Delta x',\Delta y' \in [0,2M-2] \]
于是查表过程就变成:
\[B_{ij}=\hat{B}[\Delta x',\Delta y'] \]

字母还是有些抽象,我们再举一个实例:设 \(M=3\) ,那么 patch 网格可以就是:
\[\begin{bmatrix} (0,0) & (0,1) & (0,2) \\ (1,0) & (1,1) & (1,2) \\ (2,0) & (2,1) & (2,2) \end{bmatrix} \]
此时 \(2M-1=5\) ,因此 \(\hat{B}\in\mathbb R^{5\times5}\),如果当前 patch 为 \((0,0)\),它去关注 \((2,2)\) ,那么:
\[\Delta x=0-2=-2 ,\Delta y=0-2=-2 \]
现在,我们需要查:
\[\hat B[-2,-2] \]
显然,数组索引不能为负数。 所以进行平移:
\[M-1=2 \]
于是:
\[\Delta x'=\Delta x+2 \]
\[\Delta y'=\Delta y+2 \]
原本的 \([-2,-2]\) ,就被平移成 \([0,0]\):
\[\hat B[-2,-2] \quad\Rightarrow\quad \hat B[0,0] \]
这里可能容易疑惑的一点是:
\(\hat{B}\in\mathbb R^{5\times5}\) 中存储的并不是"偏移坐标本身",而是"对应相对位移的偏移参数"。
展开来说:数学意义上的 \((Δx,Δy)=(−2,−2)\) 会被映射到数组索引\((0,0)\),因此,\(\hat B[0,0]\) 实际存储的就是相对位移为 \((-2,-2)\) 时对应的偏置。
这样,所有原本可能为负数的二维位移都被映射到了合法数组索引,可以稳定完成查表。
最终所有 patch 两两之间都会完成一次查表从而动态构造出完整的偏置矩阵:
\[B\in\mathbb R^{M^2\times M^2} \]
随后:
\[\frac{QK^T}{\sqrt d}+B \]
即可完成二维相对位置信息的注入。
值得一提的是在具体实现中,二维紧凑表会被展平成一维,以类似"编号"的逻辑取值,根本逻辑没变,明白即可。
3.3 参数对比
来看看两种方式的参数对比:
| 方式 | \(M=7\)(Swin 默认) | \(M=14\) |
|---|---|---|
| 暴力直接法 | \(49 \times 49 = 2401\) | \(196 \times 196 = 38416\) |
| Swin 紧凑法 | \(13 \times 13 = 169\) | \(27 \times 27 = 729\) |
| 压缩比 | 约 14 倍 | 约 53 倍 |
很明显,随着窗口增大,紧凑表的优势会更加明显。
这便是 Swin 的二维 RPE 的完整逻辑,它十分符合 Swin 的整体设计逻辑,配合其实现了 Attention 在 CV 领域的推广,也在后续的很多混合架构中被使用。