论文阅读笔记:Semi-DETR: Semi-Supervised Object Detection with Detection Transformers
- [1 背景](#1 背景)
-
- [1.1 动机](#1.1 动机)
- [1.2 问题](#1.2 问题)
- [2 创新点](#2 创新点)
- [3 方法](#3 方法)
- [4 模块](#4 模块)
-
- [4.1 分阶段混合匹配](#4.1 分阶段混合匹配)
- [4.2 跨视图查询一致性](#4.2 跨视图查询一致性)
- [4.3 基于代价的伪标签挖掘](#4.3 基于代价的伪标签挖掘)
- [4.4 总损失](#4.4 总损失)
- 效果
-
- [5.1 和SOTA方法对比](#5.1 和SOTA方法对比)
- [5.2 消融实验](#5.2 消融实验)
论文:https://arxiv.org/pdf/2307.08095v1.pdf
代码:https://github.com/JCZ404/Semi-DETR
1 背景
1.1 动机
虽然DETR-based方法在全监督目标检测中实现了SOTA性能,但一个可行的DETR-based半监督目标检测(SSOD)框架仍然有待探索。
1.2 问题
问题1:1对1的分配策略具有NMS-free端到端检测的优点,在半监督场景的效率较低。
如果直接用检测器对未标记图像进行伪标注,当伪包围框不准确时,一对一分配策略会将单个不准确的提议匹配为正样本,而降其他潜在正确的提议匹配为负样本,从而噪声学习效率低下。
问题2:1对多的分配策略获得了质量更好的候选建议集吗,使得检测器优化效率更高,但会引入重复预测。
问题3:SSOD中常用的一致性正则化方法在DTER-based SSOD方法中不可行。
因为DETR-based检测器通过注意力机制不断更新query特征,随着query特征的变化,预测结果也会发生变化,即输入对象查询与其输出预测结果之间不存在确定的对应关系,这使得一致性正则化无法应用于DETR-based检测器中。
2 创新点
作者在TeacherStudent架构的基础上提出了一个新的基于DETR的SSOD框架Semi-DETR。如图1(b)所示。主要是
(1)提出了一个分阶段混合匹配模块 ,分别使用1对多分配 和1对1分配两个阶段训练。第一个阶段旨在通过1对多分配策略提高训练效率,从而为第二个阶段的1对1训练提供高质量的伪标签。
(2)引入了一个跨视图查询一致性模块,该模块构建了跨视图对象查询,以消除对象查询确定性对应的要求,并帮助检测器在两个增强试图之间学习对象查询的语义不变特征。
(3)基于高斯混合模型设计了一个基于代价的伪标签挖掘模块,该模块根据匹配代价分布动态的挖掘用于一致性学习的可靠伪框。
提出的方法效果如图2。
3 方法
提出的Semi-DETR的整体框架如图3所示。根据SSOD流行的教师学生模型,作者提出的Semi-DETR采用了一对具有完全相同网络结构的教师和学生模型(论文里采用的是DINO)。在每次训练迭代中,弱增强和强增强的未标记图像分别反馈给教师和学生网络。然后将教师生成的置信度大于 τ s \tau_s τs 的伪标签作为训练学生网络的监督。学生的参数参数通过反向传播更新,教师模型参数是学生模型的EMA。
4 模块
4.1 分阶段混合匹配
在学生的预测和教师生成的伪标注之间执行匈牙利匹配,可以得到一个最优的1对1分配 σ o 2 o \sigma_{o2o} σo2o:
其中 ξ N \xi_N ξN 是 N个元素的置换构成的集合, C m a t c h ( y ^ i t , y ^ σ ( i ) s ) C_{match}(\hat{y}^t_i,\hat{y}^s_{\sigma(i)}) Cmatch(y^it,y^σ(i)s) 伪标签 y ^ i t \hat{y}_i^t y^it 和学生模型的第 σ ( i ) \sigma(i) σ(i) 个预测之间的匹配代价。
由于在SSOD训练的早期阶段,教师生成的伪标注通常是不准确和不可靠的,这使得在1对1分配策略下生成稀疏和低质量建议的风险很高。为了利用多个正查询来实现高效的半监督,作者提出使用1对多的分配代替1对1的分配:
其中 C N M C_N^M CNM 是 M 和 N 的组合,即 M 个提议的子集被分配给每个伪框 y ^ i t \hat{y}_i^t y^it 中。使用分类得分 s s s 和 IoU值 u u u 的高阶组合作为匹配代价度量:
其中 α \alpha α 和 β \beta β 是分类得分和IoU的影响因子,论文中设 α = 1 , β = 6 \alpha=1,\beta=6 α=1,β=6。通过1对多分配,选择 m m m 值最大的 M 个提案作为正样本,其余为负样本。
分类损失和回归损失也做了相应修改:
其中 γ \gamma γ 设置为2。通过为每个伪标签分配多个正建议,潜在的高质量正建议也获得了被优化的机会,大大提高了收敛速度,进而获得更好的伪标签。然而每个伪标签的多个正建议会导致重复的预测,为了缓解这一问题,在第二阶段切换回1对1的分配训练。通过这样做,在第一阶段训练后享受高质量的伪标签,并逐步减少重复预测,以在第二阶段通过1对1分配训练出NMS-free的检测器。该阶段的损失为:
教师网络的结果会采用NMS去重。
4.2 跨视图查询一致性
在传统的非DETR-based的SSOD框架中,给定相同的输入 x x x 并采用不同的随机增广,一致性正则化通过最小化教师 f θ f_\theta fθ 和学生 f θ ′ f'_\theta fθ′ 的输出之差来监督模型:
然而对于 DETR-based 框架,由于输入对象查询与输出预测结果之间没有明确的对应关系,因此进行一致性正则化变得不可行 。
图4展示了提出的跨视图查询一致性模块。具体来说,对于每一幅未标图像,给定一组伪边框 b b b,用若干个 MLP 处理 RoI Align 提取的 ROI 特征:
其中, F t F_t Ft 和 F S F_S FS 分别是教师和学生的骨干特征。随后, c t c_t ct 和 c s c_s cs 被视为跨视图查询嵌入,和另一个视图中的原始对象查询合并,作为解码器的输入:
其中 q . q_. q. 和 E . E_. E. 表示原始对象查询和编码特征, o ^ . \hat{o}. o^. 和 o . o. o. 分别表示跨视图查询和原始对象查询的解码特征。下标 t t t 和 s s s 分别表示教师和学生,为了避免信息泄露,还使用了注意力掩膜 A A A。
在跨视图查询嵌入的语义引导下,解码特征的对应关系可以自然的得到保证,一致性损失如下:
4.3 基于代价的伪标签挖掘
为了在跨视图查询一致性学习中挖掘出更多具有有意义语义内容的伪框,作者提出了一种基于代价的伪标签挖掘伪框模块,动态地在伪标注数据中挖掘出可靠的伪框。具体来说,在初始过滤的伪框和预测建议之间进行额外的二分匹配,并利用匹配代价来描述伪框的可靠性:
其中 p i p_i pi, b i b_i bi 表示第 i i i 个建议预测的分类和回归, p ^ j \hat{p}_j p^j, b ^ j \hat{b}_j b^j 表示第 j j j 个伪标签的类别和框坐标。
最后,在每个训练批次中,通过拟合高斯混合模型的匹配代价分布,将初始伪框类分为两种状态,如图5所示,匹配代价和伪框的质量非常吻合。作者进一步将可靠聚类中心的代价值设置为阈值,并收集所有代价低于阈值的伪框用于跨视图查询一致性计算。
先通过教师模型预测的每幅图像的所有建议框置信度的均值假方差获得图像级的置信度阈值,使用阈值过滤得到的初始伪标签,如图(b)所示。
代码如https://github.com/JCZ404/Semi-DETR/blob/main/detr_ssod/models/dino_detr_ssod.py#L921:
pythonavg_score = torch.mean(proposal_box[:, -1]) std_score = torch.std(proposal_box[:, -1]) pseudo_thr = avg_score + std_score # filter the pseudo bbox valid_inds = torch.nonzero(proposal_box[:, -1] >= pseudo_thr, as_tuple=False).squeeze().unique()
然后对学生模型预测的结果和伪标签将进行匈牙利匹配,计算每一批次内每个边界框的匹配代价,用GMM模型拟合,如图(a)所示。作者认为成本较低的伪框更可能是可靠的伪框,因此从GMM模型中取较低的阈值来再次过滤伪标签,得到(d)中呈现的可靠伪框。最终会用人为设定的阈值过滤出的伪框计算无监督损失,并将GMM模型过滤的伪框和人为阈值过滤的伪框合并,用于计算一致性损失。
代码如https://github.com/JCZ404/Semi-DETR/blob/main/detr_ssod/models/dino_detr_ssod.py#L332:
pythonvalid_inds = torch.nonzero(match_gt_cost <= thr_, as_tuple=False).squeeze().unique() valid_gt_inds_1 = match_gt_inds[valid_inds] valid_gt_inds_2 = torch.nonzero(gt_scores >= base_thr, as_tuple=False).squeeze().unique() valid_gt_inds = torch.cat((valid_gt_inds_1.to(imgs.device), valid_gt_inds_2.to(imgs.device))).unique() gt_bboxes_list.append(gt_bboxes[valid_gt_inds_2, :4]) gt_labels_list.append(gt_labels[valid_gt_inds_2]) gt_scores_list.append(gt_scores[valid_gt_inds_2]) # ==== High recall pseudo labels for consistency ==== unsup_bboxes_gmm_list.append(gt_bboxes[valid_gt_inds, :4]) unsup_labels_gmm_list.append(gt_labels[valid_gt_inds]) unsup_scores_gmm_list.append(gt_scores[valid_gt_inds])
4.4 总损失
总损失函数如下:
其中 w u = 4 , w c = 1 w_u=4,w_c=1 wu=4,wc=1。 T 1 T_1 T1 是SHM模块的第一个阶段,后面实验中测试最佳轮次为60K, t t t 是当前训练轮次。
效果
5.1 和SOTA方法对比
5.2 消融实验
本文提出的各模块的消融实验。
比较不同方法生成的方法来为CQC生成伪标签,其中本文提出的基于Cost的GMM阈值过滤效果最好。
第一阶段1对多分配策略的消融实验。
第一阶段的训练轮数的消融实验。
伪标签阈值 τ s \tau_s τs 的消融实验。