TOOD Task-aligned One-stage Object Detection 论文学习

1. 解决了什么问题?

目标检测通过多任务学习的方式,协同优化目标的分类和定位。分类任务会学习目标的判别特征,关注于目标的显著性或关键区域,而定位任务则学习准确地定位目标的边界。因为定位和分类的学习机制不同,这俩任务学到的特征分布也不同。当这两个分支做预测时,会产生一定的错位现象。

  • 分类和定位是独立的。两个独立的分支并行地做目标分类和定位,任务之间缺乏交流,会造成预测结果不一致。如下图红色格子,ATSS 识别的是餐桌,但定位的是披萨饼。
  • Task-agnostic 样本分配。分类和定位的最佳 anchors 通常是不一致的,根据目标的形状和特性差异可能很大。常用的样本分配方法都是 task-agnostic,因此这俩任务很难作出准确而一致的预测。如下图中的绿色格子代表了最佳的定位 anchor,它不是目标的中心点,而且与分类的最佳 anchor(红色格子)没有对齐。这会造成 NMS 时,一个准确定位的边框被抑制掉。

如下图,上面一行是 ATSS 预测的分类得分和定位得分的空间分布,下面一行是 TOOD 预测的分类得分和定位得分的空间分布。黄框是 ground truth,红色格子是分类的最佳 anchor,绿色格子是定位的最佳 anchor,若二者重叠则只显示红色格子。红框和绿框分别是红色格子和绿色格子的 anchor 预测出的边框。白色箭头表示最佳 anchor 偏离目标中心点的主要方向。

2. 提出了什么方法?

TOOD 学习对齐这两个任务。作者首先设计了一个 task-aligned head (T-Head),增强两个任务间的交流,在学习 task-interactive 和 task-specific 特征方面取得更好的平衡,更准确地对齐预测结果。其次,提出了 task alignment learning (TAL),通过样本分配机制和 task-aligned loss 训练,拉近两个任务各自最优的 anchors 的距离。

2.1 Overview

TOOD 的整体流程采用了"主干-FPN-Head"的结构。TOOD 在每个位置使用一个 anchor,对于 anchor-free 检测器它就是一个 anchor point,对于 anchor-based 检测器它就是一个 anchor box。T-head 首先基于 FPN 特征预测分类和定位。然后 TAL 基于新提出的 task alignment metric 计算任务对齐信号,表示分类和定位预测的对齐程度。最后根据 TAL 反向传播的学习信号, T-head 自动地调节各类别的概率和边框位置。对得最齐的 anchor 会得到更高的分类得分,并学习偏移量使预测框更加准确。

2.2 Task-aligned Head

为了使 head 更加高效,作者从两个方面出发:

  • 增加两个任务间的交流;
  • 增强检测器学习对齐的能力。

T-head 如下图 (b) 所示,包括一个简单的特征提取器和两个 task-aligned predictors (TAP)。

为了增强分类和定位任务间的交流,使用一个特征提取器(多个卷积层)学习一组 task-interactive 特征,即上图(b) 蓝色部分。该设计不仅能提高任务间的交流,也能为这俩任务提供具有多尺度感受野的多层级特征。 X f p n ∈ R H × W × C X^{fpn}\in \mathbb{R}^{H\times W\times C} Xfpn∈RH×W×C代表 FPN 特征,其中 H , W , C H,W,C H,W,C分别是 FPN 特征的高度、宽度和通道数。特征提取器使用 N N N个卷积层和激活函数来计算 task-interactive 特征:
X k i n t e r = { δ ( conv k ( X f p n ) ) , k = 1 δ ( conv k ( X k − 1 i n t e r ) ) , k > 1 ∀ k ∈ { 1 , 2 , . . . , N } X_k^{inter}=\left\{ \begin{array}{ll} \delta(\text{conv}_k(X^{fpn})), \quad k=1 \\ \delta(\text{conv}k(X^{inter}{k-1})),\quad k>1 \end{array} \right. \forall k\in \lbrace 1,2,...,N\rbrace Xkinter={δ(convk(Xfpn)),k=1δ(convk(Xk−1inter)),k>1∀k∈{1,2,...,N}

其中 conv k \text{conv}_k convk和 δ \delta δ表示第 k k k个卷积层和 relu \text{relu} relu函数。在 head 里使用单分支结构提取丰富的多尺度特征,然后计算得到的 task-interactive 特征会输入进两个 TAP,学习对齐分类和定位。

2.2.1 TAP

在 task-interactive 特征上进行分类和定位,这两个任务能互相感知到对方的状态。但由于是单分支设计,task-interactive 特征难免会因任务不同而存在一定的特征冲突。作者提出了一个层注意力(layer attention)机制,动态地计算 task-specific 特征,使任务解耦。下图展示了 TAP,分别计算分类或定位的 task-specific 特征:

X k t a s k = w k ⋅ X k i n t e r , ∀ k ∈ { 1 , 2 , . . . , N } X_k^{task}=\boldsymbol{w}_k \cdot X_k^{inter},\forall k\in \lbrace1,2,...,N\rbrace Xktask=wk⋅Xkinter,∀k∈{1,2,...,N}

w k \boldsymbol{w}_k wk是学习得到的层注意力 w ∈ R N \boldsymbol{w}\in \mathbb{R}^N w∈RN的第 k k k个元素。 w \boldsymbol{w} w是从 task-interactive 特征计算而来的,能够获取 X i n t e r X^{inter} Xinter不同 layer 的依赖关系:

w = σ ( f c 2 ( δ ( f c 1 ( x i n t e r ) ) ) ) \boldsymbol{w}=\sigma(fc_2(\delta(fc_1(\boldsymbol{x}^{inter})))) w=σ(fc2(δ(fc1(xinter))))

f c 1 , f c 2 fc_1, fc_2 fc1,fc2代表两个全连接层, σ \sigma σ是 sigmoid \text{sigmoid} sigmoid函数, δ \delta δ是 relu \text{relu} relu函数。将 X k i n t e r X_k^{inter} Xkinterconcat 起来得到 X i n t e r X^{inter} Xinter,然后使用全局平均池化得到 x i n t e r \boldsymbol{x}^{inter} xinter。最后,分类或定位的结果由每个 X t a s k X^{task} Xtask预测得到:
Z t a s k = conv 2 ( δ ( conv 1 ( X t a s k ) ) ) Z^{task}=\text{conv}_2(\delta(\text{conv}_1(X^{task}))) Ztask=conv2(δ(conv1(Xtask)))

其中 X t a s k X^{task} Xtask是将 X k t a s k X_k^{task} Xktask特征 concat 起来, conv 1 \text{conv}_1 conv1是 1 × 1 1\times 1 1×1卷积,用于降维。 Z t a s k Z^{task} Ztask然后使用 sigmoid \text{sigmoid} sigmoid函数转换为分类得分 P ∈ R H × W × 80 P\in \mathbb{R}^{H\times W\times 80} P∈RH×W×80,或者用 distance-to-bbox \text{distance-to-bbox} distance-to-bbox转换为目标框 B ∈ R H × W × 4 B\in \mathbb{R}^{H\times W\times 4} B∈RH×W×4。

2.2.2 Prediction alignment

预测时,通过调节两个预测( P P P和 B B B)的空间分布,进一步对齐两个任务。以前的方法使用 center-ness 分支或 IoU 分支,基于类别特征或位置特征来调节类别预测。而本文则通过 task-interactive 特征综合考虑了两个任务,然后对齐这两个预测结果。如上图,作者使用一个空间概率图 M ∈ R H × W × 1 M\in \mathbb{R}^{H\times W\times 1} M∈RH×W×1调节类别的预测:

P a l i g n = P × M P^{align}=\sqrt{P\times M} Palign=P×M

M M M由 interactive 特征计算而来,学习每个空间位置上两个任务的对齐程度。

同时,为了对齐位置预测,从 interactive 特征学习一个空间偏移图 O ∈ R H × W × 8 O\in \mathbb{R}^{H\times W\times 8} O∈RH×W×8,调节每个位置的预测框坐标。该偏移量使对得最齐的 anchor point 能识别到它附近最优的边界预测:

B a l i g n ( i , j , c ) = B ( i + O ( i , j , 2 × c ) , j + O ( i , j , 2 × c + 1 ) , c ) B^{align}(i,j,c)=B(i+O(i,j,2\times c), j+O(i,j,2\times c+1), c) Balign(i,j,c)=B(i+O(i,j,2×c),j+O(i,j,2×c+1),c)

( i , j , c ) (i,j,c) (i,j,c)表示张量中第 c c c个通道的第 ( i , j ) (i,j) (i,j)个位置。上式通过双线性插值实现,因为 B B B的通道数不大,所以计算成本很低。注意,每个通道都会独立地学习偏移量,即目标的每条边都有一个偏移量。因为每条边都能学习它附近最准确的 anchor point,预测的四条边就能更加准确。因此,本文方法不仅能对齐定位和分类任务,也能通过识别每条边精确的 anchor point 来提升定位的准确率。

M M M和 O O O从 interactive 特征中自动地学习:
M = σ ( conv 2 ( δ ( conv 1 ( X i n t e r ) ) ) ) M=\sigma(\text{conv}_2(\delta(\text{conv}_1(X^{inter})))) M=σ(conv2(δ(conv1(Xinter))))
O = conv 4 ( δ ( conv 3 ( X i n t e r ) ) ) O=\text{conv}_4(\delta(\text{conv}_3(X^{inter}))) O=conv4(δ(conv3(Xinter)))

conv 1 \text{conv}_1 conv1和 conv 3 \text{conv}_3 conv3是 2 个 1 × 1 1\times 1 1×1的卷积层,用于通道降维。 M M M和 O O O通过 TAL 学习。T-head 是一个独立的模块,可以不需要 TAL。

2.3 Task Alignment Learning

使用 TAL 进一步指导 T-head 学习对齐分类和定位的预测。TAL 包括一个样本分配策略和一个新的损失函数。它从任务对齐的角度出发,动态地选取高质量 anchors,同时考虑 anchors 的分配和加权。

2.3.1 Task-aligned Sample Assignment

为了使用 NMS,anchor 分配应该满足以下两个条件:

  • 对齐的 anchor 能同时预测出较高的类别得分和准确的边框位置;
  • 错位的 anchor 的类别得分应该很低,会被抑制掉。

于是作者提出了一个新的 anchor 对齐度量,计算任务对齐的程度。在样本分配和损失函数中加入该度量,动态优化每个 anchor 的预测。

Anchor alignment metric. 类别得分和预测框与目标框间的 IoU 分别代表分类和定位任务的预测质量,作者将类别得分和 IoU 结合,表示俩任务的对齐程度。使用下式计算每个实例的各 anchor 的对齐度量:
t = s α × u β t=s^{\alpha}\times u^{\beta} t=sα×uβ

s s s表示类别得分, u u u表示 IoU。 α , β \alpha,\beta α,β用于控制两项任务施加的影响。 t t t在任务对齐优化的过程中扮演重要角色,使网络动态地关注于高质量(任务对齐)anchors。

Training sample assignment. 关注于任务对齐的 anchors,采用简单的分配规则来选取训练样本:对于每个实例,选取 m m m个 t t t值最大 anchors 作为正样本,其余的 anchors 作为负样本,然后计算损失。

2.4 Task-aligned Loss

分类目标函数

为了抬高对齐的 anchors 的类别得分,降低那些没对齐的 anchors ( t t t值偏小)的类别得分,训练时对 t t t做归一化得到 t ^ \hat{t} t^,代替 positive anchor 的二值标签。根据下面两个性质来归一化 t ^ \hat{t} t^:

  • 保证能有效地学习难例(即 t t t值较小的 positive anchors);
  • 保留实例之间关于预测框精度的排序。

于是作者采用归一化方法来调节 t ^ \hat{t} t^,对于每个实例, t ^ \hat{t} t^的最大值等于各 anchor 的 IoU ( u ) (u) (u)的最大值。对于分类任务,使用二元交叉熵来计算损失,将 t ^ \hat{t} t^作为正样本的 ground-truth 标签,而非 1

L c l s _ p o s = ∑ i = 1 N p o s B C E ( s i , t ^ i ) L_{cls\pos}=\sum{i=1}^{N_{pos}}BCE(s_i,\hat{t}_i) Lcls_pos=i=1∑NposBCE(si,t^i)
B C E ( s i , t ^ i ) = t ^ i ⋅ log ⁡ ( s i ) + ( 1 − t ^ i ) ⋅ log ⁡ ( 1 − s i ) BCE(s_i,\hat{t}_i)=\hat{t}_i \cdot \log(s_i) + (1 - \hat{t}_i) \cdot \log(1 - s_i) BCE(si,t^i)=t^i⋅log(si)+(1−t^i)⋅log(1−si)

i i i表示某个实例的 N p o s N_{pos} Npos个 positive anchors 中的第 i i i个 anchor。本文还用了 Focal Loss 缓解正负样本不均衡的问题。总的分类损失如下:
L c l s = ∑ i = 1 N p o s ∣ t ^ i − s i ∣ γ ⋅ B C E ( s i , t ^ i ) + ∑ j = 1 N n e g s j γ ⋅ B C E ( s j , 0 ) L_{cls}=\sum_{i=1}^{N_{pos}}|\hat{t}i-s_i|^{\gamma}\cdot BCE(s_i,\hat{t}i)+\sum{j=1}^{N{neg}}s_j^{\gamma}\cdot BCE(s_j,0) Lcls=i=1∑Npos∣t^i−si∣γ⋅BCE(si,t^i)+j=1∑Nnegsjγ⋅BCE(sj,0)

j j j表示 N n e g N_{neg} Nneg个 negative anchors 的第 j j j个 anchor。 γ \gamma γ和 Focal Loss 一文的含义相同,是调节系数。

定位目标函数

对齐的 anchors 预测出的边框通常置信度更高、边框更准确,才能在 NMS 时保留下来。训练时, t t t通过加权损失来提升高质量 anchor 的影响,降低低质量 anchor 的影响。高质量边框对模型有好处,而低质量边框则会产生大量冗余、无意义的信息。作者用 t t t值计算边框的质量。利用 t ^ \hat{t} t^对每个 anchor 的回归损失做加权

L r e g = ∑ i = 1 N p o s t ^ i L G I o U ( b i , b ‾ i ) L_{reg}=\sum_{i=1}^{N_{pos}}\hat{t}i L{GIoU}(b_i, \overline{b}i) Lreg=i=1∑Npost^iLGIoU(bi,bi)
b , b ‾ b,\overline{b} b,b分别是预测框和目标框。总的 TAL 训练损失是 L r e g L
{reg} Lreg和 L c l s L_{cls} Lcls之和。

相关推荐
佚明zj43 分钟前
全卷积和全连接
人工智能·深度学习
并不会2 小时前
常见 CSS 选择器用法
前端·css·学习·html·前端开发·css选择器
龙鸣丿2 小时前
Linux基础学习笔记
linux·笔记·学习
qzhqbb3 小时前
基于统计方法的语言模型
人工智能·语言模型·easyui
冷眼看人间恩怨4 小时前
【话题讨论】AI大模型重塑软件开发:定义、应用、优势与挑战
人工智能·ai编程·软件开发
2401_883041084 小时前
新锐品牌电商代运营公司都有哪些?
大数据·人工智能
Nu11PointerException4 小时前
JAVA笔记 | ResponseBodyEmitter等异步流式接口快速学习
笔记·学习
AI极客菌5 小时前
Controlnet作者新作IC-light V2:基于FLUX训练,支持处理风格化图像,细节远高于SD1.5。
人工智能·计算机视觉·ai作画·stable diffusion·aigc·flux·人工智能作画
阿_旭5 小时前
一文读懂| 自注意力与交叉注意力机制在计算机视觉中作用与基本原理
人工智能·深度学习·计算机视觉·cross-attention·self-attention
王哈哈^_^5 小时前
【数据集】【YOLO】【目标检测】交通事故识别数据集 8939 张,YOLO道路事故目标检测实战训练教程!
前端·人工智能·深度学习·yolo·目标检测·计算机视觉·pyqt