计算机视觉CV领域————Swin Transformer

一、Swin Transformer核心概念

Swin Transformer(Shifted Window Transformer)是专为视觉任务设计的 Transformer 变体,解决了原始 Transformer 在处理高分辨率图像时计算量爆炸的问题,核心创新是分层结构和移位窗口注意力机制。

核心概念:

  • 分层特征提取:模仿 CNN 的层级结构,通过 Patch Merging 逐步缩小特征图尺寸、提升通道数,适配不同尺度的视觉特征。
  • 窗口注意力(Window Attention) :将特征图划分为不重叠的窗口,仅在窗口内计算自注意力,把复杂度从 降低到,其中M 为窗口大小。
  • 移位窗口注意力(Shifted Window Attention):解决窗口间信息孤立问题,通过循环移位生成新窗口,同时用「掩码(Mask)」避免无效计算,保证窗口内注意力的正确性。

二、Swin Transformer数学公式

(1)Patch Partition(分块)

将原始图像(H×W×3)划分为大小为 4×4 的不重叠 Patch,每个 Patch 展平为一维向量:

最终得到 (H​/4×W/4​)×48 的特征图,记为

(2)Window Attention 计算

在每个窗口内计算自注意力,核心公式与标准自注意力一致,但作用域限制在窗口内:

  • Q,K,V:查询 / 键 / 值矩阵,由输入特征线性变换得到,dk 为 Q/K 的维度;
  • M:掩码矩阵在仅 Shifted Window 时生效,用于屏蔽移位后跨原始窗口的无效注意力计算。
(3)Shifted Window 移位操作

设窗口大小为 M,特征图尺寸为 H×W,移位量为 ⌊M/2⌋,移位后坐标变换:

;

(4)Patch Merging(分层下采样)

将 2×2 相邻 Patch 拼接,通道数翻倍,尺寸减半:

最终特征图尺寸变为 H/2×W/2×2C。

三、实例代码解释

模块一:核心代码导入
python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
模块二:窗口注意力
python 复制代码
class WindowAttention(nn.Module):
    def __init__(self, dim, window_size, num_heads):
        """
        窗口注意力模块初始化
        Args:
            dim: 输入特征的通道数(如 96、192)
            window_size: 窗口大小(M),默认7,代表7×7的窗口
            num_heads: 多头注意力的头数,需满足 dim % num_heads == 0
        """
        super().__init__()
        # 保存基础参数
        self.dim = dim                      # 输入通道数
        self.window_size = window_size      # 窗口大小 M
        self.num_heads = num_heads          # 注意力头数
        self.head_dim = dim // num_heads    # 每个注意力头的维度
        self.scale = self.head_dim ** -0.5  # 缩放因子 1/√d_k,防止内积值过大

        # 线性变换层:将输入特征一次性映射为Q、K、V(效率高于3个独立线性层)
        # 输入dim → 输出3*dim(Q/K/V各占dim)
        self.qkv = nn.Linear(dim, dim * 3)
        # 输出投影层:将注意力计算后的特征映射回原维度
        self.proj = nn.Linear(dim, dim)

        # 相对位置偏置表:解决绝对位置编码的局限性,捕捉窗口内位置关系
        # 尺寸:(2M-1)×(2M-1) × num_heads → 覆盖窗口内所有可能的相对位置
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads)
        )

        # -------------------------- 预计算相对位置索引 --------------------------
        # 生成窗口内的坐标网格:[0,1,...,M-1]
        coords = torch.arange(self.window_size)
        # 生成2×M×M的坐标矩阵:coords_grid[0]是行坐标,coords_grid[1]是列坐标
        coords_grid = torch.stack(torch.meshgrid([coords, coords], indexing="ij"))
        # 展平坐标:2 × M²(把M×M的坐标拉成一维)
        coords_flatten = torch.flatten(coords_grid, 1)
        # 计算所有位置对的相对坐标:(2, M², M²) → 每个位置相对于其他位置的偏移
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        # 维度重排:(M², M², 2) → [位置i, 位置j, 行/列偏移]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
        # 将相对坐标从[-M+1, M-1]映射到[0, 2M-2](避免负数索引)
        relative_coords[:, :, 0] += self.window_size - 1  # 行偏移修正
        relative_coords[:, :, 1] += self.window_size - 1  # 列偏移修正
        # 行索引编码:行偏移 × (2M-1) + 列偏移 → 唯一标识每个相对位置
        relative_coords[:, :, 0] *= 2 * self.window_size - 1
        # 求和得到最终的相对位置索引:(M², M²)
        relative_position_index = relative_coords.sum(-1)
        # 注册为缓冲区(不参与梯度更新)
        self.register_buffer("relative_position_index", relative_position_index)

    def forward(self, x, mask=None):
        """
        窗口注意力前向传播
        Args:
            x: 输入特征,形状 [num_windows*B, M², dim]
               - num_windows: 特征图划分的窗口总数
               - B: batch size
               - M²: 单个窗口的像素数
               - dim: 通道数
            mask: 掩码矩阵(仅移位窗口时使用),形状 [num_windows, M², M²]
        Returns:
            output: 注意力计算后的特征,形状 [num_windows*B, M², dim]
        """
        # 获取输入维度:B_=num_windows*B, N=M², C=dim
        B_, N, C = x.shape
        
        # -------------------------- 生成Q/K/V --------------------------
        # 1. 线性变换:[B_, N, C] → [B_, N, 3*C]
        # 2. 维度重排:[B_, N, 3, num_heads, head_dim] → 拆分3个维度给Q/K/V
        # 3. 维度置换:[3, B_, num_heads, N, head_dim]
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        # 拆分Q/K/V:每个的形状都是 [B_, num_heads, N, head_dim]
        q, k, v = qkv[0], qkv[1], qkv[2]

        # -------------------------- 计算注意力分数 --------------------------
        q = q * self.scale  # 缩放Q,对应公式中的 1/√d_k
        # Q @ K^T:[B_, num_heads, N, head_dim] × [B_, num_heads, head_dim, N] → [B_, num_heads, N, N]
        attn = (q @ k.transpose(-2, -1))

        # -------------------------- 添加相对位置偏置 --------------------------
        # 从偏置表中取出对应位置的偏置:[M²*M², num_heads] → [M², M², num_heads]
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size * self.window_size, self.window_size * self.window_size, -1
        )
        # 维度重排:[num_heads, M², M²]
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
        # 添加偏置:[B_, num_heads, N, N] + [1, num_heads, N, N] → 广播相加
        attn = attn + relative_position_bias.unsqueeze(0)

        # -------------------------- 应用掩码(移位窗口专用) --------------------------
        if mask is not None:
            nW = mask.shape[0]  # 获取窗口数量
            # 1. 维度适配:将attn拆分为 [B//nW, nW, num_heads, N, N]
            # 2. 掩码广播:mask [nW, N, N] → [1, nW, 1, N, N]
            # 3. 相加:将掩码值(-100)加到跨窗口的注意力分数上
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            # 还原维度:[B_, num_heads, N, N]
            attn = attn.view(-1, self.num_heads, N, N)
            # Softmax归一化:掩码位置的-100会被Softmax为0,不参与计算
            attn = F.softmax(attn, dim=-1)
        else:
            # 无掩码时直接Softmax
            attn = F.softmax(attn, dim=-1)

        # -------------------------- 注意力加权求和V --------------------------
        # attn [B_, num_heads, N, N] × v [B_, num_heads, N, head_dim] → [B_, num_heads, N, head_dim]
        # 维度置换:[B_, N, num_heads, head_dim] → 合并头维度 → [B_, N, dim]
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)

        # -------------------------- 输出投影 --------------------------
        x = self.proj(x)  # 线性变换,保持维度不变
        return x

# ========================== 辅助函数:窗口划分与还原 ==========================
def window_partition(x, window_size):
    """
    将特征图划分为不重叠的窗口
    Args:
        x: 输入特征,形状 [B, H, W, C]
        window_size: 窗口大小 M
    Returns:
        windows: 窗口化特征,形状 [num_windows*B, M, M, C]
                 num_windows = (H/M) × (W/M)
    """
    B, H, W, C = x.shape
    # 维度拆分:[B, H, W, C] → [B, H//M, M, W//M, M, C]
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    # 维度置换:[B, H//M, W//M, M, M, C] → 合并前三维 → [num_windows*B, M, M, C]
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows

def window_reverse(windows, window_size, H, W):
    """
    将窗口化特征还原为完整特征图(window_partition的逆操作)
    Args:
        windows: 窗口化特征,形状 [num_windows*B, M, M, C]
        window_size: 窗口大小 M
        H, W: 原始特征图的高和宽
    Returns:
        x: 还原后的特征图,形状 [B, H, W, C]
    """
    # 计算batch size:num_windows = (H*W)/(M*M) → B = total_windows / num_windows
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    # 维度拆分:[num_windows*B, M, M, C] → [B, H//M, W//M, M, M, C]
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, C)
    # 维度置换:[B, H//M, M, W//M, M, C] → 合并维度 → [B, H, W, C]
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, C)
    return x
模块三:Swin Transformer Block
python 复制代码
class SwinTransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, window_size=7, shift_size=0):
        """
        Swin Transformer 基础块(包含窗口注意力/移位窗口注意力)
        Args:
            dim: 输入通道数
            num_heads: 注意力头数
            window_size: 窗口大小 M
            shift_size: 移位量(0=普通窗口,M//2=移位窗口)
        """
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size  # 移位量,核心参数

        # 层归一化(Transformer标准操作,放在注意力前)
        self.norm1 = nn.LayerNorm(dim)
        # 实例化窗口注意力模块
        self.attn = WindowAttention(dim, window_size, num_heads)

    def forward(self, x):
        """
        Swin Block 前向传播
        Args:
            x: 输入特征,形状 [B, H, W, C]
        Returns:
            x: 输出特征,形状 [B, H, W, C](残差连接后)
        """
        B, H, W, C = x.shape
        shortcut = x  # 保存残差连接的输入

        # 1. 层归一化
        x = self.norm1(x)

        # -------------------------- 移位操作(Shifted Window) --------------------------
        if self.shift_size > 0:
            # 循环移位:向左、向上移动shift_size个像素(负数表示左/上移)
            # 例如M=7,shift_size=3 → 左移3,上移3
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
            
            # -------------------------- 生成移位窗口的掩码 --------------------------
            # 1. 初始化掩码矩阵:[1, H, W, 1],用于标记不同原始区域
            img_mask = torch.zeros((1, H, W, 1), device=x.device)
            # 2. 划分移位后的区域切片(共3×3=9个区域)
            h_slices = (slice(0, -self.window_size),          # 上半部分
                        slice(-self.window_size, -self.shift_size),  # 中间过渡区
                        slice(-self.shift_size, None))        # 下半部分
            w_slices = (slice(0, -self.window_size),          # 左半部分
                        slice(-self.window_size, -self.shift_size),  # 中间过渡区
                        slice(-self.shift_size, None))        # 右半部分
            # 3. 为每个区域分配唯一标签(0-8)
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1
            # 4. 将掩码划分为窗口:[num_windows, M, M, 1]
            mask_windows = window_partition(img_mask, self.window_size)
            # 5. 展平掩码:[num_windows, M²]
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            # 6. 计算注意力掩码:判断两个位置是否属于同一原始区域
            #    - 同一区域:mask=0 → Softmax后正常计算
            #    - 不同区域:mask=-100 → Softmax后为0,不参与计算
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        else:
            # 普通窗口:不移位,无掩码
            shifted_x = x
            attn_mask = None

        # -------------------------- 窗口注意力计算 --------------------------
        # 1. 划分窗口:[B, H, W, C] → [num_windows*B, M, M, C]
        x_windows = window_partition(shifted_x, self.window_size)
        # 2. 展平窗口:[num_windows*B, M², C](适配WindowAttention输入)
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
        # 3. 窗口注意力前向计算
        attn_windows = self.attn(x_windows, mask=attn_mask)

        # -------------------------- 还原窗口为特征图 --------------------------
        # 1. 还原窗口形状:[num_windows*B, M², C] → [num_windows*B, M, M, C]
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        # 2. 窗口还原为特征图:[num_windows*B, M, M, C] → [B, H, W, C]
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)

        # -------------------------- 逆移位(恢复原始位置) --------------------------
        if self.shift_size > 0:
            # 向右、向下移动shift_size个像素,还原到原始位置
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x

        # -------------------------- 残差连接 --------------------------
        x = shortcut + x  # 残差相加,提升梯度传播

        return x
模块四:Patch Merging(分层下采样)
python 复制代码
class PatchMerging(nn.Module):
    def __init__(self, dim):
        """
        Patch Merging 模块:将2×2相邻Patch合并,实现下采样
        Args:
            dim: 输入通道数
        """
        super().__init__()
        self.dim = dim
        # 线性变换:将4*dim通道压缩为2*dim(下采样后通道数翻倍)
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        # 层归一化:放在线性变换前,提升稳定性
        self.norm = nn.LayerNorm(4 * dim)

    def forward(self, x):
        """
        Patch Merging 前向传播
        Args:
            x: 输入特征,形状 [B, H, W, C]
        Returns:
            x: 下采样后的特征,形状 [B, H/2, W/2, 2C]
        """
        B, H, W, C = x.shape
        
        # -------------------------- 2×2 Patch合并 --------------------------
        # 1. 维度拆分:[B, H, W, C] → [B, H//2, 2, W//2, 2, C]
        x = x.view(B, H // 2, 2, W // 2, 2, C)
        # 2. 维度置换:[B, H//2, W//2, 2, 2, C] → 合并最后三维 → [B, H//2, W//2, 4C]
        x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H // 2, W // 2, -1)
        
        # -------------------------- 归一化+通道压缩 --------------------------
        x = self.norm(x)       # 层归一化
        x = self.reduction(x)  # 4C → 2C,通道数翻倍
        
        return x
模块五:测试代码
python 复制代码
if __name__ == "__main__":
    # 模拟输入:batch_size=2,特征图56×56,通道数96(Swin-T的第一层特征)
    x = torch.randn(2, 56, 56, 96)
    
    # 1. 测试普通窗口注意力块(无移位)
    block1 = SwinTransformerBlock(dim=96, num_heads=8, window_size=7, shift_size=0)
    out1 = block1(x)
    print("普通窗口注意力输出形状:", out1.shape)  # 预期:torch.Size([2, 56, 56, 96])
    
    # 2. 测试移位窗口注意力块(移位量=3)
    block2 = SwinTransformerBlock(dim=96, num_heads=8, window_size=7, shift_size=3)
    out2 = block2(x)
    print("移位窗口注意力输出形状:", out2.shape)  # 预期:torch.Size([2, 56, 56, 96])
    
    # 3. 测试Patch Merging下采样
    patch_merge = PatchMerging(dim=96)
    out3 = patch_merge(x)
    print("Patch Merging输出形状:", out3.shape)  # 预期:torch.Size([2, 28, 28, 192])

用 "分窗口算注意力" 降低计算量,用 "挪窗口 + 掩码" 打通窗口间信息,用 "合并像素块" 构建分层特征,最终让 Transformer 能高效处理图片,既快又能学到有用的特征。、

运行结果:

普通窗口注意力输出形状: torch.Size([2, 56, 56, 96])

移位窗口注意力输出形状: torch.Size([2, 56, 56, 96])

Patch Merging输出形状: torch.Size([2, 28, 28, 192])

  • 输入一张真实图片(比如猫、狗的照片);
  • 模型输出这张图片的 "高级特征"(不是像素,是能描述 "这是猫、那是狗" 的特征);
  • 再搭配简单的分类 / 检测头,就能实现图片分类、目标检测、语义分割等视觉任务(比如识别图片里有什么、找到物体的位置)。

四、总结

  • 核心创新:Swin Transformer 用「窗口注意力」降低计算复杂度,用「移位窗口注意力」解决窗口间信息孤立问题,适配高分辨率视觉任务;
  • 数学核心:窗口内自注意力公式,其中掩码 M 是移位窗口的关键;
  • 代码核心:
    • 窗口划分 / 还原是窗口注意力的基础;
    • 移位 + 掩码是实现 Shifted Window 的核心;
    • Patch Merging 实现分层下采样,模仿 CNN 的层级特征。
相关推荐
人工智能AI技术3 小时前
AI如何学会理解和推理的
人工智能
浩瀚之水_csdn4 小时前
AVCodecParameters详解
计算机视觉
Emotional。7 小时前
2025 年度技术总结与规划:AI 时代的开发者成长之路
人工智能·python·ai·langchain
阿星AI工作室9 小时前
一个简单Demo彻底理解前后端怎么连的丨Figma + Supabase + Vercel
前端·人工智能
普通网友10 小时前
Android Jetpack组件:WorkManager的使用
人工智能
CoderJia程序员甲10 小时前
GitHub 热榜项目 - 日榜(2026-02-22)
人工智能·ai·大模型·github·ai教程
witAI11 小时前
**AI仿真人剧制作软件2025推荐,解锁沉浸式数字内容创作
人工智能·python
重生之我要成为代码大佬11 小时前
AI框架设计与选型
人工智能·langchain·大模型·llama·qwen
Dev7z11 小时前
原创论文:基于LSTM神经网络的共享单车需求预测系统设计与实现
人工智能·神经网络·lstm