MViTv2:Facebook出品,进一步优化的多尺度ViT | CVPR 2022

论文将Multiscale Vision Transformers (MViTv2) 作为图像和视频分类以及对象检测的统一架构进行研究,结合分解的相对位置编码和残差池化连接提出了MViT的改进版本

来源:晓飞的算法工程笔记 公众号

论文: MViTv2: Improved Multiscale Vision Transformers for Classification and Detection

Introduction


为不同的视觉识别任务设计架构一直很困难,而最广泛采用的架构是结合了简单性和有效性的架构,例如VGGNetResNet。最近,Vision Transformers(ViT) 已经显示出能够与卷积神经网络 (CNN) 相媲美的性能,涌现出大量将其应用于不同的视觉任务中的工作来。

虽然ViT在图像分类中很流行,但在高分辨率目标检测和视频理解任务中的应用仍然具有挑战性。视觉信号的密度对计算和内存要求提出了严峻的挑战,主要因为基于Transformer的模型的自注意力块的复杂度与输入长度呈二次方增长。目前有大量的研究来解决这个问题,比较主要的两个为:

  • 使用窗口注意力,在一个窗口内进行局部注意力计算以及对象检测,主要用于目标检测任务。
  • 使用池化注意力,在计算自注意力之前先聚合局部特征的,主要用于视频任务。

后者推动了Multiscale Vision Transformers(MViT)的研究,以简单的方式扩展ViT的架构。整个网络不再固定分辨率,而是构造从高分辨率到低分辨率的多个阶段的特征层次结构。

MViT专门为视频任务设计,具有最先进的性能。论文将MViT作为一个模型系列在图像分类、目标检测和视频分类中进行研究,从而了解它是否能够作为通用的视觉任务的主干网络。

根据研究结果,论文提出了改进的架构 (MViTv2),并包含以下内容:

  • 从两个方面来大幅提升池化注意力的性能:
    • 使用坐标分离的位置距离构造相对位置编码,在Transformer块中注入平移不变的位置信息。
    • 使用残差池化连接来补偿注意力计算中池化缩放带来的影响。
  • 根据标准的密集预测框架Mask R-CNN with Feature Pyramid Networks(FPN)改进MViT结构,并将其应用于目标检测和实例分割。实验表明,池化注意力比窗口注意力机制(例如Swin)更有效。另外,论文进一步开发了一种简单的混合池化注意力和窗口注意力的方案,可以实现更好的准确性/计算权衡。
  • 论文提供了五种尺寸的MViT2架构,只需很少的修改就能作为图像分类、对象检测和视频分类的通用视觉架构。实验表明,MViTImageNet分类的准确率为88.8%,COCO对象检测的APbox准确率为58.7%,Kinetics-400视频分类的准确率为86.1%。其中,在视频分类任务上的准确率是非常出色的。

Revisiting Multiscale Vision Transformers


MViTv1的关键思想是为低级和高级视觉建模构建不同的阶段,而不是像ViT那样全是单尺度块。MViTv1缓慢地扩展通道宽度\(D\),同时降低网络输入到输出阶段的序列长度\(\boldsymbol{\mathit{L}}\),具体可以看之前的文章 【MViT:性能杠杠的多尺度ViT | ICCV 2021】

为了在Transformer块内执行下采样,MViT引入了池化注意力(Pooling Attention)。具体来说,对于输入序列 \(X\in{\mathbb{R}}^{L\times D}\),分别对查询、键和值张量应用线性投影 \({W}{Q}\)、\({W}{K}\)、\({W}_{V}\in\mathbb{R}^{D\times D}\),以及池化运算符 (\({\mathcal{P}}\)):

\[Q={\mathcal P}{Q}\left(X W{Q}\right),\;K={\mathcal P}{K}\left(X W{K}\right),\;V={\mathcal P}{V}\left(X W{V}\right) \quad\quad(1) \]

其中,\({Q}\in\mathbb{R}^{\tilde{L}\times D}\) 的长度 \(\tilde{L}\) 可减小 \({\mathcal{P}}{Q}\) 倍,\(K\),\(V\) 的长度则可减少 \({\mathcal{P}}{K}\) 和 \({\mathcal{P}}_{V}\)倍。

随后,使用池化注意力计算任意长度 \({\widetilde{L}}\) 的输出序列 \(Z\in\mathbb{R}^{{\tilde{L}}\times D}\) 的输出序列:

\[Z:=\mathrm{Attn}(Q,K,V)=\mathrm{Softmax}\left(Q K^{\textsf{T}}/\sqrt{D}\right)V \quad\quad(2) \]

请注意,键和值张量的下采样因子 \({\mathcal{P}}{K}\) 和 \({\mathcal{P}}{V}\) 可能与应用于查询序列的 \({\cal P}_{Q}\) 不同。

池化注意力通过池化查询张量 \({Q}\) 来降低MViT不同阶段之间的分辨率,通过池化键 \(K\) 和值 \(V\) 张量来显着降低块内的计算和内存复杂性。

Improved Multiscale Vision Transformers


Improved Pooling Attention

Decomposed relative position embedding

虽然MViT已经显示出其在token之间建模交互方面的能力,但它们关注的是内容,而不是结构。完全依赖于绝对位置编码来提供位置信息的时空结构建模,导致MViT忽略了视觉中平移不变性的基本原理。也就是说,即使相对位置保持不变,MViT对两个补丁之间的交互进行建模的方式也会根据token在图像中的绝对位置而改变。为了解决这个问题,论文将相对位置编码加入到自注意力计算中。

论文将两个输入元素 \(i\) 和 \(j\) 之间的相对位置表示为位置编码 \(R_{p(i),p(j)}{\in}\mathbb{R}^{d}\) ,其中 \(p({t})\) 和 \(p({j})\) 表示元素 \(i\) 和 \(j\) 的空间位置,随后将相对位置编码嵌入到自注意力模块中:

\[\begin{array}{c} {{\mathrm{Attn}(Q,K,V)=\mathrm{Softmax}\left((Q K^{\textsf{T}}+E^{(\operatorname{rel})})/\sqrt{d}\right)V}} \\ {{\mathrm{where}}} \ \ \ \ E_{i j}^{\mathrm{(rel)}}=Q_{i}\cdot R_{p(i),p(j)} \end{array} \quad\quad(3) \]

但是,由于 \(R_{p(i),p(j)}\) 涉及时空坐标,存在的位置编码数量为 \({\cal O}(T W H)\),计算起来很复杂。为了降低复杂性,论文沿时空轴对元素 \(i\) 和 \(j\) 之间的距离进行分解计算:

\[R_{p(i),p(i)}=R_{h(i),h(i)}^{\mathrm{h}}+R_{u\iota(i),n\iota(i)}^{\mathrm{w}}+R_{t(i),t(i)}^{\mathrm{t}} \quad\quad(4) \]

其中 \(R^{\mathrm{h}},R^{\mathrm{w}},R^{\mathrm{t}}\) 是沿高度、宽度和时间轴的位置编码,\(h(i)\)、\(w(i)\) 和 \(t(i)\) 分别表示标记的垂直、水平和时间位置。需要注意的是,\(R^{t}\) 是可选的,主要是为了支持视频任务中的时间维度。相比之下,论文的位置分解将学习编码的数量减少到 \(\mathcal{O}(H+W+H)\) ,这对早期的高分辨率特征图有很大的帮助。

Residual pooling connection

正如MViTv1所描述的,池化注意力对于降低注意力块中的计算复杂度和内存消耗非常有效。 一般情况下,在 \(K\) 和 \(V\) 张量的缩放因子要大于 \(Q\) 张量的缩放因子,而且 \(Q\) 张量仅在跨阶段的分辨率发生变化时才被下采样。为了避免特征信息因为较大的缩放因子而丢失,论文添加了与池化的 \(Q\) 张量的残差池化连接,增加信息流并促进MViT中池化注意力块的训练。

如图 2 所示,论文在注意力块内引入了一个新的残差池化连接。具体来说,论文将池化查询张量添加到输出序列 \(Z\) 中:

\[Z:=\operatorname{Attn}\left(Q,K,V\right)+Q \quad\quad(5) \]

需要注意的是,输出序列 \(Z\) 与池化张量 \(Q\) 的长度相同。

这里的残差连接和池化操作都是必须的,在非跨阶段的块中也要补一个步幅为 1 的池化操作对 \(Q\) 进行处理。由于这种改进的计算增加很少,所以仍能维持池化注意力的低计算复杂度。

MViT for Object Detection

FPN integration

MViT的层次结构分四个阶段生成多尺度特征图,可以自然地对接到目标检测的特征金字塔网络(FPN)中,如图 3 所示。通过将FPNMViT主干结合使用,论文将其应用于不同的检测架构(例如Mask R-CNN)。

Hybrid window attention

Transformer中的自注意力对token数量具有二次方复杂度,这个问题在高分辨率输入的目标检测中更加严重。

池化注意力和窗口注意力都通过在计算自注意力时减少查询、键和值张量的大小来控制自注意力的复杂性,但它们的本质是不同的:

  • 池化注意力池通过局部聚合对特征进行下采样,但保持全局自注意力计算。
  • 窗口注意力保持张量的分辨率,将输入划分为非重叠窗口并在局部执行自注意力计算。

默认窗口注意力仅在窗口内执行本地自注意力,因此缺乏跨窗口的连接。与使用滑动窗口来缓解此问题的Swin不同,论文提出了一种简单的混合窗口注意力(Hwin)来添加跨窗口连接。

HwinFPN的最后三个阶段的所有块输入(最后阶段的最后一个块外)计算窗口注意力得到局部特征,最后一个块则进行全局自注意力(或池化注意力)计算使得FPN的输入特征图包含全局信息。

Positional embeddings in detection

与固定分辨率输入的ImageNet分类不同,目标检测的训练通常包含不同大小的输入。对MViT中的位置编码(绝对或相对),论文首先从ImageNet预训练权重初始化参数,对应 \(224\times224\) 输入的位置编码,然后将其插值到相应的大小以进行目标检测训练。

MViT for Video Recognition

由于升级的池化注意力可以泛化到时空域,MViT可以很简单地应用于视频识别任务(例如Kinetics)。虽然MViTv1仅关注Kinetics数据集上的预训练,MViTv2也会研究从ImageNet数据集上预训练的效果。

应用于图像的MViT与应用于视频的MViT对比有三个不同之处:

  • 1)主干中的映射层需要将输入投影到时空立方体而不是2D图像块。
  • 2)池化运算符需要池化时空特征图。
  • 3)相对位置编码加入时空位置。

由于 1) 和 2) 中的投影层和池化运算符默认由卷积层实现,因此论文使用inflation初始化,即将预训练模型中2D卷积层的权值除以T再复制T份形成3D卷积层(文章描述的做法跟参考文献有出入)。对于 3),论文基于提出的分离相对位置编码进行计算,简单地用预训练权值初始化空间位置编码,然后将时空位置编码初始化为零。

MViT Architecture Variants

如表 1 所示,论文构建了几个不同数量参数和计算量的MViT变体,以便与其他ViT模型进行比较。每个变体主要改变基础的通道尺寸、每个阶段的块数和块中的头数,一共有五个变体(TinySmallBaseLargeHuge)。需要注意,论文使用较少的头数量来改善运行时间,这对计算量和参数量是没有影响的。

遵循MViT的池化注意力设计,论文默认在所有池化注意力块中使用键和值池化,并且池化步幅在第一阶段设置为 4,并根据阶段的分辨率自适应地衰减跨阶段的步幅。

Experiments: Image Recognition


论文对ImageNet分类和COCO对象检测进行了实验。

Image Classification on ImageNet-1K

ImageNet-1K数据集上的性能如表 2 所示,论文改进的MViTv2具有更高的准确性,并且计算量和参数更少。

ImageNet-21K数据集上的性能如表 3 所示。

Object Detection on COCO

MViTv2作为不同检测框架的主干网络的性能如表 5 所示。默认采用Hwin注意力,窗口设置为 \([56, 28, 14, 7]\),使用ImageNet-1k预训练。

Ablations on ImageNet and COCO

不同注意力机制在图像分类和目标检测任务中的表现如表 4 所示。

不同位置编码方式在图像分类和目标检测任务中的表现如表 6 所示。

不同的残差池化策略在图像分类和目标检测任务中的表现如表 7 所示。

模型运行耗时在图像分类和目标检测任务中的表现如表 8 所示。

FPN对目标检测的影响如表 9 所示。

Experiments: Video Recognition


Main Results

在不同量级的Kinetics数据集上的视频任务性能对比。

SSv2数据集上的视频任务性能对比。

Ablations on Kinetics

Kinetics-400数据集上对比不同预训练模型的性能。

Conclusion


论文提出了改进的Multiscale Vision Transformer作为视觉识别的通用主干网络。通过实验表明,MViT在图像分类、目标检测、实例分割和视频识别等广泛使用的基准测试中达到了最很不错的精度。

如果本文对你有帮助,麻烦点个赞或在看呗~

更多内容请关注 微信公众号【晓飞的算法工程笔记】