Anchor DETR:Transformer-Based目标检测的Query设计

写在前面

文中指出之前DETR-like算法存在以下问题:

  • 之前DETR-liked检测算法里,object query是一组可学习的嵌入表示(就是一组256-d的向量),缺乏明确的物理意义,不能解释它们会关注什么地方。
  • 每个object query 预测的位置没有一个特定的模式(specific mode),即每个object query不会关注特定的区域。

PS:第二点所谓"预测位置没有一个特定模式"这个结论是怎么得出来的呢?作者援引了DETR论文中的一幅图像(如上图所示)进行说明。该图像中每个子图上都有很多点,每个子图代表了一个object query在验证集所有图像上得到的预测框的中心点坐标(经过归一化后的),绿色代表小的预测框,红色代表水平方向比较大的预测框,蓝色代表垂直方向比较大的预测框。通过上图可知,即使同一个object query,在不同图像上得到的预测框其位置和大小都是不固定的,所以说没有特定模式,而这使得object query难以优化。


为解决上述问题:

  • 本文基于anchor point(在CNN-based检测算法中被广泛使用)设计object query,每个object query关注anchor point附近的目标;
  • 本文object query的设计可以预测一个位置的多个目标;
  • 设计了一种注意力变体,减少显存占用。

论文的贡献或方法都可以转化成相应的问题,然后从文中逐一寻找答案,寻找答案的过程也是理解论文的过程,现在我们可以提出以下问题:

  • anchor point怎么来的?
  • 如何基于anchor point设计object query?
  • 为什么本文object query的设计可以预测一个位置的多个目标?
  • 注意力变体是怎么样的,为什么可以减少显存占用?

在阅读论文时带着问题,有目的的阅读,边阅读边思考,通常效果会好很多,也更容易理解作者想表达的意思。

接下来让我们从文中method部分寻找问题的答案。

一、Method

1. Anchor Points

Q1:anchor point怎么来的?

A1:如下图所示,文中采用两种方式获得anchor point。一种是网格均匀采样,anchor point被固定图像中均匀的网格点;为另一种是可学习的点,这些点根据满足0~1均匀分布随机初始化并作为可学习参数进行更新。

有了anchor point,就可以把回归头的输出当作对于anchor point的偏移量(参考了Deformable DETR的做法),将预测框中心点坐标加到对应的anchor point上。

对Deformable DETR不了解的朋友可以查看我的博客:Deformable DETR:结合多尺度特征、可变形卷积机制的DETR

作者通过对比实验(如下图所示),采用了可学习的anchor point(但综合看起来两者好像差别不显著= = 、)。

2. Attention Formulation

在回答第二个问题之前,我们首先需要了解一下论文中的一些符号表示。论文在该部分讲解了DETR-like检测算法中的注意力机制的建模方式(比较容易理解,不过多赘述),其中涉及的一些符号表示对我们理解文章的后续内容是有帮助的。

注意力机制建模方式如下:

其中表示维度,表示内容信息,表示位置信息。

decoder中包含自注意力交叉注意力

自注意力中是相同的,是相同的,表示decoder前一层的输出,对于decoder第一层而言可以设置为常数向量,也可以设置为可学习的嵌入表示。query位置部分在DETR中通常用一组可学习的嵌入向量表示,其中表示query的数量:

交叉注意力的讲解略过,不难理解。

接下来我们可以继续寻找下一个问题的答案。

3. Anchor Points to Object Query

Q2:如何基于anchor point设计object query?

A2:anchor point可表示为,其中表示点的个数。根据anchor point获得object query只需要确定一种编码方式即可,即。一种很自然的想法就是继续使用位置编码函数进行编码,但作者采用了一个额外的MLP网络对位置编码结果进行微调。

PS:为什么要额外添加一个MLP微调位置编码结果?文中没有进行相应的消融实验,原因未知。

4. Multiple Predictions for Each Anchor Point

Q3:既然作者的想法是说,通过anchor point得到object query,使得每个object query能够关注某个特定的区域。那如果一个位置有多个目标,但这个位置只有一个object query关注这里,即只会有一个预测框,那怎么办?

A3:简单来说,就是让这个地方可以有多个预测框。作者重新回顾了decoder第一层query的内容部分,每个object query只有一种模式(pattern ),即。为了使得一个anchor point可以预测多个目标,作者将多模式嵌入(multiple pattern embedding)整合到了每个object query中,以适应一个位置存在多个目标的情况。其中多模式嵌入表示为:

其中表示模式的数量(文中)。

PS:如何理解pattern呢? 我个人理解这里的pattern主要指的是预测框的位置和大小。通过增加pattern的数量,可以增加在某个位置预测框的数量,进而实现一个位置多个目标的检测。但具体如何将多个pattern整合到一个object query中,还没太理解。

文中还提到,由于平移不变性,所有object query都共享这些模式(个人理解写在下面)。因此进一步可以得到,即object query的数量

PS:所谓平移不变性是什么意思呢? 举个例子,对于一个检测模型来说,无论目标是在图像中间还是边缘,都应该检测到目标。而图像中每个位置都有可能出现多个目标的情况,所以所有object query应该共享这些模式。

模型预测框可视化结果如下图所示:

每一列表示一个object query在所有图像中预测框的中心点分布情况,其中最后一行的黑点表示anchor point,前三行表示每个pattern对应预测框中心点的分布情况,明显可以看出预测框都是在anchor point附近。

再结合代码进行理解:

python 复制代码
# transformer的init方法中初始化模式嵌入,维度为:3,256
self.pattern = nn.Embedding(self.num_pattern, d_model)

# transformer的forward方法中,为每个object query分配3个模式嵌入
# 调整维度为(bs, 3*300, 256)
tgt = self.pattern.weight.reshape(1, self.num_pattern, 1, c).repeat(bs, 1, self.num_position, 1).reshape(
            bs, self.num_pattern * self.num_position, c)

5. Row-Column Decoupled Attention(RCDA)

作者先说明了现有注意力机制的一些缺点:

  • transformer架构计算量较大,会占用较多的显存。
  • Deformable DETR虽然能降低显存,但会导致内存的随机访问,对硬件不友好(好吧,不懂硬件,说啥是啥)。
  • 其他注意力变体作者实验发现不适用于DETR-like的检测器。

所以,作者提出了一种新的注意力机制变体------行列解耦注意力,以降低显存要求,同时能媲美甚至超越DETR中标准的注意力机制。

大致思路跟深度可分离卷积好像差不多,就是对x和y分别进行计算,最后整合起来。具体的没仔细看,算法复杂度、降低内存啥的这类内容本能的排斥(主要是太菜了看不懂)。主要模型相关的内容已经介绍完了,后续有机会再把这部分内容补上吧。

二、实验结果

文中实验结果都比较好理解,后续补充对实验结果的个人思考。

三、总结

最后做个总结(也是回顾),Anchor DETR的主要贡献是:

  • 根据anchor point得到object query,使其具有明确物理意义,每个object query关注特定区域;
  • 针对第一点可能面临的"一个区域多个目标"的挑战,进一步将多个pattern整合到了一个object query,可实现一个位置多个目标的检测;
  • 提出行列解耦注意力机制,在降低显存使用的同时,性能可媲美甚至超过标准注意力机制。

上述改进使得模型收敛速度提高了10倍,性能也有较为显著提升。

相关推荐
EQUINOX128 分钟前
3b1b线性代数基础
人工智能·线性代数·机器学习
Kacey Huang1 小时前
YOLOv1、YOLOv2、YOLOv3目标检测算法原理与实战第十三天|YOLOv3实战、安装Typora
人工智能·算法·yolo·目标检测·计算机视觉
加德霍克1 小时前
【机器学习】使用scikit-learn中的KNN包实现对鸢尾花数据集或者自定义数据集的的预测
人工智能·python·学习·机器学习·作业
漂亮_大男孩1 小时前
深度学习|表示学习|卷积神经网络|局部链接是什么?|06
深度学习·学习·cnn
Light Gao1 小时前
AI赋能未来:Agent能力与AI中间件平台对行业的深远影响
人工智能·ai·中间件·大模型
骇客野人1 小时前
【人工智能】循环神经网络学习
人工智能·rnn·学习
eguid_11 小时前
JavaScript图像处理,常用图像边缘检测算法简单介绍说明
javascript·图像处理·算法·计算机视觉
lly_csdn1232 小时前
【Image Captioning】DynRefer
python·深度学习·ai·图像分类·多模态·字幕生成·属性识别
速融云3 小时前
汽车制造行业案例 | 发动机在制造品管理全解析(附解决方案模板)
大数据·人工智能·自动化·汽车·制造
AI明说3 小时前
什么是稀疏 MoE?Doubao-1.5-pro 如何以少胜多?
人工智能·大模型·moe·豆包