x-restormer——restormer+SSA

github: GitHub - Andrew0613/X-Restormer: ECCV2024:A Comparative Study of Image Restoration Networks for General Backbone Network Design

和NAFnet一样,这篇文章也是在对比了当前的几种SOTA算法之后,以restormer为基准,博采众长,得到了新的结构,力求对各种任务有较好的鲁棒性。

主流结构

主流的结构有三种:U-shape encoder-decoder, plain residual-in-residual and multi-stage progressive。

Unet的优势是有下采样有上采样,可以适应不同的scale,并且通过下采样增加了感受野,代表性网络有Uformer [43], Restormer。

multi-stage architecture把网络分成了几个子网络,渐进地处理特征,代表性的网络有MPRNet [49] and HINet,主要用于derraining和deblurring。

plain residual-in-residual architecture用了很多残差结构,用于高频特征的重建,代表性有RCAN [57] and SwinIR

主流算子

从算子的角度,主要有convolution, spatial self-attention and transposed self-attention.。

卷积是最熟悉的,它具有平移不变性,以固定大小的filter在图中做平移,加权得到特征。

Spatial self-attention也是基于window的,但是会生成和图像内容相关content-aware的权重,所以有局部适应能力。

Transposed self-attention则是把attention放在了通道层面,所以直接处理的是global features,再结合depth-wise convolution,对重建任务有很好的表现。

Xrestormer vs restormer

Restormer相比于SwinIR,恢复高频细节的能力稍差(即便是重复纹理),原因一方面是使用Unet加大了高频信息的重建,一方面是使用depth-wise使得空间信息的提取能力下降。unet是indispensable的,所以为了提升restormer对spatial information的利用能力,把一半的transposed self-attention blocks (TSAB) 改为了 spatial self-attention blocks (SSAB):

所以和Restormer对比,整体结构是一样的,只不过是把FN+TA的组合重复了两次,并且在第二次中把TA换成了SA。并且在每个最小的模块中都使用了残差连接。

TSA

TSA就直接使用原来的Multi-Dconv Transpose Attention (MDTA)。

##########################################################################
## Multi-DConv Head Transposed Self-Attention (MDTA)
class ChannelAttention(nn.Module):
    def __init__(self, dim, num_heads, bias):
        super(ChannelAttention, self).__init__()
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))

        self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
        self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
        self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)

    def forward(self, x):
        b,c,h,w = x.shape

        qkv = self.qkv_dwconv(self.qkv(x))  # 通道扩大3倍
        q,k,v = qkv.chunk(3, dim=1)         # 通道维度的拆分
        
        # q,k,v 都是b head c (h w)的形式
        q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)

        q = torch.nn.functional.normalize(q, dim=-1)
        k = torch.nn.functional.normalize(k, dim=-1)

        attn = (q @ k.transpose(-2, -1)) * self.temperature  # 内积只关注最后两维,所以k要交换维度,内积之后得到cxc的方阵
        attn = attn.softmax(dim=-1)

        out = (attn @ v)

        out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)

        out = self.project_out(out)
        return out

TSA避免了在空域做attention是通过rearrange把h,w放在了同一个维度下。

SSA

而对于SSA,使用的是HAT model中引入的Overlapping Cross-Attention (OCA)。

在实现上,和TSA一样,也是使用卷积,把通道扩充到3的整数倍,然后再chunk分成q,k,v。与TSA的区别是之后的操作:

##########################################################################
## Overlapping Cross-Attention (OCA)
class OCAB(nn.Module):
    def __init__(self, dim, window_size, overlap_ratio, num_heads, dim_head, bias):
        super(OCAB, self).__init__()
        self.num_spatial_heads = num_heads
        self.dim = dim
        self.window_size = window_size
        self.overlap_win_size = int(window_size * overlap_ratio) + window_size
        self.dim_head = dim_head
        self.inner_dim = self.dim_head * self.num_spatial_heads
        self.scale = self.dim_head**-0.5

        self.unfold = nn.Unfold(kernel_size=(self.overlap_win_size, self.overlap_win_size), stride=window_size, padding=(self.overlap_win_size-window_size)//2)
        self.qkv = nn.Conv2d(self.dim, self.inner_dim*3, kernel_size=1, bias=bias)
        self.project_out = nn.Conv2d(self.inner_dim, dim, kernel_size=1, bias=bias)
        self.rel_pos_emb = RelPosEmb(
            block_size = window_size,
            rel_size = window_size + (self.overlap_win_size - window_size),
            dim_head = self.dim_head
        )
    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.qkv(x)   # 通道升高为3的倍数
        qs, ks, vs = qkv.chunk(3, dim=1)  # 在通道维度拆分

        # spatial attention
        qs = rearrange(qs, 'b c (h p1) (w p2) -> (b h w) (p1 p2) c', p1 = self.window_size, p2 = self.window_size)
        ks, vs = map(lambda t: self.unfold(t), (ks, vs))
        ks, vs = map(lambda t: rearrange(t, 'b (c j) i -> (b i) j c', c = self.inner_dim), (ks, vs))

        # print(f'qs.shape:{qs.shape}, ks.shape:{ks.shape}, vs.shape:{vs.shape}')
        #split heads
        qs, ks, vs = map(lambda t: rearrange(t, 'b n (head c) -> (b head) n c', head = self.num_spatial_heads), (qs, ks, vs))

        # attention
        qs = qs * self.scale
        spatial_attn = (qs @ ks.transpose(-2, -1))
        spatial_attn += self.rel_pos_emb(qs)
        spatial_attn = spatial_attn.softmax(dim=-1)

        out = (spatial_attn @ vs)

        out = rearrange(out, '(b h w head) (p1 p2) c -> b (head c) (h p1) (w p2)', head = self.num_spatial_heads, h = h // self.window_size, w = w // self.window_size, p1 = self.window_size, p2 = self.window_size)

        # merge spatial and channel
        out = self.project_out(out)

        return out

unfold函数,是把空域的信息结果卷积时候覆盖的范围,框定好,再flaten到通道的维度。可以看这个例子。对k和v都做了unflod的操作。unfold的结果是3维的 :(N,C×∏(kernel_size),L),表示为(b, (c,j),i)

因为使用了unfold,是卷积滑窗的类型,所以K,V是有overlap的windows。而Q,由输入拆分成互不重叠的windows得到。

相同的是q和k,v都rearrange为c在最后一个维度,倒数第二个维度则是和空间window尺寸相关的:

        qs = rearrange(qs, 'b c (h p1) (w p2) -> (b h w) (p1 p2) c', p1 = self.window_size, p2 = self.window_size)
        ks, vs = map(lambda t: self.unfold(t), (ks, vs))
        ks, vs = map(lambda t: rearrange(t, 'b (c j) i -> (b i) j c', c = self.inner_dim), (ks, vs))

内积之前把q,k,v都转为多头的,

out = rearrange(out, '(b h w head) (p1 p2) c -> b (head c) (h p1) (w p2)', head = self.num_spatial_heads, h = h // self.window_size, w = w // self.window_size, p1 = self.window_size, p2 = self.window_size)

内积还需要把k进行倒数第二和倒数第一维度的交换:

spatial_attn = (qs @ ks.transpose(-2, -1))

内积的结果再加上位置编码,结果softmax再作为权重:

        spatial_attn += self.rel_pos_emb(qs)
        spatial_attn = spatial_attn.softmax(dim=-1)

权重同样使用内积加在v上:

out = (spatial_attn @ vs)
相关推荐
@心都几秒前
机器学习数学基础:29.t检验
人工智能·机器学习
9命怪猫3 分钟前
DeepSeek底层揭秘——微调
人工智能·深度学习·神经网络·ai·大模型
kcarly1 小时前
KTransformers如何通过内核级优化、多GPU并行策略和稀疏注意力等技术显著加速大语言模型的推理速度?
人工智能·语言模型·自然语言处理
倒霉蛋小马3 小时前
【YOLOv8】损失函数
深度学习·yolo·机器学习
MinIO官方账号3 小时前
使用 AIStor 和 OpenSearch 增强搜索功能
人工智能
补三补四3 小时前
金融时间序列【量化理论】
机器学习·金融·数据分析·时间序列
江江江江江江江江江3 小时前
深度神经网络终极指南:从数学本质到工业级实现(附Keras版本代码)
人工智能·keras·dnn
Fansv5873 小时前
深度学习-2.机械学习基础
人工智能·经验分享·python·深度学习·算法·机器学习
小怪兽会微笑4 小时前
PyTorch Tensor 形状变化操作详解
人工智能·pytorch·python
Erekys4 小时前
视觉分析之边缘检测算法
人工智能·计算机视觉·音视频