图像修复-SwinIR: Image Restoration Using Swin Transformer
SwinIR 是一个专门用于图像修复任务的基线模型,它基于Swin Transformer 架构。相比于基于卷积神经网络的传统方法,SwinIR利用了Transformer在高层次视觉任务中的优异表现。
文章目录
-
-
- [图像修复-SwinIR: Image Restoration Using Swin Transformer](#图像修复-SwinIR: Image Restoration Using Swin Transformer)
-
在阅读本篇文章之前,必须对Swin Transformer架构有一定了解,可以查看Swin Transformer详情
SwinIR架构图
三部分组成 :浅层特征提取 (shallow feature extraction)、深层特征提取 (deep feature extraction)、深层特征提取(deep feature extraction)
浅层特征提取(shallow feature extraction):负责从输入的低质量图像中提取初始特征。
深层特征提取 (deep feature extraction):由多个残差Swin Transformer块(Residual Swin Transformer Blocks,RSTB)组成,每个块内部包含多个Swin Transformer层,结合了残差连接。这一模块用于提取深层次的图像特征。
高质量图像重建(high-quality image reconstruction):将深层次提取的特征转换为高质量的输出图像。
浅层特征提取
首先,对于给定的低质量图像LQ∈H×W×C(其中 H 为高度,W 为宽度,C为输入通道数),通过一个 3 × 3 卷积层 HSF 提取浅层特征 F0。卷积层通过对早期视觉处理的良好效果,稳定了优化过程并提升了结果,同时将输入从图像空间映射到高维特征空间。卷积层通过对早期视觉处理的良好效果,稳定了优化过程并提升了结果,同时将输入从图像空间映射到高维特征空间。
架构代码
python
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) # 就是简单的卷积,下采样
深层特征提取
由上图中的(a)和(b)作出详细解释,6个STL之后连接一个Conv,再加上残差构成一个RSTB,深层则由6个RSTB组成,后面连接一Conv和残差构成深层特征提取。
STL也就是Swin Transformer Layer,简单来说由两部分组成,就是图中MSA(实际上是一对W-MSA和SW-MSA),然后连接归一化和MPL,同时加上残差连接。
Swin Transformer详情具体可以看:
架构代码
python
self.num_layers = len(depths) # 设置网络层数,基于给定的 depths 列表长度
self.embed_dim = embed_dim # 设置嵌入维度
self.ape = ape # 是否使用绝对位置编码(absolute position embedding)
self.patch_norm = patch_norm # 是否对patch进行归一化
self.num_features = embed_dim # 特征通道数等于嵌入维度
self.mlp_ratio = mlp_ratio # MLP比例系数
# 将图像划分为不重叠的patches
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None) # 设置patch嵌入模块
num_patches = self.patch_embed.num_patches # 获取patch数量
patches_resolution = self.patch_embed.patches_resolution # 获取patches的分辨率
self.patches_resolution = patches_resolution # 保存patches的分辨率
# 将不重叠的patches合并回图像
self.patch_unembed = PatchUnEmbed(
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None) # 设置patch解嵌入模块
# 绝对位置编码
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) # 初始化绝对位置编码参数
trunc_normal_(self.absolute_pos_embed, std=.02) # 对位置编码进行截断正态分布初始化
self.pos_drop = nn.Dropout(p=drop_rate) # 位置编码的dropout层
# 随机深度(Stochastic Depth)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # 随机深度的递减规则
# 构建残差Swin Transformer块(RSTB)
self.layers = nn.ModuleList() # 保存网络层
for i_layer in range(self.num_layers):
# 初始化每层的RSTB模块
layer = RSTB(dim=embed_dim,
input_resolution=(patches_resolution[0], patches_resolution[1]), # 输入分辨率
depth=depths[i_layer], # 当前层的深度
num_heads=num_heads[i_layer], # 多头注意力机制的头数
window_size=window_size, # 窗口大小
mlp_ratio=self.mlp_ratio, # MLP的比例系数
qkv_bias=qkv_bias, qk_scale=qk_scale, # QKV相关参数
drop=drop_rate, attn_drop=attn_drop_rate, # dropout参数
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # 随机深度的drop路径
norm_layer=norm_layer, # 归一化层
downsample=None, # 不进行下采样
use_checkpoint=use_checkpoint, # 是否使用梯度检查点
img_size=img_size, # 图像大小
patch_size=patch_size, # patch大小
resi_connection=resi_connection # 残差连接类型
)
self.layers.append(layer) # 将构建的层添加到层列表中
self.norm = norm_layer(self.num_features) # 为每层添加归一化操作
# 构建深度特征提取中的最后一个卷积层
if resi_connection == '1conv':
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) # 单层卷积作为残差连接
elif resi_connection == '3conv':
# 为了节省参数和内存,使用三层卷积残差连接
self.conv_after_body = nn.Sequential(
nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), # 第一个卷积层,减少通道数
nn.LeakyReLU(negative_slope=0.2, inplace=True), # 激活函数LeakyReLU
nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), # 第二个1x1卷积层,保持通道数
nn.LeakyReLU(negative_slope=0.2, inplace=True), # 激活函数LeakyReLU
nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1) # 第三个卷积层,恢复通道数
)
高质量图像重建
图像恢复中的高质量图像重建,针对不同任务(如超分辨率、去噪和JPEG伪影去除)选择不同的上采样策略。
pixelshuffle
模式用于经典超分辨率任务,通过 Pixel Shuffle 技术上采样。pixelshuffledirect
模式适用于轻量级超分辨率任务,减少参数量。nearest+conv
模式则针对真实世界超分辨率,结合最近邻插值和卷积减少伪影。而对于去噪和伪影去除等任务,直接通过卷积层输出高质量图像。
架构代码
python
if self.upsampler == 'pixelshuffle':
# 针对经典的超分辨率(SR)
# 在上采样之前的卷积层,3x3卷积用于特征处理,LeakyReLU用于激活
self.conv_before_upsample = nn.Sequential(
nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True)
)
# 使用PixelShuffle进行上采样
self.upsample = Upsample(upscale, num_feat)
# 最后的卷积层,用于生成输出图像,3x3卷积
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
elif self.upsampler == 'pixelshuffledirect':
# 针对轻量级超分辨率(SR),为了减少参数量
# 使用一步到位的PixelShuffle直接上采样
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
(patches_resolution[0], patches_resolution[1]))
elif self.upsampler == 'nearest+conv':
# 针对真实世界超分辨率(SR),减少伪影
# 在上采样之前的卷积层,3x3卷积处理特征,LeakyReLU用于激活
self.conv_before_upsample = nn.Sequential(
nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True)
)
# 第一个上采样卷积层,使用3x3卷积
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
if self.upscale == 4:
# 如果上采样比例为4,则需要第二个卷积层
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
# 高分辨率特征的卷积层,3x3卷积
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
# 最后的卷积层,生成最终输出图像,3x3卷积
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
# LeakyReLU激活函数
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
else:
# 针对图像去噪和JPEG压缩伪影减少
# 最后的卷积层,3x3卷积,用于生成输出图像
self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
损失函数
不同任务下的损失函数根据具体需求有所调整,SR任务中使用L1像素损失,真实世界SR中增加了GAN和感知损失,去噪与压缩伪影去除任务中采用Charbonnier损失以增强稳定性。
图像超分辨率(SR)
图像去噪和JPEG压缩伪影去除
需要源码讲解可以联系我