PyTorch基于注意力的目标检测模型DETR

【图书推荐】《PyTorch深度学习与计算机视觉实践》-CSDN博客

目标检测是计算机视觉领域的一个重要任务,它的目标是在图像或视频中识别并定位出特定的对象。在这个过程中,需要确定对象的位置和类别,以及可能存在的多个实例。

DETR模型通过端到端的方式进行目标检测,即从原始图像直接检测出目标的位置和类别,而不需要进行区域提议或特征金字塔等步骤。

DETR模型的核心思想是将目标检测任务转换为一个序列到序列的问题。它将输入图像视为一个序列,并使用Transformer编码器将其转换为一种可被解码器理解的形式。具体来说,DETR模型使用CNN来提取图像特征,然后将其输入Transformer编码器中进行处理。再使用一个Transformer解码器来逐步解码出目标的位置和类别。完整的DETR的架构如图13-11所示。

图13-11 完整的DETR模型架构

下面借用在13.2节中实现的DETR目标检测模型进行讲解。完整的DETR模型代码如下:

import torch
from torch import nn
from torchvision.models import resnet50

class DETR(nn.Module):
    def __init__(self,num_classes = 92,hidden_dim=256,nheads=8,num_encoder_layers=6,num_decoder_layers=6):
        super().__init__()
        #创建ResNet-50的骨干(backbone)网
        with torch.no_grad():
            self.backbone = resnet50()
            #清除ResNet-50骨干网最后的全连接层
            del self.backbone.fc
        #创建转换层,1×1的卷积,主要起到改变通道大小的作用
        self.conv = nn.Conv2d(2048,hidden_dim,1)
        #利用PyTorch内嵌的类创建Transformer实例
        self.transformer = nn.Transformer(hidden_dim,nheads,num_encoder_layers,num_decoder_layers)
        #预测头,多出的类别用于预测non-empty slots
        self.linear_class = nn.Linear(hidden_dim,num_classes)
        self.linear_bbox = nn.Linear(hidden_dim,4)
        # 输出检测槽编码(object queries)
        self.query_pos = nn.Parameter(torch.rand(100,hidden_dim))
        #可学习的位置编码,用于指导输入图形的坐标
        self.row_embed = nn.Parameter(torch.rand(50,hidden_dim//2))
        self.col_embed = nn.Parameter(torch.rand(50,hidden_dim//2))
        self._reset_parameters()

    def forward(self,inputs):
        #将ResNet-50网络作为backbone
        x = self.backbone.conv1(inputs)       
        x = self.backbone.bn1(x)                
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)      
        x = self.backbone.layer1(x)             
        x = self.backbone.layer2(x)             
        x = self.backbone.layer3(x)             
        x = self.backbone.layer4(x)     	#将ResNet-50网络作为backbone

        #从2048维度转换到可被Transformer接受的256维特征平面
        h = self.conv(x)                                        
        #(1,2048,25,34)->(1,hidden_dim,25,34)
        # 构建位置编码
        B,C,H,W = h.shape
        #创建一个可训练的与输入向量同样维度的位置向量,与原始的DETR的不同之处在于这里的位置向量是可训练的
        pos = torch.cat([self.col_embed[:W].unsqueeze(0).repeat(H,1,1),self. row_embed[:H].unsqueeze(1).repeat(1,W,1),],dim=-1).flatten(0,1).unsqueeze(1)
		
	   #将图像特征与位置信息进行合并
        src = pos+0.1*h.flatten(2).permute(2,0,1)
        #创建查询函数
        tgt = self.query_pos.unsqueeze(1).repeat(1,B,1)
        #通过Transformer继续前向传播
        #参数1:(h*w,batch_size,256),参数2:(100,batch_size,hidden_dim)
        #输出:(hidden_dim,100)-->(100,hidden_dim)
        h = self.transformer(src,tgt).transpose(0,1)
        #将Transformer的输出投影到分类标签及边界框
        return {'pred_logits':self.linear_class(h),'pred_boxes': self.linear_bbox(h).sigmoid()}

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                torch.nn.init.xavier_uniform_(p)

从上面模型架构的实现代码上来看,整体DETR设计较为简单,可以分为3个主要部分:backbone、Transfomer和FFN。

  1. backbone组件

backbone是DETR模型的第一部分,主要用于在图像上提取特征,生成特征图。这些特征图将作为输入传递给Transformer Encoder。backbone通常使用类似于ResNet或CNN模型来提取特征。

DETR将Resnet50作为backbone进行特征抽取,这样做的目的是可以直接使用PyTorch 2.0中提供的预训练模型和权重,从而节省了训练时间。

  1. Transformer构成

Transformer是DETR模型的第二部分,它是由编码器和解码器构成,如图13-12所示。

编码器用于对backbone输出的特征图进行编码。这个编码过程主要是通过多头自注意力机制实现的。在DETR模型中,每个多头自注意力之前都使用了位置编码,这种位置编码方式可以帮助模型更好地理解图像中的空间信息。

图13-12 DETR中的Transformer组件

  1. 分类器FFN

FFN一般使用两个全连接层作为分类器,其作用是对基于Transformer编码和查询后的特征向量进行分类计算,代码如下:

{'pred_logits':self.linear_class(h),'pred_boxes':self.linear_bbox(h).sigmoid()}

这里的self.linear_class和linear_bbox分别是对查询结果类别和位置的计算,分别用于预测分类和边界框回归。

以上就是对DETR模型的讲解。可以看到,DETR模型在架构设计上并没有太过于难懂的部分,可以认为是前面所学知识的集成。DETR在目标检测上的成功除了模型的设计外,还有一个重大创新就是开创性地提出了新的损失函数,目标检测中的损失函数通常由两部分组成:类别损失和边界框损失。对于类别损失,一般采用交叉熵损失函数,而在边界框损失方面,一般采用L1或L2损失函数。然而,DETR算法采用了不同的方式来计算类别损失和边界框损失。

DETR算法中的损失函数采用了基于二部图匹配的方式进行计算。具体来说,该算法首先将ground truth和预测的bounding box进行匹配,然后通过对比匹配结果和真实标签之间的差异来计算损失值。

《PyTorch深度学习与计算机视觉实践(人工智能技术丛书)》(王晓华)【摘要 书评 试读】- 京东图书 (jd.com)

相关推荐
看星猩的柴狗26 分钟前
机器学习-高斯混合模型
人工智能·机器学习
power-辰南2 小时前
机器学习之数据分析及特征工程详细分析过程
人工智能·python·机器学习·大模型·特征
少说多想勤做2 小时前
【前沿 热点 顶会】AAAI 2025中与目标检测有关的论文
人工智能·深度学习·神经网络·目标检测·计算机视觉·目标跟踪·aaai
橙子小哥的代码世界4 小时前
【计算机视觉基础CV-图像分类】05 - 深入解析ResNet与GoogLeNet:从基础理论到实际应用
图像处理·人工智能·深度学习·神经网络·计算机视觉·分类·卷积神经网络
leigm1234 小时前
深度学习使用Anaconda打开Jupyter Notebook编码
人工智能·深度学习·jupyter
Aileen_0v06 小时前
【玩转OCR | 腾讯云智能结构化OCR在图像增强与发票识别中的应用实践】
android·java·人工智能·云计算·ocr·腾讯云·玩转腾讯云ocr
阿正的梦工坊7 小时前
深入理解 PyTorch 的 view() 函数:以多头注意力机制(Multi-Head Attention)为例 (中英双语)
人工智能·pytorch·python
Ainnle7 小时前
GPT-O3:简单介绍
人工智能
OceanBase数据库官方博客7 小时前
向量检索+大语言模型,免费搭建基于专属知识库的 RAG 智能助手
人工智能·oceanbase·分布式数据库·向量数据库·rag
测试者家园7 小时前
ChatGPT助力数据可视化与数据分析效率的提升(一)
软件测试·人工智能·信息可视化·chatgpt·数据挖掘·数据分析·用chatgpt做软件测试