Anchor DETR:Transformer-Based目标检测的Query设计

写在前面

文中指出之前DETR-like算法存在以下问题:

  • 之前DETR-liked检测算法里,object query是一组可学习的嵌入表示(就是一组256-d的向量),缺乏明确的物理意义,不能解释它们会关注什么地方。
  • 每个object query 预测的位置没有一个特定的模式(specific mode),即每个object query不会关注特定的区域。

PS:第二点所谓"预测位置没有一个特定模式"这个结论是怎么得出来的呢?作者援引了DETR论文中的一幅图像(如上图所示)进行说明。该图像中每个子图上都有很多点,每个子图代表了一个object query在验证集所有图像上得到的预测框的中心点坐标(经过归一化后的),绿色代表小的预测框,红色代表水平方向比较大的预测框,蓝色代表垂直方向比较大的预测框。通过上图可知,即使同一个object query,在不同图像上得到的预测框其位置和大小都是不固定的,所以说没有特定模式,而这使得object query难以优化。


为解决上述问题:

  • 本文基于anchor point(在CNN-based检测算法中被广泛使用)设计object query,每个object query关注anchor point附近的目标;
  • 本文object query的设计可以预测一个位置的多个目标;
  • 设计了一种注意力变体,减少显存占用。

论文的贡献或方法都可以转化成相应的问题,然后从文中逐一寻找答案,寻找答案的过程也是理解论文的过程,现在我们可以提出以下问题:

  • anchor point怎么来的?
  • 如何基于anchor point设计object query?
  • 为什么本文object query的设计可以预测一个位置的多个目标?
  • 注意力变体是怎么样的,为什么可以减少显存占用?

在阅读论文时带着问题,有目的的阅读,边阅读边思考,通常效果会好很多,也更容易理解作者想表达的意思。

接下来让我们从文中method部分寻找问题的答案。

一、Method

1. Anchor Points

Q1:anchor point怎么来的?

A1:如下图所示,文中采用两种方式获得anchor point。一种是网格均匀采样,anchor point被固定图像中均匀的网格点;为另一种是可学习的点,这些点根据满足0~1均匀分布随机初始化并作为可学习参数进行更新。

有了anchor point,就可以把回归头的输出当作对于anchor point的偏移量(参考了Deformable DETR的做法),将预测框中心点坐标加到对应的anchor point上。

对Deformable DETR不了解的朋友可以查看我的博客:Deformable DETR:结合多尺度特征、可变形卷积机制的DETR

作者通过对比实验(如下图所示),采用了可学习的anchor point(但综合看起来两者好像差别不显著= = 、)。

2. Attention Formulation

在回答第二个问题之前,我们首先需要了解一下论文中的一些符号表示。论文在该部分讲解了DETR-like检测算法中的注意力机制的建模方式(比较容易理解,不过多赘述),其中涉及的一些符号表示对我们理解文章的后续内容是有帮助的。

注意力机制建模方式如下:

其中表示维度,表示内容信息,表示位置信息。

decoder中包含自注意力交叉注意力

自注意力中是相同的,是相同的,表示decoder前一层的输出,对于decoder第一层而言可以设置为常数向量,也可以设置为可学习的嵌入表示。query位置部分在DETR中通常用一组可学习的嵌入向量表示,其中表示query的数量:

交叉注意力的讲解略过,不难理解。

接下来我们可以继续寻找下一个问题的答案。

3. Anchor Points to Object Query

Q2:如何基于anchor point设计object query?

A2:anchor point可表示为,其中表示点的个数。根据anchor point获得object query只需要确定一种编码方式即可,即。一种很自然的想法就是继续使用位置编码函数进行编码,但作者采用了一个额外的MLP网络对位置编码结果进行微调。

PS:为什么要额外添加一个MLP微调位置编码结果?文中没有进行相应的消融实验,原因未知。

4. Multiple Predictions for Each Anchor Point

Q3:既然作者的想法是说,通过anchor point得到object query,使得每个object query能够关注某个特定的区域。那如果一个位置有多个目标,但这个位置只有一个object query关注这里,即只会有一个预测框,那怎么办?

A3:简单来说,就是让这个地方可以有多个预测框。作者重新回顾了decoder第一层query的内容部分,每个object query只有一种模式(pattern ),即。为了使得一个anchor point可以预测多个目标,作者将多模式嵌入(multiple pattern embedding)整合到了每个object query中,以适应一个位置存在多个目标的情况。其中多模式嵌入表示为:

其中表示模式的数量(文中)。

PS:如何理解pattern呢? 我个人理解这里的pattern主要指的是预测框的位置和大小。通过增加pattern的数量,可以增加在某个位置预测框的数量,进而实现一个位置多个目标的检测。但具体如何将多个pattern整合到一个object query中,还没太理解。

文中还提到,由于平移不变性,所有object query都共享这些模式(个人理解写在下面)。因此进一步可以得到,即object query的数量

PS:所谓平移不变性是什么意思呢? 举个例子,对于一个检测模型来说,无论目标是在图像中间还是边缘,都应该检测到目标。而图像中每个位置都有可能出现多个目标的情况,所以所有object query应该共享这些模式。

模型预测框可视化结果如下图所示:

每一列表示一个object query在所有图像中预测框的中心点分布情况,其中最后一行的黑点表示anchor point,前三行表示每个pattern对应预测框中心点的分布情况,明显可以看出预测框都是在anchor point附近。

再结合代码进行理解:

python 复制代码
# transformer的init方法中初始化模式嵌入,维度为:3,256
self.pattern = nn.Embedding(self.num_pattern, d_model)

# transformer的forward方法中,为每个object query分配3个模式嵌入
# 调整维度为(bs, 3*300, 256)
tgt = self.pattern.weight.reshape(1, self.num_pattern, 1, c).repeat(bs, 1, self.num_position, 1).reshape(
            bs, self.num_pattern * self.num_position, c)

5. Row-Column Decoupled Attention(RCDA)

作者先说明了现有注意力机制的一些缺点:

  • transformer架构计算量较大,会占用较多的显存。
  • Deformable DETR虽然能降低显存,但会导致内存的随机访问,对硬件不友好(好吧,不懂硬件,说啥是啥)。
  • 其他注意力变体作者实验发现不适用于DETR-like的检测器。

所以,作者提出了一种新的注意力机制变体------行列解耦注意力,以降低显存要求,同时能媲美甚至超越DETR中标准的注意力机制。

大致思路跟深度可分离卷积好像差不多,就是对x和y分别进行计算,最后整合起来。具体的没仔细看,算法复杂度、降低内存啥的这类内容本能的排斥(主要是太菜了看不懂)。主要模型相关的内容已经介绍完了,后续有机会再把这部分内容补上吧。

二、实验结果

文中实验结果都比较好理解,后续补充对实验结果的个人思考。

三、总结

最后做个总结(也是回顾),Anchor DETR的主要贡献是:

  • 根据anchor point得到object query,使其具有明确物理意义,每个object query关注特定区域;
  • 针对第一点可能面临的"一个区域多个目标"的挑战,进一步将多个pattern整合到了一个object query,可实现一个位置多个目标的检测;
  • 提出行列解耦注意力机制,在降低显存使用的同时,性能可媲美甚至超过标准注意力机制。

上述改进使得模型收敛速度提高了10倍,性能也有较为显著提升。

相关推荐
MJ绘画中文版1 分钟前
灵动AI:艺术与科技的融合
人工智能·ai·ai视频
zyhomepage7 分钟前
科技的成就(六十四)
开发语言·人工智能·科技·算法·内容运营
百流10 分钟前
Pyspark中pyspark.sql.functions常用方法(4)
1024程序员节
qq210846295314 分钟前
【Ubuntu】Ubuntu22双网卡指定网关
1024程序员节
YueTann31 分钟前
APS开源源码解读: 排程工具 optaplanner II
1024程序员节
kinlon.liu40 分钟前
安全日志记录的重要性
服务器·网络·安全·安全架构·1024程序员节
挽安学长41 分钟前
油猴脚本-GPT问题导航侧边栏增强版
人工智能·chatgpt
爱编程— 的小李44 分钟前
开关灯问题(c语言)
c语言·算法·1024程序员节
戴着眼镜看不清1 小时前
国内对接使用GPT解决方案——API中转
人工智能·gpt·claude·通义千问·api中转
YRr YRr1 小时前
深度学习:正则化(Regularization)详细解释
人工智能·深度学习