【论文阅读笔记】Activating More Pixels in Image Super-Resolution Transformer

论文地址:https://arxiv.org/abs/2205.04437

代码位置:https://github.com/XPixelGroup/HAT

论文小结

本文方法是基于Transformer的方法,探索了Transformer在低级视觉任务(如SR)中的应用潜力。本文提升有效利用像素范围得到的网络,提出了一种混合注意力Transformer,命名为 HAT (Hybrid Attention Transformer)。

本文的混合注意力Transformer结合channel attention和基于窗口的self-attention方案,所以认为是结合了全局数据统计和局部拟合能力的互补优势。为了进一步聚合跨窗口的信息,作者使用了重叠交叉注意模块来增加相邻窗口特征之间的交流。

本文方法比当时最先进的方法领先PSNR指标1dB以上。

论文简介

使用LAM方法测试(可以得到选择区域哪些像素贡献了最多),得到结论:虽然swinIR的平均指标更高,但基于transformer的swinIR的信息利用范围并不比基于CNN的RCAN方法要大 ,如下图所示。有效信息范围较小,但指标高,可能可以得出SwinIR比CNN方法拥有更大的映射能力的结论。但与此同时,由于其利用像素区域的范围有限,可能会恢复出错误的纹理。所以本文设计网络的时候考虑了在使用近self-attention结构的时候利用更多的像素用于重构

查看swinIR的中间状态,有的样本的中间状态会有块效应。这表明移位窗口机制不能完美地实现跨窗口信息交互。所有本文加强了滑动窗口之间的连接,这样可以改善基于窗口的self-attention方法。

为了解决上述问题,释放Transformer在SR方向上的潜力,作者提出了HAT,混合注意力Transformer。HAT结合了channel attention和self-attention,意图使用前者的全局信息能力和后者强大的表达能力。此外还引入了重叠交叉注意模块来实现相邻窗口特征更直接的交互。

由于Transformer不像CNN那样具有归纳偏差,因此使用大规模数据预训练的模型参数对于释放此类模型的潜力非常重要。在本文中,作者提供了一种有效的同任务预训练策略。IPT使用多个恢复任务进行预训练。EDT使用多个退化级别进行预训练。作者直接使用大规模数据集对同一任务进行预训练,并相信大规模数据对于预训练来说才是真正重要的。以训练出来的指标作为验证,如下图所示,比其他方法高 0.1 0.1 0.1~ 1.2 1.2 1.2dB。

论文方法

网络架构

本文设计的网络结构如下图所示:和很多网络一样,整个网络由三个部分组成,包括浅层特征提取,深层特征提取和图像重构。

HAB

如图2所示,当采用channel attention 时,更多的像素会被激活,因为有一个全局信息作为权重(个别图的牵强解释?)。所以本文引入channel-attention为基础的conv block,加入到Transformer block中 来增强网络的表达能力。

如上图4所示,一个channel attention block(CAB)被塞到了标准的Swin Transformer block下的第一个LayerNorm(LN)层后面,与window-base multi-head self-attention(W-MSA)模块平行。

类似于SwinIR和Swin Transformer,在连续的HAB中,每隔一段时间就采用基于偏移窗口的自注意力(SW-MSA)。为了避免CAB和MSA在优化和视觉表示上可能的冲突,CAB的输出乘以一个小常数 α \alpha α。对于给定的输入,HAB的整个计算过程:其中 X N X_N XN和 X M X_M XM都是中间特征, Y Y Y是HAB的输出。
X N = L N ( X ) , X M = ( S ) W − M S A ( X N ) + α C A B ( X N ) + X , Y = M L P ( L N ( X M ) ) + X M , (1) \begin{equation}%输入公式的区域 \begin{aligned}%aligned命令对齐,在对齐的地方用"&" X_N&=LN(X), \\ X_M&=(S)W-MSA(X_N)+\alpha CAB(X_N)+X, \\ Y&=MLP(LN(X_M))+X_M, \end{aligned} \end{equation} \tag{1} XNXMY=LN(X),=(S)W−MSA(XN)+αCAB(XN)+X,=MLP(LN(XM))+XM,(1)

把每个像素看做是embedding的一个token (比如像SwinIR一样,将patch size设置1作为patch embedding)。MLP表示多层感知机。为了计算self-attention模块,将输入的特征( H × W × C H\times W\times C H×W×C)划分为 M × M M\times M M×M大小的 H W M 2 \frac{HW}{M^2} M2HW局部窗口。对于局部窗口特征 X W ∈ R M 2 × C X_W\in\mathbb{R}^{M^2\times C} XW∈RM2×C,query,keyvalue 矩阵都通过线性映射计算为 Q , K Q,K Q,K和 V V V。然后基于窗口的自注意力被表述为 A t t e n t i o n ( Q , K , V ) = S o f t M a x ( Q K T / d + B ) V , (2) \mathcal{Attention}(Q,K,V)=\mathcal{SoftMax}(QK^T/\sqrt{d}+B)V,\tag{2} Attention(Q,K,V)=SoftMax(QKT/d +B)V,(2)  其中 d d d表示为 q u e r y / k e y query/key query/key的维度;B表示相对位置编码,计算公式由Transformer论文(Attention is all you need)可得;需要注意的是,本文使用了较大的窗口来计算self-attention,因为作者发现这样可以显著扩大使用像素的范围 。除此之外,为了建立非重叠窗口邻居之间的联系,作者使用移位窗口划分方法(如Swin Transformer一样),设置偏移量为窗口大小的一半。

一个CAB模块 包含两个标准卷积层,一个GELU激活函数,一个channel Attention(CA)。由于基于Transformer的架构通常需要较大的channel数量来进行token embedding,所以直接使用恒定宽度的卷积会产生很大的计算成本。因此,作者将两个卷积层的通道数压缩一个常数 β \beta β。对于一个输入通道数为 C C C的特征,第一个卷积的输出特征的通道数会被压缩到 C β \frac{C}\beta βC,然后第二个卷积层会恢复到 C C C的通道数。接下来使用标准的CA模块来自适应调整通道特征。

OCAB

作者引入OCAB(Overlapping Cross-Attention Block)是为了建立跨窗口连接并增强窗口自注意力的能力。本文的OCAB由重叠的交叉注意(OCA)和MLP层组成,类似于标准的Swin Transformer block。

对于OCA结构,如下图所示,使用不同的窗口大小来划分投影特征。具体来说,对于输入特征 X X X的 X Q , X K , X V ∈ R H × W × C X_Q,X_K,X_V\in\mathbb{R}^{H\times W\times C} XQ,XK,XV∈RH×W×C, X Q X_Q XQ划分为 M × M M\times M M×M个 H W M 2 \frac{HW}{M^2} M2HW大小的非重叠窗口,而 X K , X V X_K,X_V XK,XV被展开为 M o × M o M_o\times M_o Mo×Mo个 H W M 2 \frac{HW}{M^2} M2HW大小的重叠窗口,其计算方法为 M o = ( 1 + γ ) × M M_o=(1+\gamma)\times M Mo=(1+γ)×M

其中, γ \gamma γ是控制重叠大小的常数。直观上来说,标准的窗口分区,可以认为是kernel size和stride都等于窗口大小 M M M的滑动窗口。重叠窗口分区,可以认为是kernel size为 M o M_o Mo,stride为 M M M的滑动窗口。使用大小为 γ M 2 \frac{\gamma M}2 2γM的zero-padding用以保证重叠窗口的大小一致性

使用公式(2)计算attention矩阵,位置bias B ∈ R M × M o B\in \mathbb{R}^{M\times M_o} B∈RM×Mo也应用。不像WSA一样,query/key/value都从同一个窗口特征计算,OCA从更大的感受野上计算key/value,在该字段中可以利用更多的信息进行查询(query)。

模型预训练

IPT使用其他低级任务(去噪、去雨、超分等)进行预训练;EDT使用同个数据集的不同退化等级进行预训练;而本文使用相同任务进行预训练,但只是在其他数据集上(比如大的图像数据集ImageNet)预训练。比如要训练 × 4 \times4 ×4的超分任务,就在ImageNet数据集上进行分 × 4 \times4 ×4的超分预训练,然后再在特定数据集(比如DF2K)上进行微调。所提出的策略,即相同任务预训练。作者认为这样的预训练会让训练更简单,同时带来更多的性能改进。

同时,作者也提出在预训练时有足够的训练迭代次数,以及微调时适当的小学习率对于预训练策略的有效性非常重要。作者认为这是因为Transformer需要更多的数据和迭代来学习任务的一般知识,但需要较小的学习率进行微调以避免对特定数据集的过度拟合。

论文实验

结构和参数设置

训练数据集采用DF2K(DIV2K+Flicker2K),因为作者发现只使用DIV2K会导致过拟合。在利用预训练时,采用ImageNet。使用Set5,Set14,BSD100,Urban100和Manga109评估方法。定量指标使用PSNR和SSIM,在Y通道上计算。损失函数采用 L 1 L_1 L1。

对于HAT的结构,作者采用与SwinIR相同的深度和宽度。具体而言,RHAG数量和HAB数量都设为 6 6 6。channel数量设置为 180 180 180。对于(S)W-MSA和OCA,将attention head数量和window 大小设置为 6 6 6和 16 16 16。HAB的权重因子超参数 α \alpha α设为 0.01 0.01 0.01,CAB的两个卷积之间的压缩因子 β \beta β设为 3 3 3,OCA的重叠率设为 0.5 0.5 0.5。

对于更大变体结构 HAT-L ,作者采用增加深度的方式,即将HAT中的RHAG数量从 6 6 6提升到了 12 12 12。

对于与SiwnIR类似计算量,且参数更少的架构 HAT-S ,channel 数量设置成 144 144 144,同时在CAB中使用depth-wise 卷积。

消融学习

窗口大小设置

该消融学习的实验,是在swinIR上进行了,为了避免本文新引入块的影响。定量分析如下表所示,窗口大小为 16 × 16 16\times 16 16×16的模型比 8 × 8 8 \times 8 8×8的模型要有更好的性能,特别是在Urban100上。

定性分析如下图所示,对于红色标记的patch,窗口大小为 16 16 16的模型比窗口大小为 8 8 8的模型能利用更多的输入像素。

定量分析和定性分析都能证明大窗口尺寸的有效性。基于这个结论,作者直接使用 16 16 16的窗口大小作为默认设置。

OCAB和CAB的有效性

OCAB和CAB的定量分析如下表所示,在Urban100数据集上都能带来性能增益,组合起来也能有组合增益。

如下图所示,使用OCAB的模型具有更大的利用像素范围,并能生成更好的重建结果。使用CAB时,所使用的像素甚至扩展到几乎整个图像。使用OCAB和CAB的结果获得了最高的DI,这意味着本文的方法利用了最多的输入像素。尽管使用OCAB和CAB的模型性能比w/OCAB稍低,但本文的方法获得了最高的SSIM并重建了最清晰的纹理

CAB的设计

channel attention带来的影响如下表所示,能带来 0.05 d B 0.05dB 0.05dB的性能增益。

作者还探索了CAB的权重因子 α \alpha α的有效性, α \alpha α就是CAB结果的输出权重, α = 0 \alpha=0 α=0就是不使用CAB的结果。如下表所示, α = 0.01 \alpha=0.01 α=0.01能获得最佳性能。这表明CAB和自注意力在优化方面可能存在潜在问题

OCAB中重叠率的影响

超参数 γ \gamma γ被用于控制OCAB重叠交叉注意力的重叠大小。下表展示了 γ \gamma γ从 0 0 0到 0.75 0.75 0.75的影响。

γ = 0 \gamma=0 γ=0表示的是标准的Transformer模块。从表中也能看到 g a m m a = 0.5 gamma=0.5 gamma=0.5能得到最佳性能。但当 γ = 0.25 \gamma=0.25 γ=0.25或者 γ = 0.75 \gamma=0.75 γ=0.75时,给模型带来的增益较小,甚至会出现性能下降。这表明,不合适的重叠率不利于相邻窗口的相互

实验结果

下表对比了与SOTA方法的量化性能对比,具有更少参数和类似计算的HAT-S也可以显著由于最先进的方法SwinIR。与IPT和EDT的对比,就是使用了ImageNet预训练的结果,可以看出预训练策略能让模型受益匪浅。

视觉对比效果如下图所示。纹理和字符等恢复的都是最好的。

预训练策略的消融学习如下:

在不同网络上,进行同任务预训练的对比如下:

相关推荐
鸭鸭梨吖41 分钟前
产品经理笔记
笔记·产品经理
齐 飞1 小时前
MongoDB笔记01-概念与安装
前端·数据库·笔记·后端·mongodb
丫头,冲鸭!!!2 小时前
B树(B-Tree)和B+树(B+ Tree)
笔记·算法
听忆.2 小时前
手机屏幕上进行OCR识别方案
笔记
Selina K3 小时前
shell脚本知识点记录
笔记·shell
3 小时前
开源竞争-数据驱动成长-11/05-大专生的思考
人工智能·笔记·学习·算法·机器学习
霍格沃兹测试开发学社测试人社区4 小时前
软件测试学习笔记丨Flask操作数据库-数据库和表的管理
软件测试·笔记·测试开发·学习·flask
幸运超级加倍~4 小时前
软件设计师-上午题-16 算法(4-5分)
笔记·算法
王俊山IT4 小时前
C++学习笔记----10、模块、头文件及各种主题(一)---- 模块(5)
开发语言·c++·笔记·学习
Yawesh_best5 小时前
思源笔记轻松连接本地Ollama大语言模型,开启AI写作新体验!
笔记·语言模型·ai写作