Swin Transformer:深度解析其架构与代码实现

本文为稀土掘金技术社区首发签约文章,30天内禁止转载,30天后未获授权禁止转载,侵权必究!

Swin Transformer是一种强大的视觉Transformer模型,它通过引入层次化结构和基于窗口偏移的自注意力机制,有效提升了特征提取的能力。在多个计算机视觉任务中,Swin Transformer已经达到了最先进的性能水平。本文将深入探讨Swin Transformer的架构,并尝试将其网络结构进行复现。

一、Swin Transformer 概述

Swin Transformer通过扩展原始Transformer模型的能力,引入了层次化结构和基于窗口偏移的自注意力机制,使其能够有效处理图像数据,并可应用于图像分类、目标检测和分割等任务。

1.背景介绍

Swin Transformer,由微软亚洲研究院孕育的新星,今年在学术界大放异彩,以其独特的魅力在图像分类、图像分割和目标检测等众多领域中斩获了无数荣誉。

然而,Swin Transformer的诞生并非一帆风顺。在它之前,Transformer在自然语言处理(NLP)领域已经取得了辉煌的成就,但在计算机视觉(CV)的舞台上却未能同样耀眼。Swin Transformer的创造者们深入分析了这一现象,发现主要有两个难题:首先,NLP中的token大小是固定的,而CV中的特征尺度变化莫测,如同变幻莫测的风;其次,CV对于分辨率的要求更高,而使用Transformer的计算复杂度与图像尺寸的平方成正比,这无疑给计算带来了巨大的压力。

为了克服这些挑战,Swin Transformer进行了两项创新性的改进:首先,它借鉴了CNN中常用的层次化构建方式,构建了层次化的Transformer;其次,它引入了locality的概念,对没有重叠的window区域进行self-attention计算,能够精准地聚焦于每一个角落。

Swin Transformer不仅仅是一个技术革新,它更是一个多才多艺的艺术家,能够灵活地应用于图像分类、目标检测和语义分割等任务,成为这些任务的通用骨干网络。有人说,Swin Transformer可能是CNN的完美替代者,但我认为,它更像是一位能够与CNN并肩作战的伙伴,共同推动计算机视觉技术的发展。

2.主要特点

  • 层次化结构:模型采用层次化的设计,逐步降低特征的空间维度,同时增加特征的深度。
  • 移位窗口自注意力:通过在局部窗口内移动注意力焦点,减少计算量,同时捕获更丰富的上下文信息。
  • 多尺度特征学习:模型能够学习从粗粒度到细粒度的多尺度特征表示。

3.对比

下图为Swin Transformer与ViT在处理图片方式上的对比,可以看出,Swin Transformer有着ResNet一样的残差结构和CNN具有的多尺度图片结构。

二、具体实现

首先Swin-Transformer 以一张图片作为起点,这是它的画布,准备在上面绘制出精彩的图案。

1.Patch Partition 层

在 Patch Partition 层,这张图片被巧妙地拆分成众多小块,就像是将一幅大画卷分解为易于管理的小片段。Patch Partition是模型对输入图像进行预处理的一种重要操作。该操作的主要目的是将原始的连续像素图像分割成一系列固定大小的图像块(patches),以便进一步转化为Transformer可以处理的序列数据。

2.Swin Transfomer

随后,Linear Embedding 层赋予了这些小块以特征的维度,让它们不再是静止的图像,而是活跃的数据点,为之后的表演做好准备。这些特征化的小块进入 Swin Transformer Block,这是第一阶段,它们在这里学会了如何与周围的伙伴协作,共同构建起初步的图像理解。

3.Patch Merging层

接下来的第二至第四阶段,每个阶段开始前,小块们会经历 Patch Merging 的过程,这就像是将多个小故事合并为一个更加宏大的叙事,每一次合并都让图像的表示更加深入和丰富。Patch Merging层主要是进行下采样,产生分层表示。 Patch Merging 是一种减少序列长度并增加每个补丁表示中通道数的操作。

4.AdaptiveAvgPool1d 层和全连接层

在第四阶段的末尾,所有的数据汇集到输出模块,这里有一个 LayerNorm 层,它确保了数据的平衡和稳定,就像是在演出中保持舞者的稳定和优雅。最后,AdaptiveAvgPool1d 层和全连接层相继登场,它们共同作用于数据,最终完成图像的分类,为这场演出画上完美的句点。

三、代码分析

1.ShiftWindowAttentionBlock 类

python 复制代码
class ShiftWindowAttentionBlock(nn.Module):
    def __init__(self, ...):
        ...
    def forward(self, x):
    # patch_num补成能够被window_size整除
    if x.size(-2) % self.window_size:
        x = nn.ZeroPad2d((0, 0, 0, self.window_size - x.size(-2) % self.window_size))(x)

    batch, modal_leng, patch_num, input_dim = x.size()
    short_cut = x # resdual

    # 窗口偏移
    if self.shift_size:
        x = torch.roll(x, shifts=-self.shift_size, dims=2) # 只在 patch_num 上 roll   [batch, modal_leng, patch_num, input_dim]

    # 窗口化 
    window_num = patch_num // self.window_size
    window_x = x.reshape(batch, modal_leng, window_num, self.window_size, input_dim) # [batch, modal_leng, window_num, window_size, input_dim]

    # 基于窗口的多头自注意力
    q = self.query(window_x).reshape(batch, modal_leng, window_num, self.window_size, self.head_num, self.att_size).permute(0, 1, 2, 4, 3, 5) 
    ....

ShiftWindowAttentionBlock 类实现了带有窗口移位的自注意力机制。它接收一个输入张量 x,对其进行自注意力操作,并根据是否启用移位来调整窗口的覆盖范围。

2.SwinTransformer 类

train_shape: 总体训练样本的shape

category: 类别数

embedding_dim: embedding 维度

patch_size: 一个patch长度

head_num: 多头自注意力

att_size: QKV矩阵维度

window_size: 一个窗口包含多少patchs

对于传感窗口数据来讲,在每个单独的模态轴上对时序轴进行patch切分,例如 uci-har 数据集窗口尺寸为 [128, 9],一个patch包含4个数据,那么每个模态轴上的patch_num为32, 总patch数为 32 * 9:

ini 复制代码
class SwinTransformer(nn.Module):
    def __init__(self, train_shape, category, embedding_dim=256, patch_size=4, head_num=4, att_size=64, window_size=8):
        super().__init__()
        self.series_leng = train_shape[-2]
        self.modal_leng = train_shape[-1]
        self.patch_num = self.series_leng // patch_size
        
        self.patch_conv = nn.Conv2d(
            in_channels=1,
            out_channels=embedding_dim,
            kernel_size=(patch_size, 1),
            stride=(patch_size, 1),
            padding=0
        )

        # 位置信息
        self.position_embedding = nn.Parameter(torch.zeros(1, self.modal_leng, self.patch_num, embedding_dim))

        # patch_num维度降采样一次后的计算方式
        swin_transformer_block1_input_patch_num = math.ceil(self.patch_num / window_size) * window_size
        swin_transformer_block2_input_patch_num = math.ceil(math.ceil(swin_transformer_block1_input_patch_num / 2) / window_size) * window_size
        swin_transformer_block3_input_patch_num = math.ceil(math.ceil(swin_transformer_block2_input_patch_num / 2) / window_size) * window_size

        # Shift_Window_Attention_Layer
        # 共3个swin_transformer_block,每个swin_transformer_block对时序维降采样1/2,共降采样1/8
        self.swa = nn.Sequential(
            # swin_transformer_block 1
            nn.Sequential( 
                ShiftWindowAttentionBlock(patch_num=swin_transformer_block1_input_patch_num, input_dim=embedding_dim, head_num=head_num, att_size=att_size, window_size=window_size, shift=False),
                ShiftWindowAttentionBlock(patch_num=swin_transformer_block1_input_patch_num, input_dim=embedding_dim, head_num=head_num, att_size=att_size, window_size=window_size, shift=True)
            ),
            # swin_transformer_block 2
            nn.Sequential(
                ShiftWindowAttentionBlock(patch_num=swin_transformer_block2_input_patch_num, input_dim=embedding_dim, head_num=head_num, att_size=att_size, window_size=window_size, shift=False),
                ShiftWindowAttentionBlock(patch_num=swin_transformer_block2_input_patch_num, input_dim=embedding_dim, head_num=head_num, att_size=att_size, window_size=window_size, shift=True)
            ),
            # swin_transformer_block 3
            nn.Sequential(
                ShiftWindowAttentionBlock(patch_num=swin_transformer_block3_input_patch_num, input_dim=embedding_dim, head_num=head_num, att_size=att_size, window_size=window_size, shift=False),
                ShiftWindowAttentionBlock(patch_num=swin_transformer_block3_input_patch_num, input_dim=embedding_dim, head_num=head_num, att_size=att_size, window_size=window_size, shift=True)
            )
        )

        # classification tower
        self.dense_tower = nn.Sequential(
            nn.Linear(self.modal_leng * math.ceil(swin_transformer_block3_input_patch_num / 2) * embedding_dim, 1024),
            nn.LayerNorm(1024),
            nn.ReLU(),
            nn.Linear(1024, category)
        )

SwinTransformer 类构建了 Swin Transformer 的完整模型。它接收输入数据的形状 train_shape 和类别数 category,以及其他配置参数。

3.模型组件

  • 块卷积 (patch_conv):将输入数据分割成小块,并将其转换成嵌入维度。
  • 位置嵌入 (position_embedding):为每个块添加位置信息,帮助模型捕获空间关系。
  • Swin Transformer 块 (swa) :由多个 ShiftWindowAttentionBlock 组成,逐步降低特征的空间维度,同时增加深度。
  • 分类塔 (dense_tower):在模型的顶层,将特征展平并通过一系列线性层进行分类。

4.前向传播

python 复制代码
def forward(self, x):
    x = self.patch_conv(x) # [batch, embedding_dim, patch_num, modal_leng]
    x = self.position_embedding + x.permute(0, 3, 2, 1) # [batch, modal_leng, patch_num, embedding_dim]
    x = self.swa(x)
    x = nn.Flatten()(x)
    x = self.dense_tower(x)
    return x

forward 方法定义了模型的前向传播过程:

  1. 块卷积:输入数据通过卷积操作转换成嵌入维度。
  2. 位置嵌入:将位置信息添加到块特征中。
  3. Swin Transformer 块:通过多个 Swin Transformer 块进行特征提取。
  4. 分类塔:在模型顶部,将特征展平并通过线性层进行分类。

Swin Transformer 是一种创新的模型,它将 Transformer 架构的优势引入到计算机视觉领域。通过层次化处理和高效的自注意力机制,Swin Transformer 在多个视觉任务上展现出卓越的性能。提供的代码实现了 Swin Transformer 的核心功能,为进一步的研究和应用提供了基础。

四、致谢

本文和代码实现基于 Swin Transformer 的原始论文和相关研究工作。感谢所有为深度学习和计算机视觉领域做出贡献的研究人员和开发者。

注意 :具体的代码实现和模型细节可以联系作者获取,以便进一步的研究和应用。本文首发于稀土掘金,未经允许禁止转发和二次创作,侵权必究。

相关推荐
AI算法-图哥5 分钟前
pytorch量化训练
人工智能·pytorch·深度学习·文生图·模型压缩·量化
大山同学8 分钟前
DPGO:异步和并行分布式位姿图优化 2020 RA-L best paper
人工智能·分布式·语言模型·去中心化·slam·感知定位
机器学习之心8 分钟前
时序预测 | 改进图卷积+informer时间序列预测,pytorch架构
人工智能·pytorch·python·时间序列预测·informer·改进图卷积
天飓35 分钟前
基于OpenCV的自制Python访客识别程序
人工智能·python·opencv
檀越剑指大厂37 分钟前
开源AI大模型工作流神器Flowise本地部署与远程访问
人工智能·开源
声网40 分钟前
「人眼视觉不再是视频消费的唯一形式」丨智能编解码和 AI 视频生成专场回顾@RTE2024
人工智能·音视频
newxtc1 小时前
【AiPPT-注册/登录安全分析报告-无验证方式导致安全隐患】
人工智能·安全·ai写作·极验·行为验证
技术仔QAQ1 小时前
【tokenization分词】WordPiece, Byte-Pair Encoding(BPE), Byte-level BPE(BBPE)的原理和代码
人工智能·python·gpt·语言模型·自然语言处理·开源·nlp
神一样的老师1 小时前
去中心化联邦学习与TinyML联合调查:群学习简介
机器学习
陌上阳光1 小时前
动手学深度学习70 BERT微调
人工智能·深度学习·bert