深度学习进阶(八)Swin Transformer

上一篇中,我们已经明确了 DeiT 对 ViT 的改进思路:

通过蒸馏,引入 Teacher 的归纳偏置,缓解数据依赖问题。

但这条路线有一个明显局限:

它并没有改变 Transformer 本身的结构。

换句话说,DeiT 本身依然是一个全局 attention、无层级结构、内部缺乏局部归纳偏置的模型。

所以相应的改进思路水到渠成:

如果不只是"引导 Transformer 学习",而是"直接修改 Transformer 的结构",会怎样?

这便是在 21 年的论文:Swin Transformer: Hierarchical Vision Transformer using Shifted Windows中提出的 ViT 的另一变体: Swin Transformer 的出发点。

1. ViT 在实际训练中的问题

在展开 Swin Transformer 的具体改进前,我们需要简单了解一下原始 ViT 在实际训练中的一些问题。

1.1 CV 中的注意力计算量

在原始 ViT 中,每一层 Transformer 都对所有 patch token 两两计算注意力 。这一步带来了一个非常关键的代价:计算复杂度过高

其实,这倒不能怪 ViT ,归根结底其实是自注意力本身的结构让这一原本应用于序列任务的设计在视觉任务中便显现出一些不适。

假设一张图像被切分为 \(N\) 个 patch,那么 self-attention 的计算复杂度就是:

\[\mathcal{O}(N^2) \]

复杂度本身自然是不变的,但问题在于:图像分辨率一旦提升,\(N\) 会快速增长,而计算量是平方级爆炸的。

我们用一个问题来对比说明这种情况:

为什么在 NLP 等序列任务中,这个问题并没有这么突出?

其核心原因不在于 attention 机制本身,而在于:序列长度的增长方式是一维的,而图像分辨率的增长是二维的。

如图所示,图像数据分辨率的增长带来的计算量增加是平方级的。

而且,更重要的一点是:高分辨率本就是视觉任务的刚需

在一些视觉任务,如小目标检测、边界精细分割等,都强依赖高分辨率输入。

因此,除去我们一直强调的弱偏置带来的数据依赖问题外,图像数据带来的高计算量也是一个非常现实的问题。

1.2 归纳偏置问题

对这点我们并不陌生了,我们再用 CNN 和 ViT 的差异简单复述一下:

  1. 因为卷积核,CNN 会优先关注"邻域信息",并逐步扩大感受范围。 而 ViT 一上来就做全局建模,但边缘、纹理等低级视觉特征需要靠数据自己学出来。
  2. 因为层级结构,CNN 输出的特征图在层级间分辨率逐渐降低,而语义逐渐增强。但在原始 ViT 所有 patch token 的尺度是固定的,不存在多尺度表示。

总之还是我们之前说的:Transformer 在视觉任务中缺乏"结构性约束",搜索过于自由,以至于收敛慢,数据要求高。

以上便是 原始 ViT 在实际训练中会产生的一些问题,而 Swin Transformer 便是针对它们进行了改进。

我们已经提到过 Swin Transformer 是通过优化自身结构增强了归纳配置,而在展开其具体逻辑前总结一下其核心思想,那就是:

让 Transformer 具备类似 CNN 的归纳偏置。

需要提前说明的是,在这里,包括下面的具体逻辑中,你很可能会产生这样一种疑惑:Swin Transformer 在尽力还原 CNN 的逻辑,就像开倒车一样,那既然都是同一套归纳偏置,两种结构到底差在哪里?这种改进的具体影响又是哪方面的?

我们先展开 Swin Transformer 的具体改进内容,了解完其本身后就来展开这个问题。

2. Window Attention

Swin 的第一步非常直接:

把"全局 attention"改为"局部 attention"

具体做法很简单:保持 patch token 的二维结构,在此基础上划分多个不重叠的窗口,即 window,每个窗口内部独立进行自注意力计算。

这部分是整个后续设计的前提,我们详细展开如下:

总结来说,这一设计其实就是把原本的全面注意力变成局部注意力了,很容易让我们联想到卷积核

实际上,这一改进就是依据局部性先验引入的归纳配置,于此同时,这一操作还引起了计算复杂度的变化:

\[O(N^2) \rightarrow O(N \times M) \]

假设输入图像大小为 \(224 \times 224\),patch size 为 \(4 \times 4\),那么 token 数量为:

\[N = \frac{224 \times 224}{4 \times 4} = 56 \times 56 = 3136 \]

在标准 ViT 中,每一层 attention 的计算复杂度为:

\[\mathcal{O}(N^2) = 3136^2 \approx 9.8 \times 10^6 \]

现在我们引入 window,假设 window size 为 \(7 \times 7\),那么 window 的数量为:

\[\frac{3136}{49} = 64 \]

总计算量就是:

\[64 \times 49^2 \approx 64 \times 2401 \approx 1.5 \times 10^5 \]

对比一下:

\[\frac{9.8 \times 10^6}{1.5 \times 10^5} \approx 65 \]

也就是说: 仅通过局部窗口划分,计算量就下降了约 60~70 倍。

而究其根本,是因为当窗口大小固定时,计算复杂度从"平方增长"变成了"线性增长"。

不过,这一改进也带来了相应的问题:

不同窗口之间的信息被完全隔离。

自然,Swin Transformer 也有其相应的策略:

3. Shifted Window(移位窗口)

就像刚刚说的,如果只有 Window Attention,模型就会退化为"块状 CNN"。

为了解决这个问题,Swin 提出了核心机制: Shifted Window

这步的做法看似简单,但实际上存在一些要注意的细节:

循环平移窗口 ,同时使用掩码遮蔽非法邻接,从而交互跨窗口信息。

我们具体展开如下:

3.1 循环平移

复述一下要点,向右下方向平移 token 网格,平移距离一般为:

\[shift = \frac{window\ size}{2} \]

这样就可以让窗口内容发生变化,交互不同区域里的信息。

但就像图里说的,这里会出现一个问题:

一些本不相邻的区域被强行拉到一起了,比如右下角和左上角。

对此,Swin 的解决方式是:

对"跨越原边界"的 token 进行 mask,不允许它们参与 attention。

而展开其具体逻辑前,我们要先解释清楚一个问题:

为什么要拒绝直接计算不相邻区域的注意力?这又并不像原始 transformer 的因果掩码一样会透露未来信息。

如果用一句话来总结就是:这是为了局部性偏置的刻意设计

简单展开一下:

如果 shift 后直接 attention,那就会变成:token 之间可以随意跨原 window 交互。

而 Swin 想要的是:"先局部建模,再逐层扩散 "。

一旦这么做,模型就成了"弱版 ViT",又回到了全局注意力的范畴,还多了层 window 的设计。
因此,这里的 mask 为了约束信息只能沿着"窗口重组后的局部结构"传播,从而保留 Swin 的归纳偏置设计。

这点内容,我们展开下面的 mask 具体设计后就可以更好的理解。

3.2 Shifted Window 的 mask 设计

先总结一下 mask 的设计逻辑就是:在 Shift 前构造一个"窗口分组编号图",Shift 后再根据编号差异生成 attention mask。

具体如下:

而在这一层的注意力计算中,掩码的设计方法就和这些编号紧密相关,其形式如下:

\[M_{ij} = \begin{cases} 0 & \text{相同编号} \\ -\infty & \text{不同编号} \end{cases} \]

再摆一下掩码公式:

\[\mathrm{Attention}(Q,K,V)=\mathrm{softmax}\left(\frac{QK^T}{\sqrt{d_k}} + M\right)V \]

总结来说就是:在新划分的 window 中,只计算和自己编号相同的 token 的注意力。

这样,我们就实现让信息跨窗口流动的同时保证了局部性偏置不被破坏

这里要专门强调一点:在完成 Shifted Window 的相关计算和融合后,输出的 token 会被拼接回原始图像的结构。

这是因为 Swin 的整体设计,也方便堆叠的同时,在平移窗口中交流的 token 也会在重组后"和本地人重新交流外部信息"。

4. patch merging

到这里,我们就可以解释一下最开始的网络结构图中的一些内容了。

首先,在上面的内容里,我们知道 Swin 改进了 Transformer 的 block 结构,我们称之为 Swin block : 要么 W-MSA,要么 SW-MSA,两个 Swin block 在网络中交替出现

补充一个细节:在这里对归一化的设计就变成了我们之前提到的Pre-Norm(前归一化)结构

这样,我们就明白了 Swin 的大部分结构,只剩下了最后一个改进点:patch merging

如果再类比一下,它有些像 CNN 里的池化,但又有所不同。

我们先具体展开其逻辑,再说它的作用,我们假设前一层输出的 token 网格尺寸是:

\[H \times W \times C \]

Patch Merging 做的事情是:把每 2×2 的 patch 合并成 1 个 token

具体操作就是把这 4 个 token 的特征拼接在通道维度上:

\[ C \rightarrow 4C \]

对整个 token 网格应用后,空间尺寸就会变为:

\[\frac{H}{2} \times \frac{W}{2} \]

还没完,Swin 不会直接用 \(4C\),而是再接一个线性层:

\[4C \rightarrow 2C \]

最终输出的尺寸大小为:

\[\frac{H}{2} \times \frac{W}{2} \times 2C \]

到这里,其实 CNN 的思想就非常明显了:

空间分辨率逐渐减少,通道尺度逐渐增大,信息密度逐渐增加,由局部扩散到全局。

这其实就是我们很早之前介绍过的LeNet-5 范式的 Transformer 实现。

现在再看 Swin 的整体网络结构图,你就会非常明显地感受到这种思路:

如图所示,值得一提的是这一改进让 Transfromer 可以获取多尺度特征,从而用于分割等下游任务,而原始的 ViT 是无法实现这一点的。

到这里,我们就完成了 Swin 本身的全部内容。

现在,就可以展开前面的一个问题:

5. Swin 和 CNN

我们已经知道了 Swin 的核心思想:改进结构引入归纳偏置

现在,一个终极问题来到面前:

Swin 和 CNN 用的其实都是同一套归纳偏置,那 Swin 这一更现代的模型到底提升在哪里?是可用性?是上限?对归纳的实现形式到底是怎么影响最终的结果的?

要回答这个问题,关键不在"是否引入归纳偏置",而在于:归纳偏置是如何被实现的。

在 CNN 中,归纳偏置是被硬编码进结构里的。卷积操作本质上是在固定的局部邻域内,用一组共享的卷积核进行加权求和,即:

\[y_i = \sum_{j \in \text{local}} w_j x_j \]

这里的"如何建模邻域关系"其实已经被提前规定好了:连接关系是局部的,权重是共享的,信息融合方式是线性的。

这意味着模型从一开始就被假设为"所有位置的局部模式是相似的",从而大幅缩小了搜索空间,提升了训练稳定性。

但与此同时,也限制了表达能力:模型只能在既定的结构范式内进行拟合。

而在 Swin Transformer 中,归纳偏置则被部分保留、部分放开。window 的划分依然限制了注意力只能在局部范围内计算,但在这个局部范围内,token 之间的关系不再由固定卷积核决定,而是通过 self-attention 动态计算:

\[\alpha_{ij} = \mathrm{softmax}(q_i \cdot k_j) \]

也就是说,Swin 并没有规定"邻域信息该如何组合",而是把这部分建模能力交给数据去学习。这时二者的差异所在。

进一步看信息传播方式,这种差异会被放大:在 CNN 中,感受野的扩大依赖于层层堆叠,远距离信息必须通过多次局部传播逐渐传递

而在 Swin 中,每一层 window 内部都是全连接的 attention,再配合 shifted window 的窗口重组机制,使得原本分离的区域在后续层中被重新组织。这意味着信息的传播路径不再是单纯的逐层扩散,而是通过"结构重排 + 全连接建模"实现跳跃式传递,缩短了建模长距离依赖的路径长度。

因此,Swin 的意义并不在于"用 Transformer 复现 CNN",而在于:在 Transformer 的表达能力之上,引入适度的结构约束,使其既具备视觉任务所需的归纳偏置,又不牺牲对复杂关系的建模能力。

实际上,Transformer 不断发展的同时,CNN 也并非止步不前,当 Swin 在学习"约束"的同时,CNN 也在学习"自由"。

所以,CNN 的 Transformer 化便是下一篇要展开的内容。

相关推荐
YoseZang3 小时前
【机器学习】【手工】Streaming Machine Learning 流数据学习 – 应对变化的机器学习方法(一)
人工智能·学习·机器学习
henrylin99995 小时前
Hermes Agent 核心运行系统调用流程--源码分析
开发语言·人工智能·python·机器学习·hermesagent
泰恒5 小时前
国内外大模型的区别与差距
人工智能·深度学习·yolo·机器学习·计算机视觉
zs宝来了6 小时前
LangChain RAG 架构:向量检索与生成流水线
机器学习·ai·基础设施
沅_Yuan6 小时前
基于LSTM神经网络的锂电池SOH估算模型(NASA数据集)【MATLAB】
神经网络·机器学习·matlab·锂电池·nasa·soh
yfndsb6 小时前
从入门到落地:OpenClaw 全面介绍与全平台本地部署保姆级教程
人工智能·python·ai
xixixi777776 小时前
AI自主挖洞 + 通信网络扩散:全域风险指数级放大,如何构建密码-沙箱-终端联动闭环?
开发语言·网络·人工智能·ai·大模型·php·通信
沅_Yuan7 小时前
基于KAN神经网络的锂电池SOH估算模型(NASA数据集)【MATLAB】
神经网络·机器学习·matlab·锂电池·nasa·soh