End-to-End Object Detection with Transformers(论文解析)

End-to-End Object Detection with Transformers

    • 摘要
    • 介绍
    • 相关工作
      • [2.1 集合预测](#2.1 集合预测)
      • [2.2 transformer和并行解码](#2.2 transformer和并行解码)
      • [2.3 目标检测](#2.3 目标检测)
    • [3 DETR模型](#3 DETR模型)
      • [3.1 目标检测集设置预测损失](#3.1 目标检测集设置预测损失)
      • [3.2 DETR架构](#3.2 DETR架构)

摘要

我们提出了一种将目标检测视为直接集合预测问题的新方法。我们的方法简化了检测流程,有效地消除了许多手工设计的组件的需求,如显式编码我们关于任务的先验知识的非极大值抑制过程或锚点生成。新框架的主要要素,称为DEtection TRansformer或DETR,包括一个基于集合的全局损失,通过二分图匹配强制执行唯一的预测,以及一个Transformer编码器-解码器架构。给定一组固定的学习目标查询,DETR通过推理对象之间的关系和全局图像上下文,直接并行输出最终的预测。这个新模型在概念上很简单,不需要专门的库,与许多其他现代检测器不同。DETR在具有挑战性的COCO目标检测数据集上表现出与经过充分优化的Faster R-CNN基线相当的准确性和运行时性能。此外,DETR可以轻松推广为以统一的方式生成全景分割。我们展示它明显优于竞争基线。训练代码和预训练模型可在https://github.com/facebookresearch/detr获得。

介绍

目标检测的目标是为每个感兴趣的对象预测一组边界框和类别标签。现代检测器通过在大量的建议区域[37,5]、锚点[23]或窗口中心[53,46]上定义代理回归和分类问题,以间接的方式解决这个集合预测任务。它们的性能受到后处理步骤的显著影响,以折叠近似重复的预测,受到锚点集设计和将目标框分配给锚点的启发式方法的影响。为了简化这些流程,我们提出了一种直接的集合预测方法,绕过了代理任务。这种端到端的哲学已经在复杂的结构化预测任务中取得了重大进展,例如机器翻译或语音识别,但在目标检测领域尚未实现:以前的尝试[43,16,4,39]要么增加了其他形式的先验知识,要么在具有挑战性的基准测试上未能与强基线竞争。本文旨在弥合这一差距。

图1:DETR通过将常见的CNN与Transformer架构结合在一起,直接(并行地)预测最终的一组检测结果。在训练期间,二分图匹配将预测值唯一地分配给与真值框相匹配的情况。没有匹配的预测应该产生一个"无对象"(∅)的类别预测。

我们通过将目标检测视为直接的集合预测问题,简化了训练流程。我们采用了基于变换器(transformers)[47]的编码器-解码器架构,这是一种用于序列预测的流行架构。变换器的自注意机制明确地模拟了序列中所有元素之间的两两交互作用,使这些架构特别适用于集合预测的特定约束,如去除重复的预测。

我们的DEtection TRansformer(DETR,见图1)一次性预测所有对象,并通过一个集合损失函数进行端到端训练,该函数在预测对象和真值对象之间执行二分图匹配。DETR通过删除多个手工设计的组件,如空间锚点或非极大值抑制,来简化检测流程,这些组件用于编码先验知识。与大多数现有的检测方法不同,DETR不需要任何定制的层,因此可以在包含标准CNN和transformer类的任何框架中轻松复现。

与大多数以前的直接集合预测工作相比,DETR的主要特点是二分图匹配损失和具有(非自回归)并行解码的transformer结合[29,12,10,8]。相比之下,以前的工作侧重于使用RNN进行自回归解码[43,41,30,36,42]。我们的匹配损失函数将预测与真值对象唯一地分配给一个对象,并且不受预测对象排列的影响,因此我们可以并行地发出它们。

DETR的训练设置在多个方面与标准目标检测器不同。新模型需要更长的训练周期,并受益于变换器中辅助解码损失。我们深入探讨了哪些组件对所展示的性能至关重要,

DETR的设计理念很容易扩展到更复杂的任务。在我们的实验中,我们展示了在预训练的DETR之上训练的简单分割头在全景分割(Panoptic Segmentation)[19]上优于竞争基线的结果。全景分割是一项具有挑战性的像素级识别任务,最近变得越来越受欢迎。

相关工作

2.1 集合预测

目前没有一个经典的深度学习模型可以直接预测集合。基本的集合预测任务是多标签分类(请参见例如[40,33]中的参考文献,涉及计算机视觉领域),对于这种任务,基线方法------一对多(one-vs-rest)方法在检测等存在元素之间存在一定结构的问题中不适用(即,存在近似相同的边界框)。在这些任务中的第一个困难是避免近似重复。大多数当前的检测器使用后处理方法,如非极大值抑制,来解决这个问题,但直接集合预测是无需后处理的。它们需要全局推理方案,以模拟所有预测元素之间的交互,以避免冗余。对于固定大小的集合预测,密集全连接网络[9]足够,但成本较高。一种通用的方法是使用自回归序列模型,如循环神经网络[48]。在所有情况下,损失函数应该对预测的排列具有不变性。通常的解决方案是设计基于匈牙利算法[20]的损失函数,以找到真值和预测之间的二分图匹配。这强制执行排列不变性,并保证每个目标元素都有一个唯一的匹配。我们采用了二分图匹配损失方法。然而,与大多数以前的工作不同,我们放弃了自回归模型,而是使用了具有并行解码的变换器,我们将在下面进行描述。

2.2 transformer和并行解码

变换器(Transformers)是由Vaswani等人[47]引入的,作为一种新的基于注意力的机器翻译构建块。注意力机制[2]是神经网络层,可以从整个输入序列中汇总信息。变换器引入了自注意层,类似于非局部神经网络[49],它们会扫描序列中的每个元素,并通过汇总整个序列的信息来更新它。注意力模型的主要优势之一是其全局计算和完美记忆,这使它们在处理长序列时比循环神经网络更适用。在自然语言处理、语音处理和计算机视觉等领域,变换器现在正在取代循环神经网络,应用广泛[8,27,45,34,31]。

transformer首先用于自回归模型,遵循了早期的序列到序列模型[44],逐个生成输出标记。然而,由于推断成本过高(与输出长度成正比,难以批量处理),这导致了并行序列生成的发展,在音频[29]、机器翻译[12,10]、单词表示学习[8]以及更近期的语音识别[6]等领域进行了研究。我们还结合了transformer和并行解码,以在计算成本和执行集合预测所需的全局计算之间找到适当的折衷方案。

2.3 目标检测

大多数现代目标检测方法都相对于一些初始猜测进行预测。两阶段检测器[37,5]根据建议(proposals)预测边界框,而单阶段方法则根据锚点[23]或可能的物体中心网格[53,46]进行预测。最近的研究[52]表明,这些系统的最终性能在初始猜测的确切设置方式上具有很大的依赖性。在我们的模型中,我们能够通过直接预测与输入图像而不是锚点相关的一组检测结果,消除了这个手工制作的过程,并简化了检测过程。

基于集合的损失。一些目标检测器[9,25,35]使用了二分图匹配损失。然而,在这些早期的深度学习模型中,不同预测之间的关系仅使用卷积或全连接层来建模,而手动设计的非极大值抑制后处理可以提高它们的性能。更近期的检测器[37,23,53]在真值和预测之间使用了非唯一的分配规则,同时使用了非极大值抑制。

可学习的非极大值抑制方法[16,4]和关系网络[17]使用注意力明确建模了不同预测之间的关系。使用直接的集合损失,它们不需要任何后处理步骤。然而,这些方法使用额外的手工设计的上下文特征,如建议框坐标,以有效地建模检测之间的关系,而我们寻找减少模型中编码的先验知识的解决方案。

递归检测器。与我们的方法最接近的是用于目标检测[43]和实例分割[41,30,36,42]的端到端集合预测。与我们类似,它们使用基于CNN激活的编码器-解码器架构,使用二分图匹配损失直接生成一组边界框。然而,这些方法仅在小型数据集上进行了评估,而没有与现代基线模型进行比较。特别地,它们基于自回归模型(更精确地说是RNN),因此它们没有利用最近的具有并行解码的变换器模型。

3 DETR模型

在检测中进行直接集合预测需要两个关键因素:(1) 一种集合预测损失,它强制预测的边界框与真值边界框之间具有唯一匹配;(2) 一种体系结构,可以在单次传递中预测一组对象并建模它们之间的关系。我们在图2中详细描述了我们的体系结构。

3.1 目标检测集设置预测损失

DETR通过解码器单次推断出一个固定大小的N个预测,其中N被设置为明显大于图像中典型对象的数量。训练的主要困难之一是如何根据真值对预测的对象(类别、位置、大小)进行评分。我们的损失函数产生了预测对象和真值对象之间的最优二分图匹配,然后优化特定于对象的(边界框)损失。

让我们用y表示真值对象的集合,而ˆy = {ˆyi}N i=1表示N个预测的集合。假设N大于图像中的对象数量,我们也将y视为大小为N的集合,其中包括∅(表示没有对象的占位符)。为了在这两个集合之间找到一个二分图匹配,我们搜索一个具有最低成本的N个元素的排列σ ∈ SN:

其中Lmatch(yi, ˆyσ(i))是真值yi和索引σ(i)的预测之间的成本。这个最优分配是通过匈牙利算法高效计算的,这是根据之前的工作(例如[43])完成的。

匹配成本考虑了类别预测和预测框与真值框的相似性。真值集合的每个元素i可以看作是yi = (ci, bi),其中ci是目标类别标签(可能为∅),bi ∈ [0, 1]4是一个向量,定义了真值框的中心坐标以及相对于图像大小的高度和宽度。对于具有索引σ(i)的预测,我们将类别ci的概率定义为ˆpσ(i)(ci),并将预测框定义为ˆbσ(i)。使用这些符号,我们将Lmatch(yi, ˆyσ(i))定义为-1{ci=∅}ˆpσ(i)(ci) + 1{ci=∅}Lbox(bi, ˆbσ(i))。其中,1{ci=∅}是指示函数,如果ci不等于∅则为1,否则为0。这个成本函数综合考虑了类别匹配和框匹配。

这种找到匹配的过程在直接集合预测中起到了与现代检测器中用于将提议[37]或锚[22]与真值对象匹配的启发式分配规则相同的作用。主要区别在于,我们需要为直接集合预测找到不包含重复的一对一匹配。

第二步是计算损失函数,即在前一步中匹配的所有成对的匈牙利损失。我们将损失定义为类似于常见目标检测器的损失,即类别预测的负对数似然和稍后定义的框损失的线性组合:

其中ˆσ是第一步中计算的最优分配(1)。在实践中,当ci = ∅时,我们通过10倍的因子减小对数概率项的权重,以考虑类别不平衡。这类似于Faster R-CNN训练过程通过子采样平衡正样本/负样本提议[37]的方法。请注意,对象与∅之间的匹配成本不依赖于预测,这意味着在这种情况下成本是常数。在匹配成本中,我们使用概率ˆpˆσ(i)(ci)而不是对数概率。这使得类别预测项与Lbox(·, ·)(下文描述)具有可比性,并且我们观察到了更好的经验性能。

边界框损失。匹配成本和匈牙利损失的第二部分是Lbox(·),用于评分边界框。与许多检测器不同,它们根据与一些初始猜测的∆进行边界框预测,我们直接进行边界框预测。尽管这种方法简化了实现,但它在损失的相对缩放方面存在问题。最常用的1损失即使相对误差相似,对小框和大框也有不同的尺度。为了减轻这个问题,我们使用了1损失和广义IoU损失[38]的线性组合Liou(·, ·),这是尺度不变的。总的来说,我们的框损失是Lbox(bi, ˆbσ(i)),定义如下:

λiouLiou(bi, ˆbσ(i)) + λL1||bi − ˆbσ(i)||1,

其中λiou、λL1 ∈ R是超参数。这两个损失都被批次中的对象数量归一化。

3.2 DETR架构

DETR的总体架构出奇地简单,如图2所示。它包含三个主要组件,我们将在下面描述:一个CNN骨干网络用于提取紧凑的特征表示,一个编码器-解码器Transformer,以及一个简单的前馈网络(FFN)用于进行最终的检测预测。

与许多现代检测器不同,DETR可以在任何提供通用CNN骨干网络和Transformer架构实现的深度学习框架中实现,只需几百行代码。在PyTorch [32]中,可以使用不到50行代码实现DETR的推理代码。我们希望我们的方法的简单性能够吸引新的研究人员加入检测领域。

骨干网络。从初始图像ximg ∈ R3×H0×W0(具有3个颜色通道)开始,传统的CNN骨干网络会生成一个低分辨率的激活图f ∈ RC×H×W。我们通常使用的典型值为C = 2048和H,W = H0 32,W0 32。

Transformer编码器 首先,通过1x1卷积将高级别激活图f的通道维度从C减小到较小的维度d,创建一个新的特征图z0 ∈ Rd×H×W。编码器期望以序列作为输入,因此我们将z0的空间维度折叠成一个维度,得到一个d×HW的特征图。每个编码器层都具有标准的体系结构,包括多头自注意力模块和前馈网络(FFN)。由于Transformer架构是排列不变的,我们通过固定的位置编码[31,3]来补充它,这些编码被添加到每个注意层的输入中。我们将详细的架构定义放在了补充材料中,它遵循了[47]中描述的架构。
Transformer解码器。解码器遵循Transformer的标准架构,使用多头自注意力机制和编码器-解码器注意力机制来转换大小为d的N个嵌入。与原始的Transformer不同的是,我们的模型在每个解码器层上并行解码N个对象,而Vaswani等人[47]使用自回归模型,逐个元素地预测输出序列。不熟悉这些概念的读者可以参考补充材料。由于解码器也是排列不变的,因此N个输入嵌入必须不同以产生不同的结果。这些输入嵌入是学习的位置编码,我们称之为对象查询,与编码器类似,我们将它们添加到每个注意力层的输入中。N个对象查询通过解码器转换为输出嵌入。然后,它们通过前馈网络(在下一小节中描述)独立解码为边界框坐标和类标签,生成N个最终的预测。使用这些嵌入上的自注意力和编码器-解码器注意力,模型通过它们之间的成对关系全局推理所有对象,同时能够使用整个图像作为上下文。

预测前馈网络(FFNs) 。最终的预测由一个包含ReLU激活函数和隐藏维度d的3层感知器以及一个线性投影层计算。FFN预测了相对于输入图像的标准化中心坐标、高度和宽度,并且线性层使用softmax函数来预测类别标签。由于我们预测一个固定大小的N个边界框,其中N通常远大于图像中感兴趣的实际对象数量,因此额外的特殊类别标签∅ 用于表示某个槽内没有检测到对象。这个类别在标准目标检测方法中起着类似于"背景"类别的作用。
辅助解码损失。我们发现,在训练过程中使用辅助损失[1]特别有帮助,尤其是帮助模型输出每个类别的正确数量的对象。我们在每个解码器层之后添加了预测的前馈网络(FFNs)和匈牙利损失。所有预测的FFNs共享它们的参数。我们使用一个额外的共享层归一化来规范来自不同解码器层的预测FFNs的输入。

相关推荐
GocNeverGiveUp2 分钟前
机器学习2-NumPy
人工智能·机器学习·numpy
B站计算机毕业设计超人1 小时前
计算机毕业设计PySpark+Hadoop中国城市交通分析与预测 Python交通预测 Python交通可视化 客流量预测 交通大数据 机器学习 深度学习
大数据·人工智能·爬虫·python·机器学习·课程设计·数据可视化
学术头条1 小时前
清华、智谱团队:探索 RLHF 的 scaling laws
人工智能·深度学习·算法·机器学习·语言模型·计算语言学
18号房客1 小时前
一个简单的机器学习实战例程,使用Scikit-Learn库来完成一个常见的分类任务——**鸢尾花数据集(Iris Dataset)**的分类
人工智能·深度学习·神经网络·机器学习·语言模型·自然语言处理·sklearn
feifeikon1 小时前
机器学习DAY3 : 线性回归与最小二乘法与sklearn实现 (线性回归完)
人工智能·机器学习·线性回归
游客5201 小时前
opencv中的常用的100个API
图像处理·人工智能·python·opencv·计算机视觉
古希腊掌管学习的神1 小时前
[机器学习]sklearn入门指南(2)
人工智能·机器学习·sklearn
凡人的AI工具箱2 小时前
每天40分玩转Django:Django国际化
数据库·人工智能·后端·python·django·sqlite
咸鱼桨2 小时前
《庐山派从入门到...》PWM板载蜂鸣器
人工智能·windows·python·k230·庐山派
强哥之神2 小时前
Nexa AI发布OmniAudio-2.6B:一款快速的音频语言模型,专为边缘部署设计
人工智能·深度学习·机器学习·语言模型·自然语言处理·音视频·openai