ultralytics框架进行RT-DETR目标检测训练

DETR提出以来,其采用匈牙利匹配方式真正的实现了端到端检测效果,避免了NMS等后处理过程,同时,相较CNN的局部特征提取,其凭借着Transformer强大的全局特征提取能力,在目标检测领域可谓大杀四方,基于Transformer的目标检测方法因此层出不穷。

然后,尽管DETR类目标检测方法具有较好的数据拟合能力,但Transformer本身的计算复杂度较高,这使其很难完成实时检测任务,而今天我们则要介绍的便是百度提出的实时DETR目标检测方法,这个方法我已经在先前的博客中有过介绍,当时是基于RT-DETR的源码进行介绍的,今天我们则要介绍的是ultralytics中的RT-DETR模型。

模型结构

RT-DETR模型结构如下:

python 复制代码
                   from  n    params  module                                       arguments                     
  0                  -1  1      9536  ultralytics.nn.modules.block.ResNetLayer     [3, 64, 1, True, 1]           
  1                  -1  1    215808  ultralytics.nn.modules.block.ResNetLayer     [64, 64, 1, False, 3]         
  2                  -1  1   1219584  ultralytics.nn.modules.block.ResNetLayer     [256, 128, 2, False, 4]       
  3                  -1  1   7098368  ultralytics.nn.modules.block.ResNetLayer     [512, 256, 2, False, 6]       
  4                  -1  1  14964736  ultralytics.nn.modules.block.ResNetLayer     [1024, 512, 2, False, 3]      
  5                  -1  1    524800  ultralytics.nn.modules.conv.Conv             [2048, 256, 1, 1, None, 1, 1, False]
  6                  -1  1    789760  ultralytics.nn.modules.transformer.AIFI      [256, 1024, 8]                
  7                  -1  1     66048  ultralytics.nn.modules.conv.Conv             [256, 256, 1, 1]              
  8                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']          
  9                   3  1    262656  ultralytics.nn.modules.conv.Conv             [1024, 256, 1, 1, None, 1, 1, False]
 10            [-2, -1]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 11                  -1  3   2232320  ultralytics.nn.modules.block.RepC3           [512, 256, 3]                 
 12                  -1  1     66048  ultralytics.nn.modules.conv.Conv             [256, 256, 1, 1]              
 13                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']          
 14                   2  1    131584  ultralytics.nn.modules.conv.Conv             [512, 256, 1, 1, None, 1, 1, False]
 15            [-2, -1]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 16                  -1  3   2232320  ultralytics.nn.modules.block.RepC3           [512, 256, 3]                 
 17                  -1  1    590336  ultralytics.nn.modules.conv.Conv             [256, 256, 3, 2]              
 18            [-1, 12]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 19                  -1  3   2232320  ultralytics.nn.modules.block.RepC3           [512, 256, 3]                 
 20                  -1  1    590336  ultralytics.nn.modules.conv.Conv             [256, 256, 3, 2]              
 21             [-1, 7]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 22                  -1  3   2232320  ultralytics.nn.modules.block.RepC3           [512, 256, 3]                 
 23        [16, 19, 22]  1   7310072  ultralytics.nn.modules.head.RTDETRDecoder    [4, [256, 256, 256]]          
rtdetr-resnet50 summary: 402 layers, 42,768,952 parameters, 42,768,952 gradients

训练代码

这里我们使用ResNet50作为我们的特征提取网络,从论文的结果来看,RT-DETR只需要24epoch便能达到一个较好的效果。

python 复制代码
from ultralytics import RTDETR
# 加载预训练模型
model = RTDETR("rtdetr-resnet50.yaml")
# 开始训练
results = model.train(
    data="others.yaml",
    epochs=24,
    batch=6,       # 根据GPU显存调整(T4建议batch=8)
    imgsz=640,
    device="0",     # 指定GPU ID
    optimizer="AdamW",
    lr0=1e-4,
    warmup_epochs=4,
    label_smoothing=0.1,
    amp=True
)

相较于YOLO系列的目标检测方法,其训练速度要慢很多,其需要一个epoch需要的时间大概是6分钟,不过从结果来看,其只需要20 epoch便能达到一个较好的效果。

从模型文件大小来看,RT-DETR的文件大小约为245MB,而YOLO模型的大小多在1020MB

最终结果如下:


相关推荐
拓端研究室TRL3 分钟前
消费者网络购物意向分析:调优逻辑回归LR与决策树模型在电商用户购买预测中的应用及特征重要性优化
人工智能·算法·决策树·机器学习·逻辑回归
南玖yy6 分钟前
C++ 类模板三参数深度解析:从链表迭代器看类型推导与实例化(为什么迭代器类模版使用三参数?实例化又会是怎样?)
开发语言·数据结构·c++·人工智能·windows·科技·链表
学术-张老师8 分钟前
PABD 2025:大数据与智慧城市管理的融合之道
大数据·论文阅读·人工智能·智慧城市·论文笔记
技术吧15 分钟前
Spark-TTS: AI语音合成的“变声大师“
大数据·人工智能·spark
掘金酱31 分钟前
创作者训练营:老友带新+新人冲榜,全员参与,双倍快乐!
前端·人工智能·后端
我就是全世界43 分钟前
Magentic-UI:人机协作的网页自动化革命
运维·人工智能·ui·自动化
CoovallyAIHub1 小时前
基于YOLO-NAS-Pose的无人机象群姿态估计:群体行为分析的突破
深度学习·算法·计算机视觉
一个热爱生活的普通人1 小时前
不需要apikey认证的大模型api如何在cline上配置(结尾附带cline系统提示词)
人工智能·aigc
Xyz_Overlord1 小时前
机器学习----决策树
人工智能·决策树·机器学习
codegarfield1 小时前
神经网络中的梯度消失与梯度爆炸
人工智能·神经网络·resnet·梯度