计算机视觉|Swin Transformer:视觉 Transformer 的新方向

一、引言

在计算机视觉领域的发展历程中,卷积神经网络(CNN) 长期占据主导地位。从早期的 LeNet 到后来的 AlexNet、VGGNet、ResNet 等,CNN 在图像分类、目标检测、语义分割等任务中取得了显著成果。然而,CNN 在捕捉全局信息和处理长距离依赖关系方面存在局限性。与此同时,Transformer Architektur 在自然语言处理(NLP)领域表现出色,凭借自注意力机制有效捕捉序列数据中的长距离依赖关系,例如 GPT 系列模型在语言生成和问答系统中的成功应用。

将 Transformer 直接应用于视觉任务面临挑战,例如计算复杂度高,尤其是在处理高分辨率图像时,计算量会随着图像尺寸增加而显著增长,对硬件资源和计算时间要求较高。此外,Transformer 最初为序列数据设计,在提取图像局部特征方面不如 CNN 有效。

Swin Transformer 通过引入 窗口注意力机制,将特征图划分为多个不重叠窗口,在每个窗口内进行自注意力计算,从而降低了计算复杂度。它采用分层结构,类似 CNN 的层次设计,能够提取不同尺度的特征,适应多尺度视觉任务。此外,补丁合并层 通过减少特征图尺寸并增加通道数进一步提升性能。Swin Transformer 在多个视觉任务中表现出色,成为计算机视觉领域的研究重点。本文将深入分析其原理、结构、优势及应用案例。


二、Swin Transformer 的背景

在深度学习发展中,卷积神经网络(CNN) 在计算机视觉领域占据重要地位。LeNet 在手写数字识别中取得初步成功,为 CNN 奠定了基础。2012 年,AlexNet 在 ImageNet 挑战赛中以更深的网络结构和 ReLU 激活函数大幅提升准确率,推动了深度学习在视觉领域的快速发展。此后,VGGNet 通过堆叠小卷积核减少参数,ResNet 通过残差连接解决深层网络的梯度问题,使网络能够更深层并学习复杂特征。

然而,CNN 在捕捉全局信息方面能力较弱。卷积操作主要提取局部特征,通过多层卷积扩大感受野,但对长距离依赖关系建模仍有限制。与此同时,Transformer 在 NLP 领域凭借自注意力机制和并行计算能力取得成功,GPT-3 等模型展示了其语言理解和生成能力。

研究者尝试将 Transformer 应用于视觉任务,但面临图像数据与文本数据的结构差异及高计算复杂度问题。Swin Transformer 的提出旨在将 Transformer 的能力引入视觉领域,通过窗口注意力机制和分层结构降低计算复杂度,提升特征提取能力。


三、核心原理剖析

Swin Transformer 是一种基于 Transformer 的视觉模型,其核心创新在于层次化架构设计(Hierarchical Architecture)和移位窗口自注意力(Shifted Window Self-Attention)。这一设计使其能够高效处理图像数据,同时兼容卷积神经网络(CNN)的多尺度特征提取能力,适用于分类、检测、分割等任务。

(一)整体架构与分层设计


Swin Transformer 的整体架构分为 4 个阶段(Stage),每个阶段通过 Patch Merging 操作逐步降低特征图分辨率,同时增加通道维度,形成金字塔式的层次化特征表示。整体流程如下:

  1. 输入处理

    Patch Partition:将输入图像划分为 4 × 4 4\times4 4×4 的非重叠块(Patch),每个块通过 线性投影 (Linear Embedding)转换为特征向量。例如,输入图像尺寸为 H × W × 3 H \times W \times 3 H×W×3,处理后得到 ( H 4 × W 4 × C \frac{H}{4} \times \frac{W}{4} \times C 4H×4W×C) 的特征图( C 为嵌入维度,默认 96 C 为嵌入维度,默认 96 C为嵌入维度,默认96)。

  2. Stage 1~4

    • 每个 Stage 包含若干 Swin Transformer Block 和一个 Patch Merging 层(最后一个 Stage 无 Patch Merging)。

    Swin Transformer Block:交替使用 窗口多头自注意力(W-MSA)移位窗口多头自注意力(SW-MSA),通过窗口划分减少计算复杂度。

    Patch Merging:将相邻的 2 × 2 2\times2 2×2 块合并为一个块(类似池化),分辨率降低为原来的 1 2 \frac{1}{2} 21,通道数增加为原来的 2 2 2 倍(例如从 C C C 到 2 C 2C 2C)。

典型配置示例

Stage 特征图分辨率 通道数 Swin Block 数量 窗口大小
1 H 4 × W 4 \frac{H}{4} \times \frac{W}{4} 4H×4W 96 2 7×7
2 H 8 × W 8 \frac{H}{8} \times \frac{W}{8} 8H×8W 192 2 7×7
3 H 16 × W 16 \frac{H}{16} \times \frac{W}{16} 16H×16W 384 6 7×7
4 H 32 × W 32 \frac{H}{32} \times \frac{W}{32} 32H×32W 768 2 7×7

(二)窗口注意力机制(W-MSA)

窗口注意力机制 是 Swin Transformer 的核心创新。传统 Transformer 的全局自注意力计算复杂度随图像尺寸平方增长,而 Swin Transformer 将特征图划分为不重叠窗口(如 7x7),在窗口内进行自注意力计算。窗口大小为 MxM 时,计算复杂度为 O(M² * H/W),远低于全局自注意力的 O(HW²)。

实现过程包括:将特征图划分为窗口,计算每个窗口内的 Query、Key 和 Value 矩阵,通过矩阵运算生成注意力权重并与 Value 相乘。这种方式降低计算量,同时保留局部特征提取能力。

(三)移位窗口机制(SW-MSA)

窗口注意力机制虽高效,但窗口间无交互可能限制全局信息捕捉。为此,Swin Transformer 引入 移位窗口机制。在连续自注意力层间,窗口位置移动(如右、下移),使相邻窗口部分重叠,促进信息交互。超出边界的区域通过填充处理。这一机制在保持低计算复杂度的同时增强全局上下文建模能力。

(四)补丁合并层(Patch Merging)

补丁合并层 用于构建层次特征。将特征图按 2x2 窗口切分,拼接为 4C 维向量(C 为原通道数),通过线性层降维至 2C,最终特征图尺寸减半,通道数翻倍。这一过程逐步整合局部特征,提取更具代表性的全局特征。

(五)多头自注意力机制(Multi-Head Self-Attention)

Swin Transformer 沿用 多头自注意力机制,通过多个线性变换生成多组 Query、Key 和 Value 矩阵,分别计算注意力并拼接输出。不同注意力头关注图像的不同特征(如形状、纹理),提升模型对复杂任务的适应性。


应用领域展示

(一)图像分类

在 ImageNet 数据集上,Swin Transformer 表现优异。例如,Swin-B 在 ImageNet-22K 预训练后,在 ImageNet-1K 上 Top-1 准确率达 87.3%,优于 ResNet50(约 76%)及 Vision Transformer(ViT)。其优势在于窗口注意力机制和移位窗口机制结合,有效捕捉全局和局部信息。

(二)目标检测

在 COCO 数据集上,以 Swin Transformer 为骨干网络的 Mask R-CNN 模型 mAP 达 49.5,超越 Faster R-CNN(约 38)。分层结构提取多尺度特征,窗口机制增强上下文信息捕捉,提升检测精度。

(三)语义分割

在 ADE20K 数据集上,基于 Swin Transformer 的 UperNet 模型 mIoU 达 44.5,高于 FCN(约 41)。其多尺度特征提取和上下文理解能力提升像素级分类准确性。


优势对比分析

(一)与传统 CNN 对比

与 CNN 相比,Swin Transformer 在全局信息和长距离依赖建模上更强。CNN 通过卷积提取局部特征,依赖多层堆叠扩大感受野,而 Swin Transformer 的自注意力机制直接捕捉全局依赖。计算复杂度方面,Swin Transformer 通过窗口机制实现近似线性增长,优于传统 Transformer 的平方级增长。

(二)与Vision Transformer 模型对比

相比 Vision Transformer(ViT),Swin Transformer 通过窗口机制降低计算复杂度,提升空间和计算效率。其多层次设计同时捕捉局部和全局特征,且灵活的窗口调整适应不同任务,性能更优。

特性 Swin Transformer Vision Transformer (ViT)
特征图分辨率 多尺度(4 个 Stage) 单尺度(固定分辨率)
计算复杂度 线性复杂度(窗口划分) 平方复杂度(全局注意力)
适用任务 分类、检测、分割 主要分类
位置编码 相对位置编码 绝对位置编码

代码实现示例

以下是使用 Python 和 PyTorch 框架实现 Swin Transformer 中关键模块的代码示例,并对代码进行详细解释,以帮助读者更好地理解模型的实现细节。

(一)窗口注意力机制(W-MSA)

WindowAttention 是 Swin Transformer 的核心模块之一,实现了窗口内的多头自注意力机制(Window-based Multi-head Self-Attention, W-MSA),通过限制注意力计算范围降低复杂度,并加入相对位置偏置以增强空间感知能力。

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

class WindowAttention(nn.Module):
    """基于窗口的多头自注意力模块,包含相对位置编码(Swin Transformer的核心组件)"""
    
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
        """
        Args:
            dim (int): 输入特征维度
            window_size (tuple): 窗口大小 (h, w)
            num_heads (int): 注意力头的数量
            qkv_bias (bool): 是否在qkv线性层添加偏置
            qk_scale (float): 缩放因子,默认为 head_dim^-0.5
            attn_drop (float): 注意力dropout概率
            proj_drop (float): 输出投影层的dropout概率
        """
        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        head_dim = dim // num_heads  # 每个注意力头的维度
        
        # 缩放因子,用于缩放点积注意力得分
        self.scale = qk_scale or head_dim ​**​ -0.5

        # 定义相对位置编码表:存储所有可能相对位置的位置偏置
        # 形状为 [(2h-1)*(2w-1), num_heads],用于表示不同相对位置的注意力偏置
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
        )

        # 生成窗口内每个位置的坐标(用于计算相对位置索引)
        coords_h = torch.arange(self.window_size[0])  # 高度方向坐标 [0,1,...,h-1]
        coords_w = torch.arange(self.window_size[1])  # 宽度方向坐标 [0,1,...,w-1]
        coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='xy'))  # 网格坐标 [2, h, w]
        coords_flatten = torch.flatten(coords, 1)  # 展平为 [2, h*w]

        # 计算相对坐标差值(每个位置与其他位置的相对坐标差)
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # [2, h*w, h*w]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # 调整维度顺序为 [h*w, h*w, 2]

        # 将相对坐标偏移到非负数范围(方便作为索引)
        relative_coords[:, :, 0] += self.window_size[0] - 1  # 行偏移到 [0, 2h-2]
        relative_coords[:, :, 1] += self.window_size[1] - 1  # 列偏移到 [0, 2w-2]
        
        # 将二维相对坐标转换为一维索引(用于查表)
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1  # 行坐标乘以跨度
        relative_position_index = relative_coords.sum(-1)  # 合并坐标得到一维索引 [h*w, h*w]
        
        # 注册为不参与梯度更新的缓冲区(在forward中通过索引获取位置偏置)
        self.register_buffer("relative_position_index", relative_position_index)

        # 定义qkv投影层:将输入特征映射为query, key, value
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        
        # 定义输出投影层和dropout
        self.proj = nn.Linear(dim, dim)         # 合并多头输出
        self.attn_drop = nn.Dropout(attn_drop)  # 注意力分数dropout
        self.proj_drop = nn.Dropout(proj_drop)  # 输出投影dropout

        # 初始化相对位置偏置表(正态分布)
        nn.init.normal_(self.relative_position_bias_table, std=0.02)

    def forward(self, x, mask=None):
        """
        Args:
            x (Tensor): 输入特征,形状为 [batch_size*num_windows, num_patches, dim]
            mask (Tensor): 窗口注意力掩码(用于SW-MSA),形状为 [num_windows, num_patches, num_patches]
        
        Returns:
            Tensor: 输出特征,形状同输入
        """
        B_, N, C = x.shape  # B_ = batch_size * num_windows, N = num_patches (h*w), C = dim
        
        # 生成qkv向量,并重塑为多头形式
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # 分离q/k/v [B_, num_heads, N, head_dim]

        # 缩放点积得分
        q = q * self.scale  # 缩放query
        
        # 计算原始注意力分数 [B_, num_heads, N, N]
        attn = (q @ k.transpose(-2, -1))  # 矩阵乘法计算注意力分数

        # 添加相对位置偏置(从预定义的表中获取)
        relative_position_bias = self.relative_position_bias_table[
            self.relative_position_index.view(-1)  # 将索引展平查表
        ].view(
            self.window_size[0] * self.window_size[1],  # 窗口内总位置数(h*w)
            self.window_size[0] * self.window_size[1], 
            -1
        )  # 形状变为 [h*w, h*w, num_heads]
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # [num_heads, h*w, h*w]
        attn = attn + relative_position_bias.unsqueeze(0)  # 广播到batch维度 [B_, num_heads, N, N]

        # 应用掩码(用于SW-MSA的移位窗口)
        if mask is not None:
            nW = mask.shape[0]  # 窗口数量
            # 将attn拆分为不同窗口的注意力并添加掩码
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)  # 重新合并batch维度
            attn = F.softmax(attn, dim=-1)  # 带掩码的softmax
        else:
            attn = F.softmax(attn, dim=-1)  # 普通softmax

        attn = self.attn_drop(attn)  # 应用注意力dropout

        # 计算加权值向量并合并多头输出
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)  # [B_, N, dim]
        
        # 输出投影和dropout
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

(二)补丁合并层(Patch Merging)

PatchMerging 是 Swin Transformer 中用于降采样和通道扩展的模块,通过合并相邻补丁减少空间分辨率并增加特征维度。

python 复制代码
import torch
import torch.nn as nn

class PatchMerging(nn.Module):
    """空间下采样模块,用于Swin Transformer的层次化特征提取(类似CNN中的池化层)
    功能:将特征图分辨率降低为1/2,通道数增加为2倍(通过合并相邻2x2区域的特征)
    """
    
    def __init__(self, dim, norm_layer=nn.LayerNorm):
        """
        Args:
            dim (int): 输入特征维度
            norm_layer (nn.Module): 归一化层,默认为LayerNorm
        """
        super().__init__()
        self.dim = dim
        
        # 定义线性投影层:将4*dim维特征映射到2*dim维(通道数翻倍)
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        
        # 归一化层:作用于合并后的特征(输入维度为4*dim)
        self.norm = norm_layer(4 * dim)

    def forward(self, x, H, W):
        """
        Args:
            x (Tensor): 输入特征,形状为 [batch_size, H*W, dim]
            H, W (int): 特征图的高度和宽度
        
        Returns:
            Tensor: 下采样后的特征,形状为 [batch_size, (H//2)*(W//2), 2*dim]
        """
        B, L, C = x.shape
        assert L == H * W, "输入特征长度必须等于H*W"
        
        # 重塑为空间结构 [B, H, W, C]
        x = x.view(B, H, W, C)
        
        # 处理奇数尺寸:当H或W为奇数时,通过padding补充1行/列(右下补零)
        pad_input = (H % 2 == 1) or (W % 2 == 1)
        if pad_input:
            # padding格式:(左, 右, 上, 下, 前, 后) -> 此处仅padding高度和宽度的右侧/底部
            x = nn.functional.pad(x, (0, 0,  # 通道维度不padding
                                      0, W % 2,  # 宽度右侧补 (0或1列)
                                      0, H % 2)) # 高度底部补 (0或1行)
        
        # 划分2x2区域并拼接(空间下采样核心操作)
        # 通过切片操作提取相邻2x2区域的四个子块
        x0 = x[:, 0::2, 0::2, :]  # 左上块 [B, H//2, W//2, C]
        x1 = x[:, 1::2, 0::2, :]  # 左下块
        x2 = x[:, 0::2, 1::2, :]  # 右上块
        x3 = x[:, 1::2, 1::2, :]  # 右下块
        
        # 沿通道维度拼接 -> 通道数变为4倍 [B, H//2, W//2, 4*C]
        x = torch.cat([x0, x1, x2, x3], dim=-1)
        
        # 展平空间维度 -> [B, (H//2)*(W//2), 4*C]
        x = x.view(B, -1, 4 * C)  
        
        # 归一化处理
        x = self.norm(x)
        
        # 线性投影降维:4*C -> 2*C(通道数翻倍)
        x = self.reduction(x)
        
        return x

(三)demo调用及输出结果展示

我们创建一个 demo 函数来演示 Swin Transformer 中 PatchMergingWindowAttention 的调用流程

python 复制代码
import torch

def demo():
    """演示 Swin Transformer 中 PatchMerging 和 WindowAttention 的调用流程"""
    
    # --------------- 参数设置 ---------------
    batch_size = 2             # 批大小
    height, width = 16, 16     # 输入特征图的高和宽
    dim = 96                   # 输入特征的维度(通道数)
    window_size = (4, 4)       # 窗口大小(高方向4像素,宽方向4像素)
    num_heads = 4              # 多头注意力头数
    
    # 创建随机输入特征图 [B, H*W, C]
    x = torch.randn(batch_size, height * width, dim)
    print(f"原始输入形状: {x.shape}")  # 预期输出: [2, 256, 96]

    # --------------- 调用 PatchMerging 模块 ---------------
    patch_merge = PatchMerging(dim=dim)
    x_patch = patch_merge(x, height, width)
    new_height, new_width = height // 2, width // 2  # 下采样后特征图尺寸
    print(f"PatchMerging 输出形状: {x_patch.shape}")  # 预期输出: [2, 64, 192]
    # 注:H/2 * W/2 = 8 * 8=64,通道数从96扩展到192

    # --------------- 调整形状以适应 WindowAttention 输入 ---------------
    # 计算窗口划分后的参数
    num_windows_h = new_height // window_size[0]  # 窗口行数 8//4=2
    num_windows_w = new_width // window_size[1]   # 窗口列数 8//4=2
    num_windows = num_windows_h * num_windows_w   # 总窗口数 2 * 2=4
    tokens_per_window = window_size[0] * window_size[1]  # 每个窗口的token数 4 * 4=16
    new_dim = 2 * dim  # PatchMerging后的通道数 96 * 2=192

    # 重塑特征图为窗口形式 [B * num_windows, tokens_per_window, new_dim]
    x_window = x_patch.view(
        batch_size, 
        num_windows_h, num_windows_w,  # 窗口的行列数
        tokens_per_window,             # 每个窗口的token数
        new_dim                        # 新通道数
    )
    x_window = x_window.permute(0, 1, 2, 3, 4).contiguous()  # 维度调整 [B, num_h, num_w, tokens, C]
    x_window = x_window.view(batch_size * num_windows, tokens_per_window, new_dim)
    print(f"调整为窗口输入形状: {x_window.shape}")  # 预期输出: [8, 16, 192] (2 * 4=8窗口)

    # --------------- 调用 WindowAttention 模块 ---------------
    window_attn = WindowAttention(
        dim=new_dim, 
        window_size=window_size, 
        num_heads=num_heads
    )
    x_out = window_attn(x_window)  # 输入形状 [8, 16, 192]
    print(f"WindowAttention 输出形状: {x_out.shape}")  # 预期输出: [8, 16, 192]

    # --------------- 将输出还原为特征图形式 ---------------
    # 逆向重塑操作(仅用于展示,实际可能不需要)
    x_out = x_out.view(
        batch_size, 
        num_windows_h, num_windows_w,  # 窗口行列数
        tokens_per_window,              # 每个窗口的token数
        new_dim                         # 通道数
    )
    x_out = x_out.permute(0, 1, 3, 2, 4).contiguous()  # [B, num_h, tokens, num_w, C]
    x_out = x_out.view(batch_size, new_height, new_width, new_dim)
    print(f"最终特征图形状: {x_out.shape}")  # 预期输出: [2, 8, 8, 192]

if __name__ == "__main__":
    demo()

输出结果如下:

通过代码示例,我们不仅理解了 Swin Transformer ​层次化架构、窗口注意力和移位窗口机制的实现细节,更深入认识到其设计哲学:在保持 Transformer 全局建模能力的同时,通过局部计算和层次化设计逼近 CNN 的效率优势。这种平衡使其成为视觉任务的通用 Backbone,为后续研究(如 SwinV2、Uniformer)提供了重要参考。


总结与展望

Swin Transformer 在计算机视觉领域具有重要地位,通过 窗口注意力机制移位窗口机制补丁合并层 等设计降低计算复杂度,提升特征提取能力,在多项任务中表现出色。对于研究者而言,它是一个值得深入探索的模型,未来可在更多领域发挥作用。

Swin Transformer 的研究正处于快速发展阶段。优化方向包括改进窗口注意力机制(如动态窗口划分)和降低计算复杂度(如稀疏注意力)。多模态融合 是另一热点,与 NLP 结合可实现图像描述和视觉问答等任务。在应用上,Swin Transformer 在自动驾驶(车辆检测)和医疗影像分析(肿瘤检测)中展现潜力。未来,其通用性和计算效率有望进一步提升,应用范围将更广泛。


延伸阅读


相关推荐
阿正的梦工坊8 分钟前
解析 PyTorch 中的 torch.multinomial 函数
人工智能·pytorch·python
芥子沫11 分钟前
一文了解Conda使用
人工智能
巫山老妖28 分钟前
全球首款通用 AI 智能体 Manus 来袭,AI 圈沸腾了!
人工智能
虾球xz32 分钟前
游戏引擎学习第137天
人工智能·学习·游戏引擎
一水鉴天1 小时前
为AI聊天工具添加一个知识系统 之135 详细设计之76 通用编程语言 之6
开发语言·人工智能·架构
He.Tech1 小时前
DeepSeek大模型+RAGFlow实战指南:构建知识驱动的智能问答系统
人工智能·ai
康谋自动驾驶1 小时前
康谋分享 | 3DGS:革新自动驾驶仿真场景重建的关键技术
人工智能·科技·3d·数据分析·自动驾驶·汽车
麦麦大数据1 小时前
vue+neo4j 四大名著知识图谱问答系统
vue.js·人工智能·python·django·问答系统·知识图谱·neo4j
CodeJourney.1 小时前
Deepseek助力思维导图与流程图制作:高效出图新选择
数据库·人工智能·算法
几道之旅1 小时前
扣子(Coze):重构AI时代的工作流革命
人工智能