PyTorch基于注意力的目标检测模型DETR

【图书推荐】《PyTorch深度学习与计算机视觉实践》-CSDN博客

目标检测是计算机视觉领域的一个重要任务,它的目标是在图像或视频中识别并定位出特定的对象。在这个过程中,需要确定对象的位置和类别,以及可能存在的多个实例。

DETR模型通过端到端的方式进行目标检测,即从原始图像直接检测出目标的位置和类别,而不需要进行区域提议或特征金字塔等步骤。

DETR模型的核心思想是将目标检测任务转换为一个序列到序列的问题。它将输入图像视为一个序列,并使用Transformer编码器将其转换为一种可被解码器理解的形式。具体来说,DETR模型使用CNN来提取图像特征,然后将其输入Transformer编码器中进行处理。再使用一个Transformer解码器来逐步解码出目标的位置和类别。完整的DETR的架构如图13-11所示。

图13-11 完整的DETR模型架构

下面借用在13.2节中实现的DETR目标检测模型进行讲解。完整的DETR模型代码如下:

复制代码
import torch
from torch import nn
from torchvision.models import resnet50

class DETR(nn.Module):
    def __init__(self,num_classes = 92,hidden_dim=256,nheads=8,num_encoder_layers=6,num_decoder_layers=6):
        super().__init__()
        #创建ResNet-50的骨干(backbone)网
        with torch.no_grad():
            self.backbone = resnet50()
            #清除ResNet-50骨干网最后的全连接层
            del self.backbone.fc
        #创建转换层,1×1的卷积,主要起到改变通道大小的作用
        self.conv = nn.Conv2d(2048,hidden_dim,1)
        #利用PyTorch内嵌的类创建Transformer实例
        self.transformer = nn.Transformer(hidden_dim,nheads,num_encoder_layers,num_decoder_layers)
        #预测头,多出的类别用于预测non-empty slots
        self.linear_class = nn.Linear(hidden_dim,num_classes)
        self.linear_bbox = nn.Linear(hidden_dim,4)
        # 输出检测槽编码(object queries)
        self.query_pos = nn.Parameter(torch.rand(100,hidden_dim))
        #可学习的位置编码,用于指导输入图形的坐标
        self.row_embed = nn.Parameter(torch.rand(50,hidden_dim//2))
        self.col_embed = nn.Parameter(torch.rand(50,hidden_dim//2))
        self._reset_parameters()

    def forward(self,inputs):
        #将ResNet-50网络作为backbone
        x = self.backbone.conv1(inputs)       
        x = self.backbone.bn1(x)                
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)      
        x = self.backbone.layer1(x)             
        x = self.backbone.layer2(x)             
        x = self.backbone.layer3(x)             
        x = self.backbone.layer4(x)     	#将ResNet-50网络作为backbone

        #从2048维度转换到可被Transformer接受的256维特征平面
        h = self.conv(x)                                        
        #(1,2048,25,34)->(1,hidden_dim,25,34)
        # 构建位置编码
        B,C,H,W = h.shape
        #创建一个可训练的与输入向量同样维度的位置向量,与原始的DETR的不同之处在于这里的位置向量是可训练的
        pos = torch.cat([self.col_embed[:W].unsqueeze(0).repeat(H,1,1),self. row_embed[:H].unsqueeze(1).repeat(1,W,1),],dim=-1).flatten(0,1).unsqueeze(1)
		
	   #将图像特征与位置信息进行合并
        src = pos+0.1*h.flatten(2).permute(2,0,1)
        #创建查询函数
        tgt = self.query_pos.unsqueeze(1).repeat(1,B,1)
        #通过Transformer继续前向传播
        #参数1:(h*w,batch_size,256),参数2:(100,batch_size,hidden_dim)
        #输出:(hidden_dim,100)-->(100,hidden_dim)
        h = self.transformer(src,tgt).transpose(0,1)
        #将Transformer的输出投影到分类标签及边界框
        return {'pred_logits':self.linear_class(h),'pred_boxes': self.linear_bbox(h).sigmoid()}

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                torch.nn.init.xavier_uniform_(p)

从上面模型架构的实现代码上来看,整体DETR设计较为简单,可以分为3个主要部分:backbone、Transfomer和FFN。

  1. backbone组件

backbone是DETR模型的第一部分,主要用于在图像上提取特征,生成特征图。这些特征图将作为输入传递给Transformer Encoder。backbone通常使用类似于ResNet或CNN模型来提取特征。

DETR将Resnet50作为backbone进行特征抽取,这样做的目的是可以直接使用PyTorch 2.0中提供的预训练模型和权重,从而节省了训练时间。

  1. Transformer构成

Transformer是DETR模型的第二部分,它是由编码器和解码器构成,如图13-12所示。

编码器用于对backbone输出的特征图进行编码。这个编码过程主要是通过多头自注意力机制实现的。在DETR模型中,每个多头自注意力之前都使用了位置编码,这种位置编码方式可以帮助模型更好地理解图像中的空间信息。

图13-12 DETR中的Transformer组件

  1. 分类器FFN

FFN一般使用两个全连接层作为分类器,其作用是对基于Transformer编码和查询后的特征向量进行分类计算,代码如下:

{'pred_logits':self.linear_class(h),'pred_boxes':self.linear_bbox(h).sigmoid()}

这里的self.linear_class和linear_bbox分别是对查询结果类别和位置的计算,分别用于预测分类和边界框回归。

以上就是对DETR模型的讲解。可以看到,DETR模型在架构设计上并没有太过于难懂的部分,可以认为是前面所学知识的集成。DETR在目标检测上的成功除了模型的设计外,还有一个重大创新就是开创性地提出了新的损失函数,目标检测中的损失函数通常由两部分组成:类别损失和边界框损失。对于类别损失,一般采用交叉熵损失函数,而在边界框损失方面,一般采用L1或L2损失函数。然而,DETR算法采用了不同的方式来计算类别损失和边界框损失。

DETR算法中的损失函数采用了基于二部图匹配的方式进行计算。具体来说,该算法首先将ground truth和预测的bounding box进行匹配,然后通过对比匹配结果和真实标签之间的差异来计算损失值。

《PyTorch深度学习与计算机视觉实践(人工智能技术丛书)》(王晓华)【摘要 书评 试读】- 京东图书 (jd.com)

相关推荐
翼龙云_cloud几秒前
阿里云代理商:轻量服务器部署 OpenClaw 集成钉钉实现自动化办公
服务器·人工智能·阿里云·钉钉·openclaw
LilySesy6 分钟前
【案例总结】震撼巨作——SAP连接钉钉WEBHOOK
运维·人工智能·ai·钉钉·sap·abap·webhook
星空6 分钟前
从LLM到Agent Skill学习笔记
人工智能
新缸中之脑10 分钟前
12个最佳AI演示文稿(PPT)制作工具
人工智能·powerpoint
火山引擎开发者社区13 分钟前
从“内容苦力”到“高效创作者”,你只差一个 ArkClaw
人工智能
开开心心就好13 分钟前
电子教材下载工具,支持多链接批量下载
windows·随机森林·计算机视觉·pdf·计算机外设·逻辑回归·excel
weiyvyy18 分钟前
嵌入式硬件接口开发的流程
人工智能·驱动开发·单片机·嵌入式硬件·硬件架构·硬件工程
罗小罗同学19 分钟前
首个病理AI领域的扩散基础模型CytoSyn开源,可生成高度逼真、符合生物学规律的H&E染色病理切片
人工智能·开源·医学图像处理·医工交叉·医学ai
code_pgf27 分钟前
Jetson Orin NX 16G设备上配置AI服务自动启动的方案,包括Ollama、llama-server和OpenClaw Gateway三个组件
数据库·人工智能·安全·gateway·边缘计算·llama
前端付豪29 分钟前
实现 AI 回复支持 Markdown 渲染
前端·人工智能·markdown