Transformer实战-系列教程15:DETR 源码解读2(整体架构:DETR类)

🚩🚩🚩Transformer实战-系列教程总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传
点我下载源码

DETR 算法解读
DETR 源码解读1(项目配置/CocoDetection类/ConvertCocoPolysToMask类)
DETR 源码解读2(DETR类)
DETR 源码解读3(位置编码:Joiner类/PositionEmbeddingSine类)
DETR 源码解读4(BackboneBase类/Backbone类)
DETR 源码解读5(Transformer类)
DETR 源码解读6(编码器:TransformerEncoder类/TransformerEncoderLayer类)
DETR 源码解读7(解码器:TransformerDecoder类/TransformerDecoderLayer类)
DETR 源码解读8(训练函数/损失函数)

4、DETR类

位置:models/detr.py/DETR类

4.1 构造函数

python 复制代码
class DETR(nn.Module):
    def __init__(self, backbone, transformer, num_classes, num_queries, aux_loss=False):
        super().__init__()
        self.num_queries = num_queries
        self.transformer = transformer
        hidden_dim = transformer.d_model
        self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
        self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
        self.query_embed = nn.Embedding(num_queries, hidden_dim)
        self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1)
        self.backbone = backbone
        self.aux_loss = aux_loss
  1. DETR类继承torch nn.Module
  2. 构造函数,传入5个参数:
    • backbone:CNN骨架网络,用于特征提取
    • transformer:Transformer模型,用于处理序列数据
    • num_classes:目标类别的数量
    • num_queries:解码器初始化生成的100个向量的个数,num_queries=100
    • aux_loss:一个布尔值,指示是否使用辅助损失来帮助训练
  3. 初始化
  4. num_queries
  5. transformer
  6. hidden_dim ,Transformer中的隐藏层维度
  7. class_embed ,类别预测的输出层,这个全连接层是接Transformer的输出,类别加1是额外的无类别对象
  8. bbox_embed,一个MLP,也是接Transformer的输出,边界框的四个坐标的回归
  9. query_embed ,解码器的初始100个向量
  10. input_proj ,一个1x1的二维卷积,使得backbone的输出通道数映射到与Transformer隐藏层维度相同
  11. backbone,一个预训练的卷积神经网络,主要作用是提取图像的特征,它的输出经过input_proj 处理后作为Transformer的输入
  12. aux_loss,保存是否使用辅助损失的标志

这里包含了几个自定义函数和类:

nested_tensor_from_tensor_list函数:将不同尺寸处理的图像Tensor转换为一个嵌套Tensor

MLP类:边界框的四个坐标的回归

transformer类:构建transformer架构

backbone:用于提取图像特征的CNN

4.2 前向传播

python 复制代码
    def forward(self, samples: NestedTensor):
        if isinstance(samples, (list, torch.Tensor)):
            samples = nested_tensor_from_tensor_list(samples)
        features, pos = self.backbone(samples)
        src, mask = features[-1].decompose()
        assert mask is not None
        hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]
        outputs_class = self.class_embed(hs)
        outputs_coord = self.bbox_embed(hs).sigmoid()
        out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
        if self.aux_loss:
            out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)
        return out    
  1. 前向传播函数,输入为samples=NestedTensor{mask={Tensor(2,771,911)},tensors={Tensor(2,3,771,911)}}
  2. 检查samples是否为列表或Tensor类型
  3. samples ,如果,使用nested_tensor_from_tensor_list函数转换为NestedTensor
  4. features, pos,图像特征图对应的位置编码,backbone实际上是一个resnet,features和pos是一个list结构,保存了各层的输出
  5. src, mask,解构最后一层的特征,获取源数据和掩码,src:torch.Size([2, 2048, 21, 18]),mask torch.Size([2, 21, 18]),2是batch,2048是特征维度,后面两个是图像长宽,这里的features[-1]表示在backbone中有多层都有输出,features保存了各层的输出,这里-1就表示最后的输出
  6. 确保掩码不为空
  7. 将数据通过Transformer处理,获取序列输出,torch.Size([6, 2, 100, 256]),6是Transformer的堆叠层数,2是batch,100是生成100个目标预测,256是每个目标预测的维度,Transformer模块有两个返回值,只取第一个返回值
  8. outputs_class ,获取类别预测
  9. outputs_coord ,获取边界框坐标预测,并使用sigmoid函数将输出值限制在0到1之间
  10. out ,将类别预测结果和 边界框坐标预测结果做成一个字典
  11. 如果启用了辅助损失
  12. 通过辅助函数_set_aux_loss计算辅助损失
  13. 返回out

4.3 辅助函数_set_aux_loss()

python 复制代码
@torch.jit.unused
    def _set_aux_loss(self, outputs_class, outputs_coord):
        return [{'pred_logits': a, 'pred_boxes': b}
                for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
  1. @torch.jit.unused:一个装饰器,指示当使用TorchScript编译模型时,该方法不应被编译。这是因为辅助损失的计算可能不兼容TorchScript的静态图特性
  2. 定义函数,接收类别预测和边界框坐标作为输入
  3. 返回一个列表,将每一个类别预测和边界框坐标都封装成一个字典,这样,训练过程中可以计算每一层的损失,从而实现辅助损失的目的

DETR 算法解读
DETR 源码解读1(项目配置/CocoDetection类/ConvertCocoPolysToMask类)
DETR 源码解读2(DETR类)
DETR 源码解读3(位置编码:Joiner类/PositionEmbeddingSine类)
DETR 源码解读4(BackboneBase类/Backbone类)
DETR 源码解读5(Transformer类)
DETR 源码解读6(编码器:TransformerEncoder类/TransformerEncoderLayer类)
DETR 源码解读7(解码器:TransformerDecoder类/TransformerDecoderLayer类)
DETR 源码解读8(训练函数/损失函数)

相关推荐
程序猿小D3 分钟前
【完整源码+数据集+部署教程】【零售和消费品&存货】价格标签检测系统源码&数据集全套:改进yolo11-RFAConv
前端·yolo·计算机视觉·目标跟踪·数据集·yolo11·价格标签检测系统源码
滑水滑成滑头6 分钟前
**点云处理:发散创新,探索前沿技术**随着科技的飞速发展,点云处理技术在计算机视觉、自动驾驶、虚拟现实等领域的应用愈发广
java·python·科技·计算机视觉·自动驾驶
用户34216749055239 分钟前
Java高手速成--吃透源码+手写组件+定制开发教程
前端·深度学习
CoovallyAIHub1 小时前
超越“识别”:下一代机器视觉如何破解具身智能落地难题?
深度学习·算法·计算机视觉
Jump 不二1 小时前
百度 PaddleOCR 3.0 深度测评:与 MinerU 的复杂表格识别对决
人工智能·深度学习·百度·ocr
孤廖1 小时前
C++ 模板再升级:非类型参数、特化技巧(含全特化与偏特化)、分离编译破解
linux·服务器·开发语言·c++·人工智能·后端·深度学习
CoovallyAIHub2 小时前
全球OCR新标杆!百度0.9B小模型斩获四项SOTA,读懂复杂文档像人一样自然
深度学习·算法·计算机视觉
Francek Chen2 小时前
【深度学习计算机视觉】14:实战Kaggle比赛:狗的品种识别(ImageNet Dogs)
人工智能·pytorch·深度学习·计算机视觉·kaggle·imagenet dogs
渡我白衣2 小时前
《未来的 AI 操作系统(四)——AgentOS 的内核设计:调度、记忆与自我反思机制》
人工智能·深度学习·机器学习·语言模型·数据挖掘·人机交互·语音识别
MoRanzhi12033 小时前
Pillow 基础图像操作与数据预处理
图像处理·python·深度学习·机器学习·numpy·pillow·数据预处理