图像修复-SwinIR: Image Restoration Using Swin Transformer

图像修复-SwinIR: Image Restoration Using Swin Transformer

SwinIR 是一个专门用于图像修复任务的基线模型,它基于Swin Transformer 架构。相比于基于卷积神经网络的传统方法,SwinIR利用了Transformer在高层次视觉任务中的优异表现。

文章目录

在阅读本篇文章之前,必须对Swin Transformer架构有一定了解,可以查看Swin Transformer详情

SwinIR架构图

三部分组成浅层特征提取 (shallow feature extraction)、深层特征提取 (deep feature extraction)、深层特征提取(deep feature extraction)

浅层特征提取(shallow feature extraction):负责从输入的低质量图像中提取初始特征。

深层特征提取 (deep feature extraction):由多个残差Swin Transformer块(Residual Swin Transformer Blocks,RSTB)组成,每个块内部包含多个Swin Transformer层,结合了残差连接。这一模块用于提取深层次的图像特征。

高质量图像重建(high-quality image reconstruction):将深层次提取的特征转换为高质量的输出图像。

浅层特征提取

首先,对于给定的低质量图像LQ∈H×W×C(其中 H 为高度,W 为宽度,C为输入通道数),通过一个 3 × 3 卷积层 HSF 提取浅层特征 F0。卷积层通过对早期视觉处理的良好效果,稳定了优化过程并提升了结果,同时将输入从图像空间映射到高维特征空间。卷积层通过对早期视觉处理的良好效果,稳定了优化过程并提升了结果,同时将输入从图像空间映射到高维特征空间。

架构代码
python 复制代码
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) # 就是简单的卷积,下采样
深层特征提取

由上图中的(a)和(b)作出详细解释,6个STL之后连接一个Conv,再加上残差构成一个RSTB,深层则由6个RSTB组成,后面连接一Conv和残差构成深层特征提取。
STL也就是Swin Transformer Layer,简单来说由两部分组成,就是图中MSA(实际上是一对W-MSA和SW-MSA),然后连接归一化和MPL,同时加上残差连接。
Swin Transformer详情

具体可以看:

架构代码
python 复制代码
self.num_layers = len(depths)  # 设置网络层数,基于给定的 depths 列表长度
self.embed_dim = embed_dim  # 设置嵌入维度
self.ape = ape  # 是否使用绝对位置编码(absolute position embedding)
self.patch_norm = patch_norm  # 是否对patch进行归一化
self.num_features = embed_dim  # 特征通道数等于嵌入维度
self.mlp_ratio = mlp_ratio  # MLP比例系数

# 将图像划分为不重叠的patches
self.patch_embed = PatchEmbed(
    img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
    norm_layer=norm_layer if self.patch_norm else None)  # 设置patch嵌入模块
num_patches = self.patch_embed.num_patches  # 获取patch数量
patches_resolution = self.patch_embed.patches_resolution  # 获取patches的分辨率
self.patches_resolution = patches_resolution  # 保存patches的分辨率

# 将不重叠的patches合并回图像
self.patch_unembed = PatchUnEmbed(
    img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
    norm_layer=norm_layer if self.patch_norm else None)  # 设置patch解嵌入模块

# 绝对位置编码
if self.ape:
    self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))  # 初始化绝对位置编码参数
    trunc_normal_(self.absolute_pos_embed, std=.02)  # 对位置编码进行截断正态分布初始化

    self.pos_drop = nn.Dropout(p=drop_rate)  # 位置编码的dropout层

    # 随机深度(Stochastic Depth)
    dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # 随机深度的递减规则

    # 构建残差Swin Transformer块(RSTB)
    self.layers = nn.ModuleList()  # 保存网络层
    for i_layer in range(self.num_layers):
        # 初始化每层的RSTB模块
        layer = RSTB(dim=embed_dim,
                     input_resolution=(patches_resolution[0], patches_resolution[1]),  # 输入分辨率
                     depth=depths[i_layer],  # 当前层的深度
                     num_heads=num_heads[i_layer],  # 多头注意力机制的头数
                     window_size=window_size,  # 窗口大小
                     mlp_ratio=self.mlp_ratio,  # MLP的比例系数
                     qkv_bias=qkv_bias, qk_scale=qk_scale,  # QKV相关参数
                     drop=drop_rate, attn_drop=attn_drop_rate,  # dropout参数
                     drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],  # 随机深度的drop路径
                     norm_layer=norm_layer,  # 归一化层
                     downsample=None,  # 不进行下采样
                     use_checkpoint=use_checkpoint,  # 是否使用梯度检查点
                     img_size=img_size,  # 图像大小
                     patch_size=patch_size,  # patch大小
                     resi_connection=resi_connection  # 残差连接类型
                    )
        self.layers.append(layer)  # 将构建的层添加到层列表中
        self.norm = norm_layer(self.num_features)  # 为每层添加归一化操作

        # 构建深度特征提取中的最后一个卷积层
        if resi_connection == '1conv':
            self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)  # 单层卷积作为残差连接
        elif resi_connection == '3conv':
            # 为了节省参数和内存,使用三层卷积残差连接
            self.conv_after_body = nn.Sequential(
                nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),  # 第一个卷积层,减少通道数
                nn.LeakyReLU(negative_slope=0.2, inplace=True),  # 激活函数LeakyReLU
                nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),  # 第二个1x1卷积层,保持通道数
                nn.LeakyReLU(negative_slope=0.2, inplace=True),  # 激活函数LeakyReLU
                nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)  # 第三个卷积层,恢复通道数
            )
高质量图像重建

图像恢复中的高质量图像重建,针对不同任务(如超分辨率、去噪和JPEG伪影去除)选择不同的上采样策略。

  • pixelshuffle 模式用于经典超分辨率任务,通过 Pixel Shuffle 技术上采样。
  • pixelshuffledirect 模式适用于轻量级超分辨率任务,减少参数量。
  • nearest+conv 模式则针对真实世界超分辨率,结合最近邻插值和卷积减少伪影。而对于去噪和伪影去除等任务,直接通过卷积层输出高质量图像。
架构代码
python 复制代码
if self.upsampler == 'pixelshuffle':
    # 针对经典的超分辨率(SR)
    # 在上采样之前的卷积层,3x3卷积用于特征处理,LeakyReLU用于激活
    self.conv_before_upsample = nn.Sequential(
        nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
        nn.LeakyReLU(inplace=True)
    )
    # 使用PixelShuffle进行上采样
    self.upsample = Upsample(upscale, num_feat)
    # 最后的卷积层,用于生成输出图像,3x3卷积
    self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)

elif self.upsampler == 'pixelshuffledirect':
    # 针对轻量级超分辨率(SR),为了减少参数量
    # 使用一步到位的PixelShuffle直接上采样
    self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
                                    (patches_resolution[0], patches_resolution[1]))

elif self.upsampler == 'nearest+conv':
    # 针对真实世界超分辨率(SR),减少伪影
    # 在上采样之前的卷积层,3x3卷积处理特征,LeakyReLU用于激活
    self.conv_before_upsample = nn.Sequential(
        nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
        nn.LeakyReLU(inplace=True)
    )
    # 第一个上采样卷积层,使用3x3卷积
    self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
    
    if self.upscale == 4:
        # 如果上采样比例为4,则需要第二个卷积层
        self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        # 高分辨率特征的卷积层,3x3卷积
        self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        # 最后的卷积层,生成最终输出图像,3x3卷积
        self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
        # LeakyReLU激活函数
        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
else:
    # 针对图像去噪和JPEG压缩伪影减少
    # 最后的卷积层,3x3卷积,用于生成输出图像
    self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
损失函数

不同任务下的损失函数根据具体需求有所调整,SR任务中使用L1像素损失,真实世界SR中增加了GAN和感知损失,去噪与压缩伪影去除任务中采用Charbonnier损失以增强稳定性。

图像超分辨率(SR)
图像去噪和JPEG压缩伪影去除

需要源码讲解可以联系我

相关推荐
机器学习之心2 分钟前
198种组合算法+优化CNN-LSTM+SHAP分析+新数据预测+多输出!深度学习可解释分析,强烈安利,粉丝必备
深度学习·算法·cnn-lstm·shap分析·198种组合算法
我登哥MVP2 分钟前
VS Code 安装 Claude Code 并接入 DeepSeek V4 Model
人工智能·python·node.js·agent·codex·deepseek·claude code
unique3 分钟前
AI Native 调研报告
人工智能
云烟成雨TD3 分钟前
Spring AI Alibaba 1.x 系列【73】两步 RAG
java·人工智能·spring
ai产品老杨4 分钟前
解耦视频高并发与边缘计算AI布控:基于Docker的高性能安防平台,破局GB28181/RTSP协议兼容与源码交付痛点
人工智能·音视频·边缘计算
CHrisFC6 分钟前
LIMS 系统 AI 建设路径:从自动化到智能化的演进之路
运维·人工智能·自动化
饼干哥哥6 分钟前
一口气搭了300个AI Agents并发处理跨境运营的dirty work
人工智能
AI行业学习7 分钟前
CC‑Switch v3.16.1-下载、配置、安装(2026‑06‑01 最新官方版)
开发语言·人工智能·windows·python
小糖学代码8 分钟前
机器学习:5.深度学习
人工智能·深度学习·机器学习
轮子飞了10 分钟前
Spring Ai 集成 DashScope 多模态模型实现身份证信息识别
java·人工智能·spring