CrossKD: Cross-Head Knowledge Distillation for Object Detection

文章信息

贡献

  • 提出了一种新的预测模仿蒸馏方法 CrossKD,有效缓解了目标冲突问题。
  • 在 MS COCO 数据集上实现了显著的性能提升,超越了现有的 KD 方法。
  • 证明了 CrossKD 对不同检测器和异构骨干网络的适用性。
  • 提供了详细的实验分析,展示了 CrossKD 在处理目标冲突时的鲁棒性。

研究背景

知识蒸馏(Knowledge Distillation, KD)是一种用于将大型教师模型的知识迁移到小型学生模型的模型压缩技术。在目标检测领域,KD 方法主要分为两类:

  • 预测模仿(prediction mimicking):预测模仿是知识蒸馏中最直接的方法之一,其核心思想是让学生模型的输出尽可能接近教师模型的输出。具体来说,教师模型会生成关于目标检测任务的预测结果(如目标的类别概率和边界框位置),学生模型则通过模仿这些预测结果来学习教师模型的知识。例如,传统的软标签蒸馏方法会将教师模型的输出概率分布作为额外的监督信号,引导学生模型的训练。这种方法的优点是简单直接,但可能无法充分利用教师模型的中间层特征信息。
  • 特征模仿(feature imitation):特征模仿则关注于教师模型中间层的特征表示,通过让学生模型模仿教师模型的中间特征来实现知识迁移。这些中间特征通常包含了丰富的语义信息和空间结构信息,能够帮助学生模型更好地理解和学习教师模型的表示能力。例如,一些方法会设计特定的特征适配层或注意力机制,以对齐教师和学生模型的特征维度,并引导学生模型学习教师模型的关键特征。此外,特征模仿还可以结合预测信息进行优化,例如通过预测差异来指导特征模仿,从而更直接地提升学生模型的检测精度。

尽管预测模仿是 KD 的早期策略,但其效率一直低于特征模仿。最近的研究(如 LD 方法)通过转移定位知识提升了预测模仿的效果,但预测模仿仍存在优化目标冲突的问题。

研究动机

  • 目标冲突问题:在预测模仿中,学生模型需要同时模仿教师模型的预测和真实标签,但教师模型的预测与真实标签之间可能存在较大差异,导致学生模型的优化过程存在冲突。
  • 现有方法的局限性:现有方法通过选择特定区域或调整权重来缓解目标冲突,但这些方法忽略了那些具有高不确定性的区域,而这些区域可能包含对学生模型更有价值的信息。

CrossKD 方法

CrossKD 是一种新的预测模仿蒸馏方案,通过将学生检测头的中间特征传递到教师检测头中,生成交叉头预测(cross-head predictions),然后强制这些预测模仿教师模型的预测。这种方法的主要优点包括:

  • 缓解目标冲突:通过共享教师检测头的部分结构,交叉头预测与教师预测更加一致,减少了教师-学生对之间的差异,提高了预测模仿的稳定性。
  • 任务导向的信息传递:与特征模仿相比,CrossKD 更直接地传递任务相关的知识。

方法细节

  • 交叉头知识蒸馏框架:
    将学生检测头的中间特征传递到教师检测头的下一层,生成交叉头预测。
    在交叉头预测和教师预测之间计算蒸馏损失。
    通过冻结教师层的梯度,将蒸馏损失反向传播到学生模型的中间特征中。
  • 优化目标:
    总损失包括检测损失和蒸馏损失。
    在分类分支中,使用 Quality Focal Loss(QFL) 作为蒸馏损失。
    在回归分支中,根据不同的回归形式选择 GIoU 或 KL 散度 作为蒸馏损失。

实验

关键结论

  • 性能提升:
    CrossKD 在 GFL 上实现了 43.7 AP,比基线模型(40.2 AP)提高了 3.5 AP,超越了所有现有的 KD 方法。
    结合特征模仿方法(如 PKD),进一步提升到 43.9 AP。
    对不同检测器的适用性:
    在 RetinaNet、FCOS 和 ATSS 上应用 CrossKD,均取得了显著的性能提升。
    例如,RetinaNet 的 AP 从 37.4 提升到 39.7,FCOS 的 AP 从 38.5 提升到 41.3。
  • 对异构骨干网络的适用性:
    在使用不同骨干网络(如 Swin-T 和 MobileNetv2)的检测器上,CrossKD 也表现出色。
    例如,在 Swin-T 到 ResNet-50 的蒸馏中,CrossKD 实现了 38.0 AP,比 PKD 高 0.8 AP。
  • 对目标冲突的鲁棒性:
    即使在教师和学生使用不同标签分配器的情况下,CrossKD 仍能显著提升学生模型的性能。
    例如,使用 RetinaNet 作为教师时,CrossKD 将 GFL 的 AP 提升到 41.2,而传统预测模仿方法的 AP 仅为 30.3。

总结

CrossKD 是一种针对密集目标检测器的知识蒸馏方法,通过交叉头预测有效缓解了目标冲突问题,并在多个检测器和异构骨干网络上取得了 SOTA 性能。未来的工作可能会将 CrossKD 扩展到其他相关领域,如 3D 目标检测。

相关推荐
阿正的梦工坊10 分钟前
变分扩散模型 ELBO 重构推导详解
人工智能·深度学习·算法·机器学习
紫雾凌寒43 分钟前
计算机视觉|Swin Transformer:视觉 Transformer 的新方向
人工智能·深度学习·计算机视觉·transformer·vit·swintransformer·视频理解
国货崛起1 小时前
宇树科技再落一子!天羿科技落地深圳,加速机器人创世纪
人工智能·科技·机器人
EasyCVR1 小时前
安防监控/视频集中存储EasyCVR视频汇聚平台如何配置AI智能分析平台的接入?
人工智能·音视频·webrtc·rtsp·gb28181
weixin_519311741 小时前
3.多线程获取音频AI的PCM数据
人工智能·音视频·pcm
小机学AI大模型1 小时前
【手撕算法】支持向量机(SVM)从入门到实战:数学推导与核技巧揭秘
人工智能
胡耀超1 小时前
3.激活函数:神经网络中的非线性驱动器——大模型开发深度学习理论基础
人工智能·深度学习·神经网络·大模型
牛奶2 小时前
前端学AI:基于Node.js的LangChain开发-知识概念
前端·人工智能·aigc
扫地僧9852 小时前
基于提示驱动的潜在领域泛化的医学图像分类方法(Python实现代码和数据分析)
人工智能·分类·数据挖掘