图像修复-SwinIR: Image Restoration Using Swin Transformer

图像修复-SwinIR: Image Restoration Using Swin Transformer

SwinIR 是一个专门用于图像修复任务的基线模型,它基于Swin Transformer 架构。相比于基于卷积神经网络的传统方法,SwinIR利用了Transformer在高层次视觉任务中的优异表现。

文章目录

在阅读本篇文章之前,必须对Swin Transformer架构有一定了解,可以查看Swin Transformer详情

SwinIR架构图

三部分组成浅层特征提取 (shallow feature extraction)、深层特征提取 (deep feature extraction)、深层特征提取(deep feature extraction)

浅层特征提取(shallow feature extraction):负责从输入的低质量图像中提取初始特征。

深层特征提取 (deep feature extraction):由多个残差Swin Transformer块(Residual Swin Transformer Blocks,RSTB)组成,每个块内部包含多个Swin Transformer层,结合了残差连接。这一模块用于提取深层次的图像特征。

高质量图像重建(high-quality image reconstruction):将深层次提取的特征转换为高质量的输出图像。

浅层特征提取

首先,对于给定的低质量图像LQ∈H×W×C(其中 H 为高度,W 为宽度,C为输入通道数),通过一个 3 × 3 卷积层 HSF 提取浅层特征 F0。卷积层通过对早期视觉处理的良好效果,稳定了优化过程并提升了结果,同时将输入从图像空间映射到高维特征空间。卷积层通过对早期视觉处理的良好效果,稳定了优化过程并提升了结果,同时将输入从图像空间映射到高维特征空间。

架构代码
python 复制代码
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) # 就是简单的卷积,下采样
深层特征提取

由上图中的(a)和(b)作出详细解释,6个STL之后连接一个Conv,再加上残差构成一个RSTB,深层则由6个RSTB组成,后面连接一Conv和残差构成深层特征提取。
STL也就是Swin Transformer Layer,简单来说由两部分组成,就是图中MSA(实际上是一对W-MSA和SW-MSA),然后连接归一化和MPL,同时加上残差连接。
Swin Transformer详情

具体可以看:

架构代码
python 复制代码
self.num_layers = len(depths)  # 设置网络层数,基于给定的 depths 列表长度
self.embed_dim = embed_dim  # 设置嵌入维度
self.ape = ape  # 是否使用绝对位置编码(absolute position embedding)
self.patch_norm = patch_norm  # 是否对patch进行归一化
self.num_features = embed_dim  # 特征通道数等于嵌入维度
self.mlp_ratio = mlp_ratio  # MLP比例系数

# 将图像划分为不重叠的patches
self.patch_embed = PatchEmbed(
    img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
    norm_layer=norm_layer if self.patch_norm else None)  # 设置patch嵌入模块
num_patches = self.patch_embed.num_patches  # 获取patch数量
patches_resolution = self.patch_embed.patches_resolution  # 获取patches的分辨率
self.patches_resolution = patches_resolution  # 保存patches的分辨率

# 将不重叠的patches合并回图像
self.patch_unembed = PatchUnEmbed(
    img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
    norm_layer=norm_layer if self.patch_norm else None)  # 设置patch解嵌入模块

# 绝对位置编码
if self.ape:
    self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))  # 初始化绝对位置编码参数
    trunc_normal_(self.absolute_pos_embed, std=.02)  # 对位置编码进行截断正态分布初始化

    self.pos_drop = nn.Dropout(p=drop_rate)  # 位置编码的dropout层

    # 随机深度(Stochastic Depth)
    dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # 随机深度的递减规则

    # 构建残差Swin Transformer块(RSTB)
    self.layers = nn.ModuleList()  # 保存网络层
    for i_layer in range(self.num_layers):
        # 初始化每层的RSTB模块
        layer = RSTB(dim=embed_dim,
                     input_resolution=(patches_resolution[0], patches_resolution[1]),  # 输入分辨率
                     depth=depths[i_layer],  # 当前层的深度
                     num_heads=num_heads[i_layer],  # 多头注意力机制的头数
                     window_size=window_size,  # 窗口大小
                     mlp_ratio=self.mlp_ratio,  # MLP的比例系数
                     qkv_bias=qkv_bias, qk_scale=qk_scale,  # QKV相关参数
                     drop=drop_rate, attn_drop=attn_drop_rate,  # dropout参数
                     drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],  # 随机深度的drop路径
                     norm_layer=norm_layer,  # 归一化层
                     downsample=None,  # 不进行下采样
                     use_checkpoint=use_checkpoint,  # 是否使用梯度检查点
                     img_size=img_size,  # 图像大小
                     patch_size=patch_size,  # patch大小
                     resi_connection=resi_connection  # 残差连接类型
                    )
        self.layers.append(layer)  # 将构建的层添加到层列表中
        self.norm = norm_layer(self.num_features)  # 为每层添加归一化操作

        # 构建深度特征提取中的最后一个卷积层
        if resi_connection == '1conv':
            self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)  # 单层卷积作为残差连接
        elif resi_connection == '3conv':
            # 为了节省参数和内存,使用三层卷积残差连接
            self.conv_after_body = nn.Sequential(
                nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),  # 第一个卷积层,减少通道数
                nn.LeakyReLU(negative_slope=0.2, inplace=True),  # 激活函数LeakyReLU
                nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),  # 第二个1x1卷积层,保持通道数
                nn.LeakyReLU(negative_slope=0.2, inplace=True),  # 激活函数LeakyReLU
                nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)  # 第三个卷积层,恢复通道数
            )
高质量图像重建

图像恢复中的高质量图像重建,针对不同任务(如超分辨率、去噪和JPEG伪影去除)选择不同的上采样策略。

  • pixelshuffle 模式用于经典超分辨率任务,通过 Pixel Shuffle 技术上采样。
  • pixelshuffledirect 模式适用于轻量级超分辨率任务,减少参数量。
  • nearest+conv 模式则针对真实世界超分辨率,结合最近邻插值和卷积减少伪影。而对于去噪和伪影去除等任务,直接通过卷积层输出高质量图像。
架构代码
python 复制代码
if self.upsampler == 'pixelshuffle':
    # 针对经典的超分辨率(SR)
    # 在上采样之前的卷积层,3x3卷积用于特征处理,LeakyReLU用于激活
    self.conv_before_upsample = nn.Sequential(
        nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
        nn.LeakyReLU(inplace=True)
    )
    # 使用PixelShuffle进行上采样
    self.upsample = Upsample(upscale, num_feat)
    # 最后的卷积层,用于生成输出图像,3x3卷积
    self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)

elif self.upsampler == 'pixelshuffledirect':
    # 针对轻量级超分辨率(SR),为了减少参数量
    # 使用一步到位的PixelShuffle直接上采样
    self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
                                    (patches_resolution[0], patches_resolution[1]))

elif self.upsampler == 'nearest+conv':
    # 针对真实世界超分辨率(SR),减少伪影
    # 在上采样之前的卷积层,3x3卷积处理特征,LeakyReLU用于激活
    self.conv_before_upsample = nn.Sequential(
        nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
        nn.LeakyReLU(inplace=True)
    )
    # 第一个上采样卷积层,使用3x3卷积
    self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
    
    if self.upscale == 4:
        # 如果上采样比例为4,则需要第二个卷积层
        self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        # 高分辨率特征的卷积层,3x3卷积
        self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        # 最后的卷积层,生成最终输出图像,3x3卷积
        self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
        # LeakyReLU激活函数
        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
else:
    # 针对图像去噪和JPEG压缩伪影减少
    # 最后的卷积层,3x3卷积,用于生成输出图像
    self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
损失函数

不同任务下的损失函数根据具体需求有所调整,SR任务中使用L1像素损失,真实世界SR中增加了GAN和感知损失,去噪与压缩伪影去除任务中采用Charbonnier损失以增强稳定性。

图像超分辨率(SR)
图像去噪和JPEG压缩伪影去除

需要源码讲解可以联系我

相关推荐
handsomestWei1 分钟前
ISP图像处理简介
图像处理
深圳南柯电子6 分钟前
深圳南柯电子|电子设备EMC测试整改:常见问题与解决方案
人工智能
Kai HVZ7 分钟前
《OpenCV计算机视觉》--介绍及基础操作
人工智能·opencv·计算机视觉
biter008812 分钟前
opencv(15) OpenCV背景减除器(Background Subtractors)学习
人工智能·opencv·学习
吃个糖糖18 分钟前
35 Opencv 亚像素角点检测
人工智能·opencv·计算机视觉
qq_5290252936 分钟前
Torch.gather
python·深度学习·机器学习
IT古董1 小时前
【漫话机器学习系列】017.大O算法(Big-O Notation)
人工智能·机器学习
凯哥是个大帅比1 小时前
人工智能ACA(五)--深度学习基础
人工智能·深度学习
m0_748232921 小时前
DALL-M:基于大语言模型的上下文感知临床数据增强方法 ,补充
人工智能·语言模型·自然语言处理