RT-DETR代码详解(官方pytorch版)——参数配置(1)

前言

RT-DETR虽然是DETR系列,但是它的代码结构和之前的DETR系列代码不一样。

它是通过很多的yaml文件进行参数配置,和之前在train.py的parser = argparse.ArgumentParser()去配置所有参数不同,所以刚开始不熟悉代码的时候可能不知道在哪儿修改参数。

RT-DETR有官方版和ultralytics版两个版本代码,可以参考以下链接,分别使用两种方法对代码进行复现:
详解RT-DETR网络结构/数据集获取/环境搭建/训练/推理/验证/导出/部署_rt-dert-CSDN博客

下述内容主要是针对参数配置的代码实现进行解读,因为刚开始我拿着代码都不知道是怎么运行的,模型在哪儿加载参数都找不到

一、train.py文件

在RT-DETR中,train.py文件需要配置的内容很少,因为需要的参数配置全都放在了rtdetr_rxxvd_6x_coco.yml(骨干网络可选)文件中。在这个文件中又包含了其他所有的文件,可以依需修改:

左边是可以选择的backbone骨干网络,后续以ResNet18为例。

二、rtdetr_r18vd_6x_coco.yaml文件

python 复制代码
__include__: [
  '../dataset/coco_detection.yml',  # 数据集
  '../runtime.yml', # 运行参数配置
  './include/dataloader.yml', # 定义数据加载器参数
  './include/optimizer.yml', # 定义优化器通用设置
  './include/rtdetr_r50vd.yml', # 定义 RT-DETR 模型的结构参数(如 backbone 和解码器层数等
]


output_dir: ./output/rtdetr_r18vd_6x_coco  # 输出的文件地址

PResNet:
  depth: 18
  freeze_at: -1 # 不冻结任何层(如果设置为正数,则冻结 ResNet 的前几层)
  freeze_norm: False # 不冻结归一化层(如 BatchNorm)
  pretrained: True # 加载预训练权重(通常是基于 ImageNet 数据集的权重)

HybridEncoder:
  in_channels: [128, 256, 512] # 编码器的输入特征通道数,分别对应 ResNet-18 不同尺度的特征图输出
  hidden_dim: 256
  expansion: 0.5 # 特征通道扩展比例


RTDETRTransformer:
  eval_idx: -1 # 指定在哪一层解码器输出进行评估(-1 表示最后一层)
  num_decoder_layers: 3 # 解码器的层数
  num_denoising: 100  # 去噪查询的数量



optimizer:
  type: AdamW # 该优化器改进了 Adam,支持权重衰减以减轻过拟合
  params:  # 参数分组,针对不同模块的参数设置不同的学习率和权重衰减
    - 
      params: '^(?=.*backbone)(?=.*norm).*$'      # 匹配骨干网络中的归一化层参数,设置较低学习率和无权重衰减
      lr: 0.00001
      weight_decay: 0.
    - 
      params: '^(?=.*backbone)(?!.*norm).*$'      # 匹配骨干网络中非归一化参数
      lr: 0.00001
    - 
      params: '^(?=.*(?:encoder|decoder))(?=.*(?:norm|bias)).*$'   # 匹配 Transformer 中归一化层或偏置参数
      weight_decay: 0.

  lr: 0.0001
  betas: [0.9, 0.999] # Adam 优化器的 beta 参数
  weight_decay: 0.0001 # 权重衰减值

上面的注释只是为了解释各行代码意思,但是运行代码过程中,yaml文件不能有注释,否则会报错:

三、yaml_config.py文件

在train.py文件中,实际是通过YAMLConfig()这个类读取rtdetr_r18vd_6x_coco.yaml中的配置信息。通过加载 YAML 配置文件,将不同的模型、优化器、数据加载器等组件以模块化的方式创建

主要功能

1. 动态加载 YAML 配置文件

  • 使用 load_config 函数加载 YAML 文件,读取其中的配置数据。
  • 支持通过 merge_dict 将命令行或其他来源的参数覆盖 YAML 文件中的默认配置。

2. 组件动态创建

  • 根据 YAML 文件的配置,动态创建模型(model)、损失函数(criterion)、优化器(optimizer)、学习率调度器(lr_scheduler)和数据加载器(dataloader)等。

3. 参数分组和正则匹配

  • 支持为优化器指定不同模块的参数组,并通过正则表达式选择分组的参数。

4. 支持扩展功能

  • 支持 EMA(Exponential Moving Average,指数滑动平均)AMP(Automatic Mixed Precision,自动混合精度)
  • 自动处理模型参数的冻结、梯度裁剪等功能。

5. 模块化设计

  • 配置组件通过 create 函数动态实例化,便于扩展和自定义。

3.1 类初始化与加载配置

python 复制代码
class YAMLConfig(BaseConfig):
    def __init__(self, cfg_path: str, **kwargs) -> None:
        super().__init__()
        cfg = load_config(cfg_path)  # 加载 YAML 配置文件
        merge_dict(cfg, kwargs)  # 合并外部输入的参数(高优先级)

        self.yaml_cfg = cfg  # 保存解析后的 YAML 配置

        # 一些常见配置的提取
        self.log_step = cfg.get('log_step', 100)
        self.checkpoint_step = cfg.get('checkpoint_step', 1)
        self.epoches = cfg.get('epoches', -1)
        self.resume = cfg.get('resume', '')
        self.tuning = cfg.get('tuning', '')
        self.sync_bn = cfg.get('sync_bn', False)
        self.output_dir = cfg.get('output_dir', None)
        self.use_ema = cfg.get('use_ema', False)
        self.use_amp = cfg.get('use_amp', False)
        self.autocast = cfg.get('autocast', dict())
        self.find_unused_parameters = cfg.get('find_unused_parameters', None)
        self.clip_max_norm = cfg.get('clip_max_norm', 0.0)
  • 功能
    • 从 YAML 配置文件中加载配置,初始化训练流程中常用的参数。
    • cfg_path:YAML 配置文件路径。
    • kwargs:支持通过外部传入参数(如命令行参数)覆盖 YAML 中的默认配置。
    • 使用 get 方法设置默认值,避免配置文件缺失某些字段时程序报错。

3.1.1 yaml_config.py文件

通过cfg = load_config(cfg_path)已经将所有的配置信息传递给cfg了

尽管传入的只有一个rtdetr_r18vd_6x_coco.yaml文件,但它里面包含了其他的配置文件地址:

load_config()函数在yaml_utils.py文件中

python 复制代码
def load_config(file_path, cfg=dict()):
    """
    加载 YAML 配置文件,并支持递归加载包含的其他 YAML 文件。
    Args:
        file_path (str): 要加载的 YAML 文件路径。
        cfg (dict): 全局配置字典,默认为空字典。
    Returns:
        dict: 加载并合并后的配置字典。
    """
    # 获取文件扩展名并确保是 YAML 文件
    _, ext = os.path.splitext(file_path)
    assert ext in ['.yml', '.yaml'], "仅支持 YAML 文件(.yml 或 .yaml)"

    # 打开并加载 YAML 文件
    with open(file_path, 'r') as f:
        file_cfg = yaml.load(f, Loader=yaml.Loader)
        if file_cfg is None:
            return {}  # 如果文件为空,则返回空字典

    # 检查是否需要加载包含的 YAML 配置(递归加载)
    if INCLUDE_KEY in file_cfg:
        # 提取 'include' 键的值,通常是其他 YAML 文件路径的列表
        base_yamls = list(file_cfg[INCLUDE_KEY])
        for base_yaml in base_yamls:
            # 将路径展开为完整路径(支持用户目录 ~ 和相对路径)
            if base_yaml.startswith('~'):
                base_yaml = os.path.expanduser(base_yaml)
            if not base_yaml.startswith('/'):  # 如果是相对路径
                base_yaml = os.path.join(os.path.dirname(file_path), base_yaml)

            # 递归加载被包含的 YAML 文件
            base_cfg = load_config(base_yaml, cfg)
            # 合并当前加载的配置到全局配置中
            merge_config(base_cfg, cfg)

    # 最终合并当前文件的配置到全局配置中
    return merge_config(file_cfg, cfg)
  • 通过 include 字段,可以将配置拆分成多个 YAML 文件,便于管理和维护。
  • 支持递归加载多个 YAML 文件 ,并通过 merge_config 实现配置合并,确保最终配置完整。

3.2 动态加载组件(如模型、优化器等)

过 @property 装饰器延迟加载组件,仅在实际使用时创建对象

@property装饰器

是 Python 的一个内置装饰器,常用于定义一个类的方法,并将其伪装成"属性"。

  1. 保护类的封装特性
  2. 让开发者可以使用"对象.属性"的方式操作操作类属性

通过 @property 装饰器,可以直接通过方法名来访问方法,不需要在方法名后添加一对"()"小括号。

语法格式:

python 复制代码
@property
def 方法名(self)
    代码块

更多@property装饰器内容可看,其中包含延时加载的应用:@property装饰器-CSDN博客

3.2.1 模型加载

python 复制代码
@property
def model(self) -> torch.nn.Module:
    if self._model is None and 'model' in self.yaml_cfg:
        merge_config(self.yaml_cfg)  # 合并全局配置
        self._model = create(self.yaml_cfg['model'])  # 动态创建模型
    return self._model
  • 检查 _model 是否已经创建,若未创建且配置中包含 model 字段,则动态创建模型。++(self.yaml_cfg已经存储了所有的配置信息,见3.1.1 图,提取model键的值)++
  • 使用 create 函数按照 yaml_cfg['model'] 中的定义实例化模型。

在rtdetr_r18vd_6x_coco.yml--->./include/rtdetr_r50vd.yml中 :

3.2.2 优化器延迟加载

python 复制代码
@property
def optimizer(self):
    if self._optimizer is None and 'optimizer' in self.yaml_cfg:
        merge_config(self.yaml_cfg)  # 合并全局配置
        params = self.get_optim_params(self.yaml_cfg['optimizer'], self.model)  # 获取参数分组
        self._optimizer = create('optimizer', params=params)  # 动态创建优化器
    return self._optimizer
  • 获取优化器参数分组(get_optim_params),根据配置动态创建优化器实例。

3.2.3 学习率调度器加载

python 复制代码
@property
def lr_scheduler(self):
    if self._lr_scheduler is None and 'lr_scheduler' in self.yaml_cfg:
        merge_config(self.yaml_cfg)
        self._lr_scheduler = create('lr_scheduler', optimizer=self.optimizer)
        print('Initial lr: ', self._lr_scheduler.get_last_lr())
    return self._lr_scheduler
  • 动态创建学习率调度器对象,并与优化器绑定

在rtdetr_r18vd_6x_coco.yml--->./include/optimizer.yml中 :

基于MultiStepLR生成对应的学习率调度器

  • MultiStepLR 是 PyTorch 中 torch.optim.lr_scheduler 提供的一种学习率调度器
  • 它会在指定的训练步骤(milestones)调整学习率

根据配置,初始学习率为 0.1在第 1000 步时,学习率会乘以 gamma=0.1 ,变为 0.01。输出如下:

python 复制代码
Step 0: Learning Rate = 0.1
Step 500: Learning Rate = 0.1
Step 1000: Learning Rate = 0.01
Step 1500: Learning Rate = 0.01

3.3 数据加载器

python 复制代码
@property
def train_dataloader(self):
    if self._train_dataloader is None and 'train_dataloader' in self.yaml_cfg:
        merge_config(self.yaml_cfg)
        self._train_dataloader = create('train_dataloader')
        self._train_dataloader.shuffle = self.yaml_cfg['train_dataloader'].get('shuffle', False)
    return self._train_dataloader
  • 动态加载训练数据加载器,并根据配置调整 shuffle 参数

3.4 参数分组(正则表达式匹配)

python 复制代码
@staticmethod
def get_optim_params(cfg: dict, model: nn.Module):
    '''
    E.g.:
        ^(?=.*a)(?=.*b).*$         means including a and b
        ^((?!b.)*a((?!b).)*$       means including a but not b
        ^((?!b|c).)*a((?!b|c).)*$  means including a but not (b | c)
    '''
    assert 'type' in cfg, ''
    cfg = copy.deepcopy(cfg)

    if 'params' not in cfg:
        return model.parameters()  # 如果未定义参数分组,返回默认模型参数

    assert isinstance(cfg['params'], list), ''

    param_groups = []
    visited = []
    for pg in cfg['params']:
        pattern = pg['params']
        params = {k: v for k, v in model.named_parameters() if v.requires_grad and len(re.findall(pattern, k)) > 0}
        pg['params'] = params.values()
        param_groups.append(pg)
        visited.extend(list(params.keys()))

    names = [k for k, v in model.named_parameters() if v.requires_grad]

    if len(visited) < len(names):
        unseen = set(names) - set(visited)
        params = {k: v for k, v in model.named_parameters() if v.requires_grad and k in unseen}
        param_groups.append({'params': params.values()})
        visited.extend(list(params.keys()))

    assert len(visited) == len(names), ''
    return param_groups
  • 根据正则表达式匹配模型中的参数(named_parameters 方法返回 <参数名, 参数> 的映射)。
  • 支持按模块或特定规则分组优化器参数(如设置不同学习率、权重衰减)。
  • 未匹配的参数会自动归为默认组。
  • ^(?=.*backbone)(?=.*norm).*$:匹配键名中包含 backbonenorm 的参数。
  • ^(?=.*encoder)(?!.*bias).*$:匹配键名中包含 encoder 且不包含 bias 的参数。
相关推荐
007tg2 小时前
从ChatGPT家长控制功能看AI合规与技术应对策略
人工智能·chatgpt·企业数据安全
Memene摸鱼日报2 小时前
「Memene 摸鱼日报 2025.9.11」腾讯推出命令行编程工具 CodeBuddy Code, ChatGPT 开发者模式迎来 MCP 全面支持
人工智能·chatgpt·agi
linjoe992 小时前
【Deep Learning】Ubuntu配置深度学习环境
人工智能·深度学习·ubuntu
范纹杉想快点毕业3 小时前
ZYNQ PS 端 UART 接收数据数据帧(初学者友好版)嵌入式编程 C语言 c++ 软件开发
c语言·笔记·stm32·单片机·嵌入式硬件·mcu·51单片机
先做个垃圾出来………3 小时前
残差连接的概念与作用
人工智能·算法·机器学习·语言模型·自然语言处理
AI小书房4 小时前
【人工智能通识专栏】第十三讲:图像处理
人工智能
fanstuck4 小时前
基于大模型的个性化推荐系统实现探索与应用
大数据·人工智能·语言模型·数据挖掘
多看书少吃饭5 小时前
基于 OpenCV 的眼球识别算法以及青光眼算法识别
人工智能·opencv·计算机视觉
IT学长编程5 小时前
计算机毕业设计 基于大数据技术的医疗数据分析与研究 Python 大数据毕业设计 Hadoop毕业设计选题【附源码+文档报告+安装调试】
大数据·hadoop·机器学习·数据分析·毕业设计·毕业论文·医疗数据分析