RT-DETR全解析:超越 YOLO,实时物体检测更快更精准(附代码)

引言

近年来,物体检测技术在许多领域取得了显著进展,如视频监控、自动驾驶、智能家居等。传统的物体检测模型,如YOLO(You Only Look Once)系列,已经在实时检测任务中得到了广泛应用,但仍然存在一些限制,比如速度与精度的平衡问题。

RT-DETR (Real-Time Detection Transformer) 是第一个实时端到端目标检测模型。它是一种基于Transformer架构的物体检测模型,旨在提供比YOLO更高效、更快速的实时物体检测性能。RT-DETR利用Transformer在计算机视觉任务中的能力,RT-DETR为实时目标检测带来了新的性能水平。甚至被称为"YOLO终结者",那它是否能终结YOLO,让我们一探究竟。

模型介绍

RT-DETR是基于DETR(DEtection TRansformer)模型进行改进的。DETR的核心思想是通过Transformer结构进行物体检测,摒弃了传统的卷积神经网络(CNN)中对物体位置的预测和边界框回归,而是采用了全局自注意力机制来直接生成检测结果。

然而,DETR在实时性方面表现欠佳,主要由于其较长的推理时间和计算复杂度。为了弥补这一缺点,RT-DETR进行了优化,使得其不仅继承了Transformer的优势,还能够在保持较高精度的同时,大大提高了推理速度。

RT-DETR的改进创新

  • 去除NMS,端到端训练

传统的物体检测模型,如YOLO和Faster R-CNN,通常依赖于NMS(非最大抑制)来从多个重叠的边界框中选择最优框。NMS过程虽然有效,但会增加额外的计算开销,降低检测速度。RT-DETR的最大创新之一是完全去除了NMS步骤,采用了端到端的Transformer架构,直接在输出中生成最终的物体检测结果。通过这种方式,RT-DETR减少了计算复杂度,提高了推理速度。

端到端训练:RT-DETR采用端到端的训练方法,使得模型的输入到输出完全一体化,不需要复杂的后处理步骤。这样,训练过程更加高效,推理速度得以加快。

无需NMS:通过创新的查询机制,RT-DETR能够有效识别每个物体,并直接生成边界框位置和类别,而无需依赖于传统的NMS算法来去除冗余框。

RT-DETR中的无NMS实现:

ini 复制代码
import torch
from transformers import DetrImageProcessor, DetrForObjectDetection
from PIL import Image
import requests

# 加载预训练模型
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")

# 加载测试图片
url = "https://path_to_your_image.jpg"
image = Image.open(requests.get(url, stream=True).raw)

# 对图像进行预处理
inputs = processor(images=image, return_tensors="pt")

# 推理并获取检测结果
with torch.no_grad():
    outputs = model(**inputs)

# 处理检测结果(无需NMS)
target_sizes = torch.tensor([image.size[::-1]])  # 图像的宽高
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]

# 结果包含最终的检测框(无需NMS)
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
    print(f"Score: {score.item()} - Label: {label.item()} - Box: {box.tolist()}")
  • 混合编码器架构

RT-DETR在DETR的基础上引入了混合编码器(Hybrid Encoder)来提高对不同尺度物体的处理能力。传统的卷积神经网络(CNN)往往在处理不同尺度的物体时表现较差,尤其是小物体的检测。而混合编码器能够有效融合多层次的特征,帮助模型捕捉图像中的细节信息,从而提高物体检测的精度。

多尺度特征融合:混合编码器采用了多尺度特征融合策略,结合低层次的细节信息和高层次的语义信息,帮助模型更好地理解图像中的物体。

提高小物体检测能力:通过优化不同尺度特征的处理,RT-DETR在小物体的检测上表现尤为突出,能够较好地捕捉到图像中的小物体,减少漏检情况。

RT-DETR中的混合编码器:

ini 复制代码
import torch
from transformers import DetrImageProcessor, DetrForObjectDetection

# 加载预训练模型
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")

# 图像预处理
inputs = processor(images=image, return_tensors="pt")

# 推理并生成特征
outputs = model(**inputs)

# 输出注意力层的特征
attentions = outputs.attentions  # Transformer的注意力层输出

这里,outputs.attentions会返回Transformer的注意力机制的输出,混合了卷积网络提取的特征和Transformer的自注意力特性,从而在多个尺度上处理物体信息。

  • 查询选择器(Query Selector)

DETR模型通过将物体检测问题转化为集合匹配问题,使用查询(queries)来生成检测结果。RT-DETR对这一机制进行了进一步的优化,使用了查询选择器来提高查询的选择效率。传统的DETR模型需要通过学习与图像中物体相关的查询,但这种方法在大规模数据集上训练时可能不够高效。

优化的查询选择器:RT-DETR通过改进查询选择策略,使得模型能够在较短的时间内更精确地选择物体,从而提高了推理速度。

快速物体检测:查询选择器不仅提升了精度,还加快了检测过程,减少了模型对每个物体的搜索时间,从而大幅提高了实时性。

  • 改进的解码器结构

RT-DETR的解码器结构经过优化,使其能够更高效地生成物体的位置和类别。在传统的DETR中,解码器通常需要大量的计算来匹配物体和查询,但RT-DETR在此基础上进行了改进,通过更高效的解码机制加速了计算过程。

高效解码器:RT-DETR优化了解码器的结构,引入了多尺度可变形注意力,使其能够更快地处理图像中的所有物体,并减少计算负担。

减少计算复杂度:解码器的高效设计不仅降低了计算时间,还减少了模型的内存消耗,使其更加适合嵌入式设备和实时检测任务。

  • 轻量化设计

RT-DETR通过在模型设计中引入轻量化技术,进一步提高了模型的推理速度。相比于传统的大型卷积神经网络(CNN),RT-DETR使用了更加高效的Transformer架构,减少了参数量和计算量。

轻量化模型结构:通过减少不必要的层次和模块,RT-DETR在确保检测精度的同时,降低了计算开销,使得模型更适合嵌入式设备和边缘计算。

RT-DETR架构

  • 骨干

主干是初始阶段,在此阶段处理输入图像以提取特征。这通常涉及卷积神经网络 (CNN),该网络在增加深度的同时减小图像的空间维度,并在各个阶段 (S3、S4、S5) 生成特征图。

  • 高效混合编码器

该模块处理来自主干的特征图。它包含:

AIFI(自适应交互融合集成):它融合了来自不同级别的主干(S3、S4、S5)的特征,以创建更丰富的表示。python实现代码:

ini 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

class AIFI(nn.Module):
    def __init__(self, in_channels_s3, in_channels_s4, in_channels_s5, out_channels):
        super(AIFI, self).__init__()
        # 假设S3, S4, S5的通道数分别为in_channels_s3, in_channels_s4, in_channels_s5
        # 使用1x1卷积来调整通道数
        self.conv_s3 = nn.Conv2d(in_channels_s3, out_channels, kernel_size=1)
        self.conv_s4 = nn.Conv2d(in_channels_s4, out_channels, kernel_size=1)
        self.conv_s5 = nn.Conv2d(in_channels_s5, out_channels, kernel_size=1)

        # 自适应加权模块,利用全连接层进行权重计算
        self.fc = nn.Linear(out_channels * 3, 1)

    def forward(self, s3, s4, s5):
        # 对不同层次特征进行处理
        s3_out = self.conv_s3(s3)
        s4_out = self.conv_s4(s4)
        s5_out = self.conv_s5(s5)

        # 融合特征(可以通过拼接或者加权)
        fused_features = torch.cat((s3_out, s4_out, s5_out), dim=1)  # 拼接通道

        # 计算加权系数
        attention_weights = self.fc(fused_features.view(fused_features.size(0), -1))  # 平展后输入全连接层
        attention_weights = F.sigmoid(attention_weights)  # 使用Sigmoid激活函数确保权重在0-1之间

        # 加权融合后的特征
        weighted_fusion = fused_features * attention_weights.unsqueeze(-1).unsqueeze(-1)
        return weighted_fusion.sum(dim=1)  # 对通道维度进行求和,得到最终融合特征

CCFF(跨尺度通道融合):该模块进行多层次的融合,结合来自不同尺度的特征,以保持高级语义信息和低级细节特征之间的平衡。

CCFF模块通常采用跨通道注意力机制或简单的加权平均来完成这一任务:

ini 复制代码
class CCFF(nn.Module):
    def __init__(self, in_channels_list, out_channels):
        super(CCFF, self).__init__()
        self.convs = nn.ModuleList([nn.Conv2d(in_channels, out_channels, kernel_size=1) for in_channels in in_channels_list])
        self.attn_fc = nn.Linear(out_channels * len(in_channels_list), 1)

    def forward(self, features):
        # features是一个包含多个尺度特征的列表
        processed_features = [conv(f) for conv, f in zip(self.convs, features)]  # 每个尺度的特征通过1x1卷积处理
        
        # 融合多个尺度的特征
        fused_features = torch.cat(processed_features, dim=1)  # 拼接各尺度特征
        
        # 计算注意力权重
        attention_weights = self.attn_fc(fused_features.view(fused_features.size(0), -1))  # 使用全连接层计算加权系数
        attention_weights = F.sigmoid(attention_weights)  # 使用sigmoid激活函数
        
        # 对融合特征进行加权
        weighted_fusion = fused_features * attention_weights.unsqueeze(-1).unsqueeze(-1)  # 加权融合
        return weighted_fusion.sum(dim=1)  # 对通道维度进行求和,返回融合后的特征

F5:融合后的最终特征图,输入到下一阶段。

  • 查询选择器

该组件根据查询的不确定性进行选择。它确保为检测任务选择最具信息量的查询,从而减少冗余并提高效率。

所选编码器特征的分类和 IoU 分数。紫色和绿色点分别表示使用不确定性最小查询选择和原始查询选择训练的模型中选择的特征。

  • 解码器

解码器和头部模块处理选定的查询以产生最终的检测输出。

位置嵌入:向查询添加空间信息,以帮助模型理解图像中物体的位置。

图像特征:这里使用了来自编码器的精炼特征。

对象查询:这些是可学习的嵌入,可帮助模型关注图像的不同部分以进行对象检测。

  • 检测头

最终检测使用两种类型的卷积层:

Conv1x1 s1:1x1卷积,然后进行批量标准化(BN)和 SiLU 激活。

Conv3x3 s2:3x3卷积,后跟批量标准化 (BN) 和SiLU激活。

RT-DETR应用代码:

import 复制代码
from transformers import DetrImageProcessor, DetrForObjectDetection
from PIL import Image
import requests
import matplotlib.pyplot as plt

# 下载预训练模型
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")

# 加载测试图片
url = "https://path_to_your_image.jpg"
image = Image.open(requests.get(url, stream=True).raw)

# 对图片进行预处理
inputs = processor(images=image, return_tensors="pt")

# 模型推理
with torch.no_grad():
    outputs = model(**inputs)

# 获取检测结果
target_sizes = torch.tensor([image.size[::-1]])  # 图像的宽高
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]

# 绘制检测结果
fig, ax = plt.subplots(1, figsize=(12, 9))
ax.imshow(image)

for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
    box = [round(i, 2) for i in box.tolist()]
    ax.add_patch(plt.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1],
                               fill=False, color="red", linewidth=3))
    ax.text(box[0], box[1], f"{model.config.id2label[label.item()]}: {round(score.item(), 3)}",
            color="white", fontsize=12, weight="bold")
plt.show()

RT-DETR与YOLO对比

传统上,YOLO(You Only Look Once)凭借其轻量级和直接的设计,一直是实时物体检测的首选模型。

然而,RT-DETR引入了与YOLO不同的创新。尽管YOLO因其在速度和准确性之间的平衡而广受欢迎,但它也面临着挑战,尤其是在非最大抑制 (NMS) 方面。

  • 实验比较

RT-DETR-R50/R101在COCO上实现了53.1%/54.3%AP,在T4 GPU上实现了108/74FPS,在速度和准确度方面均优于之前先进的YOLO。此外,RT-DETR-R50的准确度比DINO-R50高出2.2%AP,FPS高出约21倍。在使用Objects365进行预训练后,RTDETR-R50/R101实现了55.3%/56.2% AP。

RT-DETR在速度和准确率方面均优于所有具有相同主干的DETR。与DINO-Deformable-DETR-R50相比,RT-DETR-R50的准确率提高了2.2%AP,速度提高了21倍 (108 FPSvs 5 FPS),两项都有显著提升。

  • 推理速度

RT-DETR通过优化模型架构和解码器,使得推理速度相比DETR更快,接近YOLO系列的水平,适合在实时应用中使用。

  • 精度

在精度方面,RT-DETR在许多基准数据集上与YOLO相当,尤其是在较小物体的检测上,RT-DETR的表现略优。

  • 计算开销

RT-DETR相较于传统的YOLO模型,计算开销有所增加,但其速度提升和精度优化在大多数实际场景中提供了更好的折衷。

总结

RT-DETR是物体检测领域的一项重要创新,它摒弃了传统物体检测模型中的NMS后处理步骤,通过优化Transformer架构,实现在实时应用中仍能保持高精度和高速度。RT-DETR不仅在计算效率上优于YOLO等传统模型,还在精度上具有一定优势,尤其在复杂场景中的表现尤为突出。

通过端到端的训练、混合编码器和查询选择器等技术,RT-DETR为实时物体检测任务提供了一个新的解决方案。一个模型的优秀不仅仅只看它的速度和精度,它出彩的设计理念让它真正被人们记住,至于算不算"YOLO终结者",我认为看完整篇文章,你应该有自己的看法了,欢迎在评论区讨论交流哦!

相关推荐
夏末秋也凉30 分钟前
力扣-回溯-491 非递减子序列
数据结构·算法·leetcode
penguin_bark32 分钟前
三、动规_子数组系列
算法·leetcode
kyle~40 分钟前
thread---基本使用和常见错误
开发语言·c++·算法
曲奇是块小饼干_1 小时前
leetcode刷题记录(一百零八)——322. 零钱兑换
java·算法·leetcode·职场和发展
小wanga1 小时前
【leetcode】滑动窗口
算法·leetcode·职场和发展
少年芒1 小时前
Leetcode 490 迷宫
android·算法·leetcode
BingLin-Liu2 小时前
蓝桥杯备考:搜索算法之枚举子集
算法·蓝桥杯·深度优先
码农诗人2 小时前
调用openssl实现加解密算法
算法·openssl·ecdh算法
IT猿手2 小时前
2025最新智能优化算法:鲸鱼迁徙算法(Whale Migration Algorithm,WMA)求解23个经典函数测试集,MATLAB
android·数据库·人工智能·算法·机器学习·matlab·无人机
别NULL2 小时前
机试题——编辑器
c++·算法