dirving transformer详读

引入

写在前面

由于时代变了,几年前学习bevformer的时候 自己的笔记 查阅的资料 AI的问答 还有后续补充更新杂乱在一起,导致自己写的东西 几年后再看 也不读不进去了 写的太乱了

传统规控转型的我 感觉 对于新问题 知道问什么 和怎么问 之后 学习起来基本就靠ai带飞了

当下的学习,可以很方便的用IDE中 或者龙虾的AI全局阅读,然后给出框接,总结和细节。

文体从不知所云的散文 改变为对话式了

AI给出的框架和代码解读在前 细节作为代码备注直接在代码里 传在附件中

部分扩展的知识点、额外的提问,作为自我查阅的备忘放在最后面

todo

目前只看完了AI 对主要代码的解读

有待细看提出问题

文件阅读顺序

先看配置:drivetransformer_small.py 模型结构、数据集、loss query 数量。

再看训练入口:train.py 训练主入口 读配置、加载 plugin dataset、 model runner

再看模型调度层:drivetransformer.py 模型主体1

然后重点读统一任务头:drivetransformer_head.py 模型主体2

最后补 transformer 细节和数据流:drivetransformer_layers.py 模型主体3

数据怎么喂:bench2drive_drivetransformer_dataset.py

drivetransformer_small.py 解读

这个文件是做什么的

这是 DriveTransformer 的总装图

  • train.py 是启动器;

  • drivetransformer.py 是总控器;

  • drivetransformer_head.py 是核心大脑;

  • drivetransformer_layers.py 是大脑里的神经层;

  • bench2drive_drivetransformer_dataset.py 是数据供给系统;

  • 那么这个配置文件就是把这些模块拼到一起的蓝图

先读配置文件的意义在于:能先知道系统有哪些模块、每个模块的规模、输入输出是什么。

建议阅读顺序

  1. 全局超参数

  2. model

  3. train_cfg

  4. train_pipeline / test_pipeline

  5. data

  6. optimizer / lr_config / runner

这份配置里最重要的全局概念

1. 任务统一

DriveTransformer 不是只做一个任务,而是把四件事放在同一个模型里做:

  • 3D agent detection

  • agent motion prediction

  • online map prediction

  • ego planning

所以配置里会不断看到以下三类 query:

  • agent_query

  • map_query

  • ego query

它们最后在统一 transformer 中交互。

2. 记忆机制

memory_len_frame = 10

这表示模型不仅看当前帧,还会保留历史帧里的记忆 token。时序记忆是这篇论文实现 streaming processing 的核心。

3. query 数量

  • agent_query_num_vec = 900

  • map_query_num_vec = 100

  • agent_num_topk_sift = 900

  • map_num_topk_sift = 100

这些值决定了模型每帧愿意维护多少候选目标和地图元素。query 越多,表达能力越强,但计算量也越大。

model 段怎么读

1. type='DriveTransformer'

说明最外层模型类在 mmdet3d_plugin/ours/drivetransformer.py

2. img_backbone

这里用的是 ResNet-50,作用是把多相机图像提成高层视觉特征。

3. img_neck

这里用 FPN,把 backbone 的输出变成统一通道数 _dim_ = 128 的特征,方便后面 transformer 使用。

4. pts_bbox_head

这是全配置里最重要的模块。这里承载了:

  • detection

  • motion forecasting

  • map prediction

  • planning

也就是说,这个 head 是论文核心逻辑。

pts_bbox_head 里的关键参数

感知 / 预测 / 地图 / 规划统一维度

  • embed_dims=_dim_

  • num_reg_fcs=2

  • num_cls_fcs=2

这些定义了统一 token 的特征维度,以及每个输出分支的小 MLP 规模。

检测与运动预测

  • agent_num_query

  • agent_num_query_sifted

  • fut_mode

  • fut_ts

  • num_classes

含义可以理解为:

  • 用多少 agent token 找目标;

  • 对每个 agent 预测多少未来模式;

  • 每条未来轨迹包含多少时间点。

在线地图

  • map_num_query

  • map_num_pts_per_vec

  • map_num_classes

这里说明 map 不是当成 dense segmentation,而是当成polyline/vector prediction 来做。

规划

  • fut_ts_ego_fix_time

  • fut_ts_ego_fix_dist

  • fut_ego_fix_dist

说明 ego 规划同时支持:

  • 固定时间间隔轨迹

  • 固定距离间隔轨迹

这对闭环控制是很实用的,因为固定时间和固定距离各有优势。

transformer 结构怎么对应论文

agent_prep_decoder / map_prep_decoder

这两个是预解码层,用于让 query 在进入主 decoder 之前,先从图像特征里拿到初始信息。

transformer=dict(type='DriveTransformerWrapper', ...)

这里才是主干统一 transformer。

注意 operation_order

  • task_self_attn

  • temporal_cross_attn

  • sensor_cross_attn

  • ffn

这基本就是论文的三类 unified operation 在代码里的直接映射。

loss 配置怎么读

配置里 loss 大体分四组:

  1. detection:loss_clsloss_bbox

  2. motion:loss_trajloss_traj_cls

  3. map:loss_map_clsloss_map_ptsloss_map_dir

  4. planning:loss_plan_reg_fix_timeloss_plan_reg_fix_distloss_plan_cls

这说明训练时模型不是只优化一个目标,而是统一优化四类任务。

train_cfg 怎么理解

这里的重点是 assignment:

  • HungarianAssigner3D

  • MapHungarianAssigner3D

这意味着 detection 和 map prediction 都沿用了 DETR 系风格的 bipartite matching,而不是传统 anchor matching。

DriveTransformer 的 query 预测方式,本质上更像 DETR 系,而不是传统 BEV 检测器。

数据部分怎么读

dataset_type

B2D_DriveTransformer_Dataset 对应 Bench2Drive 数据封装。

train_pipeline

核心过程:

  1. 读多相机图像

  2. 图像增强

  3. 读 3D 标注

  4. 过滤无效目标

  5. 构造轨迹标签

  6. 格式化并收集字段

collect_keys

这很重要,它定义了哪些时序 / 几何字段会被送入模型,例如:

  • lidar2img

  • cam_intrinsic

  • ego_pose

  • ego_pose_inv

  • timestamp

这些字段是做 3D 位置编码和时序对齐的关键输入。

训练策略部分

优化器

AdamW 是当前 transformer 系方法的标准选择。

学习率策略

CosineAnnealing + warmup 也是比较标准的 transformer 训练配置。

runner

这里使用 IterBasedRunner,不是 epoch-based。说明作者更关心迭代数控制,而不是按完整数据轮次来表达训练进度。

总结

  1. 这个文件不是"参数表",而是论文设计的工程映射

  2. model.pts_bbox_head 是核心中的核心。

  3. query 数、memory 长度、trajectory mode 数量,是理解模型容量的关键。

  4. pipeline 和 collect_keys 决定了模型到底能看到哪些几何与时序信息。

读完这个文件后,下一步看什么

下一步建议看 train.py,目前已经知道"要搭什么",接下来应该看"怎么启动和组装"。

train.py 解读

这个文件是做什么的

这是训练入口。它不负责定义论文算法,而负责把整个工程真正跑起来。

如果把配置文件比作"建筑蓝图",那 train.py 就是"施工总包":

  • 读取蓝图

  • 导入自定义模块

  • 构建数据集

  • 构建模型

  • 构建优化器

  • 构建 runner

  • 开始训练 / 恢复训练 / 加载权重

顶层结构

这个文件主要由四部分组成:

  1. BN 融合工具函数:_fuse_conv_bn / fuse_conv_bn

  2. 参数解析:parse_args

  3. 主入口:main

  4. 启动语句:脚本入口

一、BN 融合工具函数

_fuse_conv_bn

这个函数把一个卷积层和其后的 BN 融成一个卷积层。

用途:

  • 推理时减少额外 BN 计算

  • 简化网络图

这不是 DriveTransformer 特有逻辑,更像一个部署 / 推理优化选项。

fuse_conv_bn

递归地在整个模块里查找 Conv2d + BN 组合,并替换成融合后的结果。

如果你做推理优化、导出模型、部署,这个函数值得关注。

二、parse_args

这个函数定义了训练脚本的命令行接口。

重要的参数有:

  • config:配置文件路径

  • --work-dir:日志和 checkpoint 输出目录

  • --resume-from:断点续训

  • --load-from:加载预训练权重

  • --launcher:是否分布式训练

  • --cfg-options:临时覆盖配置项

  • --compile_before / --compile_after:尝试用 torch.compile 加速

这说明作者在工程上考虑了:

  • 分布式训练

  • 断点恢复

  • 动态改配置

  • 性能优化

三、main() 主流程

这是整个文件最重要的部分。

第 1 步:读取配置

python 复制代码
cfg = Config.fromfile(args.config) 

这里把配置文件变成 MMCV 的 Config 对象。后面所有 build 逻辑都从这里读取信息。

第 2 步:导入 plugin

DriveTransformer 的自定义模块不在 MMCV 默认注册表里,所以必须先 import:

  • 自定义 dataset

  • 自定义 detector

  • 自定义 head

  • 自定义 transformer 层

如果这一步漏掉,后面的 build_model / build_dataset 会直接失败。

第 3 步:设置工作目录和日志

这里会:

  • 创建 work_dir

  • 备份当前配置

  • 创建日志文件

这是训练工程中非常重要的一步,因为它保证了可复现性。

第 4 步:初始化分布式环境

通过 args.launcher 判断是否需要分布式训练。

如果使用 pytorch launcher,会调用 init_dist 初始化多卡环境。

第 5 步:记录环境信息和随机种子

包括:

  • CUDA / PyTorch / MMCV 环境

  • 当前配置文本

  • 随机种子

这是实验管理里非常关键的"元数据记录"。

第 6 步:构建数据集与 dataloader

python 复制代码
datasets = [build_dataset(cfg.data.train)] 

data_loaders = [build_dataloader(...)] 

这里说明:

  • 数据集类型与 pipeline 在 config 中定义;

  • train.py 只负责调度,不写死具体数据逻辑。

这也是 MMCV 风格项目的一大特点:配置驱动构建

第 7 步:构建模型

python 复制代码
model = build_model(cfg.model, ...) 

model.init_weights() 

这一步会根据配置真正实例化 DriveTransformer

注意它还会打印参数量,这对模型规模评估很有用。

第 8 步:包裹并行训练接口

  • 分布式:DistributedDataParallel

  • 单机多卡 / 单卡:DataParallel

这说明模型实现本身是普通 PyTorch 模块,训练并行性由外层包装器负责。

第 9 步:构建优化器和优化钩子

这里既支持普通优化,也支持 fp16 hook。

如果你后续要做混合精度、梯度裁剪或 optimizer 替换,改动点通常就在这里附近。

第 10 步:构建 runner

MMCV 的 runner 负责:

  • 驱动训练循环

  • 注册 log hook

  • 注册 checkpoint hook

  • 注册 lr schedule

你可以把它理解为"训练主循环引擎"。

第 11 步:可选 torch.compile

作者提供了 compile 前后两个开关,尝试对 backbone、neck、head 的一些子模块做编译优化。

这是一个很工程化的设计点,表明作者在意吞吐和性能。

第 12 步:加载 checkpoint 或恢复训练

  • resume:恢复训练状态

  • load_checkpoint:只加载模型权重

二者区别很重要,新工程师经常混淆。

这个文件里你要建立的工程直觉

  1. 它不实现论文核心算法,它实现"把算法跑起来"。

  2. 它大量依赖 MMCV 的 registry + build 机制。

  3. 它是定位训练问题的第一现场:

  • 模块没注册

  • 配置路径错

  • DDP 没初始化

  • checkpoint 加载失败

  • dataloader 构建失败

调试建议

如果训练跑不起来,优先检查:

  1. config 是否指向正确文件

  2. plugin_dir 是否导入成功

  3. dataset 是否能 build

  4. model 是否能 init_weights

  5. checkpoint path 是否存在

读完这个文件后,下一步看什么

下一步看 mmdet3d_plugin/ours/drivetransformer.py。因为 train.py 已经告诉你"怎么启动",接下来要看"模型最外层怎么调度训练和推理"。

扩展

MMCV的机制

核心思路

这个项目对 MMCV 的用法,本质上是:用配置文件描述实验,用 registry 注册模块,用 builder 递归实例化,用 runner 驱动训练。

所以论文方法本身主要写在模型与数据模块里,而训练脚本 train.py 更像一个"装配与调度中心"。

  1. Config 机制

drivetransformer_small.py 里的 model=dict(...)、data=dict(...)、optimizer=dict(...) 不是普通参数表,而是 MMCV 的声明式实验配置。

Config.fromfile(...) 读入后,cfg.model、cfg.data.train、cfg.optimizer 都变成可递归访问的配置对象。

论文里的"统一 transformer + 多任务 head + 时序记忆 + map/vector prediction"并不是在 train.py 里硬编码拼出来的,而是先写成配置,再由 builder 去构建。

  1. Registry 机制

MMCV 的关键机制是 registry:模块先注册,builder 才能按名字找到它。

比如:

detector 在 drivetransformer.py 里通过 @DETECTORS.register_module()

head 在 drivetransformer_head.py 里通过 @HEADS.register_module()

transformer layer 在 drivetransformer_layers.py 里通过 @ATTENTION.register_module()、@TRANSFORMER_LAYER.register_module() 等

dataset 在 bench2drive_drivetransformer_dataset.py 里通过 @DATASETS.register_module()

这就是为什么 train.py 要先 import plugin:不 import,就还没注册,build_model / build_dataset 就找不到对应名字。

  1. Builder 机制

build_model(cfg.model, ...) 会读取配置里的 type='DriveTransformer',去 registry 里找到 DriveTransformer 类并实例化。

然后它会继续递归构建子模块:

img_backbone=dict(type='ResNet', ...)

img_neck=dict(type='FPN', ...)

pts_bbox_head=dict(type='DriveTransformerlHead', ...)

transformer=dict(type='DriveTransformerWrapper', ...)

也就是说,这篇论文的模型不是手写一长串 self.xxx = ... 在训练脚本里拼的,而是配置驱动 + 递归构建。

  1. Dataset / Pipeline 机制

build_dataset(cfg.data.train) 会根据 dataset_type = "B2D_DriveTransformer_Dataset" 去构建数据集对象。

然后 pipeline 也是配置驱动的,像:

LoadMultiViewImageFromFiles

ResizeCropFlipImage

TrajPreprocess

CustomCollect3D

这非常符合论文需求,因为这篇工作不是普通检测,它需要同时组织:

多相机图像

ego pose

command

agent future trajectory label

map polyline label

所以 MMCV 在这里的价值是:把复杂数据流拆成一串可配置的数据处理模块。

  1. Runner / Hook 机制

build_runner(cfg.runner, ...) 构建的是训练循环引擎,不是模型本身。

Runner 负责:

迭代 dataloader

调 model.forward

反向传播与 optimizer step

checkpoint 保存

日志输出

学习率调度

runner.register_training_hooks(...) 把 lr_config、optimizer_config、checkpoint_config、log_config 都接进去。

对这篇论文来说,MMCV 让作者不用自己手写完整训练循环,而把精力放在:

统一 query 设计

memory 机制

多任务损失

vector map / planning 输出

  1. 这篇论文里 MMCV 最"值钱"的地方

配置可组合:可以快速替换 backbone、decoder 层数、loss 权重。

模块可注册:论文自定义的 detector/head/layer/dataset 能无缝接入标准训练框架。

训练逻辑通用化:分布式、fp16、hook、runner 这些基础设施直接复用。

代码边界清晰:

论文创新写在模型文件

工程调度交给 MMCV

  1. 用这篇代码理解 MMCV 的最佳链路

先看配置怎么声明实验:drivetransformer_small.py

再看训练脚本怎么调用 builder:train.py

再看注册后的模块本体:

drivetransformer.py

drivetransformer_head.py

drivetransformer_layers.py

bench2drive_drivetransformer_dataset.py

一句话总结

对这篇论文来说,MMCV 不是"算法本身",而是把论文算法模块化、配置化、可训练化的基础设施。

drivetransformer.py 解读

这个文件是做什么的

这是 DriveTransformer 的最外层模型封装 。它继承了 MMDetection3D 的 MVXTwoStageDetector,但实际上并没有做传统 two-stage 检测器那套复杂流程,而是借用了 MMCV/MMDet3D 的标准接口来承载 DriveTransformer 的统一多任务 head。

你可以把这个文件理解成:

  • 上接训练框架

  • 下接核心 head

  • 左边接图像 backbone/neck

  • 右边接闭环推理输出

它在系统中的位置

调用链大致是:

train.pybuild_model(cfg.model)DriveTransformerpts_bbox_head

这个文件自己不实现论文里最复杂的 attention 和 query 交互,但决定了:

  • 输入如何进入模型

  • 图像特征如何提取

  • 训练输出如何变成 loss

  • 推理输出如何变成 benchmark/agent 可以消费的结构

核心成员变量

self.grid_mask

训练时可选的数据增强,主要用于图像特征鲁棒性。

self.prev_scene_token

用于推理时判断是否切换到了新的场景。一旦切换场景,就必须重置 memory。

self.position_level

决定从 img_feats 的哪一层取视觉特征。当前实现基本只用一个 level。

关键方法一:forward

这是最标准的 PyTorch/MMCV 入口:

  • return_loss=True → 训练路径

  • return_loss=False → 测试路径

这里最重要的理解是:训练和推理的数据组织方式不一样。训练通常是 batch tensor,推理常常带更复杂的嵌套结构。

关键方法二:extract_img_feat

这是图像特征提取入口。

它主要做四件事:

  1. 整理输入维度

  2. 可选执行 grid_mask

  3. 调用图像 backbone

  4. 调用 neck,并重新 reshape 回多相机/多时序形式

这个方法是理解多相机输入形状的关键点。

为什么要 reshape

backbone 通常希望输入是 [B*N_cam, C, H, W] 这种 4D tensor;

而自动驾驶系统原始输入更像 [B, T, N_cam, C, H, W]

所以这里的 reshape 本质上是在:

  • 让 backbone 易于复用

  • 让 head 还能保留多相机/时序结构信息

关键方法三:forward_train

这是训练时的主逻辑。

训练流程

  1. 如有必要,先重置 memory

  2. 提取图像特征

  3. 对输入字典做形状整理

  4. 调用 pts_bbox_head 得到多任务预测

  5. 调用 pts_bbox_head.loss 得到多任务 loss

  6. _parse_losses 汇总日志与标量 loss

这里体现出的设计思想

这个 detector 本身不做各任务的细节计算,只做统一调度。真正的多任务耦合都在 pts_bbox_head 里。

所以如果你要研究:

  • detection query 怎么来的

  • ego planning 怎么预测

  • map polyline 怎么输出

都不要停留在这个文件,要继续往 head 里看。

关键方法四:forward_test

这是闭环评测 / 推理时的重要入口。

这里最关键的工程点:scene memory

python 复制代码
if img_metas[0]['scene_token'] != self.prev_scene_token: 

self.pts_bbox_head.reset_memory() 

说明模型在推理时使用了场景级的历史记忆。如果新场景不重置 memory,旧场景信息会污染新场景。

推理流程

  1. 整理嵌套输入结构

  2. 判断是否切换 scene

  3. 提取图像特征

  4. 调用 pts_bbox_head

  5. 调用 get_bboxes 生成结构化结果

  6. 补充 ego planning 输出

_parse_losses

这是 MMCV 风格常见辅助函数,用来:

  • 汇总多个 loss 分量

  • 生成日志字典

  • 为分布式训练提供统一的 loss 结构

这个文件对新工程师最大的价值

它帮你建立"系统边界感":

  • 视觉特征到这里为止;

  • 多任务预测从这里进入 head;

  • 训练 / 推理分流在这里完成;

  • 闭环 memory 生命周期在这里被管理。

读完后下一步看什么

下一步一定要看 drivetransformer_head.py。因为真正的论文主体------query、memory、统一多任务输出、loss 聚合------都在那里。

drivetransformer_head.py 解读

先给结论

如果你只能花时间认真读一个实现文件,优先读这个。

这个文件是 DriveTransformer 的核心大脑。它把四类任务统一到一个 head 里:

  • agent detection

  • agent trajectory prediction

  • online map prediction

  • ego planning

为什么这个文件如此重要

很多传统项目里,head 只是 backbone 后面接一个小分支;但这里不是。

这里的 DriveTransformerlHead 同时负责:

  1. 定义 query 和 reference points

  2. 构建多任务输出分支

  3. 维护时序 memory

  4. 计算前向预测

  5. 生成训练目标

  6. 计算多任务 loss

  7. 生成推理结果

所以这个文件既是"模型主体",也是"训练目标聚合器"。

顶层结构

主类:DriveTransformerlHead

这是核心类。

辅助类:MLN

文件末尾的 MLN 是一个小型条件调制模块,用在一些特征调节场景中。对入门阅读来说优先级低于主类。

一、__init__:把论文超参数落成代码结构

构造函数非常长,但可以按功能分块理解。

1. query 与模型容量相关参数

  • agent_num_query

  • map_num_query

  • memory_len_frame

  • fut_mode

  • fut_ts

这些参数决定了:

  • 要追踪多少个 agent 候选

  • 要预测多少个 map vector

  • 保留多少时序记忆

  • 每个目标预测多少条未来轨迹模式

2. 3D 位置编码相关参数

  • position_range

  • depth_start

  • depth_step

  • depth_num

这一块说明模型不是简单地做 2D 图像 token attention,而是试图把多相机图像 token 提升到 3D-aware 的表示空间。

3. 损失函数构建

这个类里直接实例化了 detection / motion / map / planning 的所有 loss。

这意味着训练时的多任务平衡逻辑都汇聚在这里。

二、初始化部分:init_output_head / init_weights

init_output_head

这个方法负责组装多任务输出分支,例如:

  • agent 分类分支

  • agent 框回归分支

  • trajectory 分支

  • map 分支

  • ego trajectory 分支

可以把它理解成:"统一 transformer 表征之后,如何拆成各任务头"。

init_weights

这个方法不仅初始化常规线性层,还初始化 query/reference points。

这里要特别注意:当前文件中有一些 project / TODO / 替换此处代码 标记,说明这份代码可能还带有课程作业式占位实现。阅读时请把这些位置当作"待补全逻辑",不要误认为是论文正式版的完整实现。

三、forward:全文件最重要的方法

这是真正把四类任务串到一起的地方。

可以按以下顺序理解。

第 1 步:更新 memory

python 复制代码
self.pre_update_memory(data) 

模型会先根据当前 batch 的 prev_exists、ego pose、时间戳等信息更新内部记忆状态。

第 2 步:准备图像 token 与 3D 位置编码

核心方法:

  • prepare_location

  • img_3d_position_embedding

这一步把多相机视觉特征变成带有 3D 几何意识的 token 表示。

第 3 步:准备 agent/map query 与 reference points

这一步定义模型要"找什么东西"。

  • agent_query:候选动态目标 token

  • map_query:候选地图元素 token

对 DETR 系背景的工程师来说,这和 object query 很像;但这里是多任务统一版。

第 4 步:temporal alignment

python 复制代码
agent_query, map_query, ... = self.temporal_alignment(...) 

这一步把历史 memory 和当前 query 对齐起来,是 streaming 处理的关键。

第 5 步:预解码

通过 agent_prep_decoder / map_prep_decoder,让初始 query 先和视觉 token 做一次较轻量交互。

预解码之后,模型会生成初步预测:

  • 初始 agent class / bbox

  • 初始 map class / polyline

  • 初始 motion trajectory

  • 初始 ego planning trajectory

第 6 步:ego planning query 构造

ego query 来自多个信号拼接:

  • 自车动力学特征 ego_lcf_feat

  • 自车历史轨迹 ego_his_trajs

  • 规划命令 ego_fut_cmd

这一点非常值得新工程师记住:端到端规划并不是只看图像,它还明确使用了自车状态和导航命令。

第 7 步:进入统一 transformer 主干

python 复制代码
self.transformer(...) 

这是整篇论文"统一化"的核心落点:agent、map、ego 三类 query 在同一个主干里共同演化。

第 8 步:收集结果并写回 memory

通过 post_update_memory 把当前帧的重要表示写回历史记忆,供下一帧使用。

然后把每一层 decoder 的输出拼成:

  • all_cls_scores

  • all_bbox_preds

  • all_traj_preds

  • map_all_cls_scores

  • ego_fut_preds_fix_time

这就是后面 loss 计算和推理输出的基础。

四、memory 相关方法

reset_memory

场景切换时调用,清空历史记忆。

pre_update_memory

在当前帧计算前,先把 memory 做时序推进和有效位处理。

post_update_memory

在当前帧计算后,把新的 agent/map/ego 信息写入记忆池。

temporal_alignment

作用是把当前 query 和历史 memory 对齐,使不同帧之间能在统一参考系下交互。

五、目标分配与 loss

detection target

  • _get_target_single

  • get_targets

map target

  • _map_get_target_single

  • map_get_targets

planning loss

  • loss_planning

总 loss

  • loss

你可以把这个类理解成同时实现了:

  • DETR-style matching for detection

  • vector map matching

  • multimodal motion loss

  • ego planning regression/classification loss

六、推理输出

get_bboxes

这个方法把内部预测结果变成更接近 benchmark 输出的数据结构:

  • 3D boxes

  • scores

  • labels

  • trajectories

  • map vectors

阅读这个文件的重点问题

建议带着下面 5 个问题读:

  1. query 是怎么初始化的?

  2. 历史 memory 是怎么进入当前帧的?

  3. 图像 token 的 3D 位置编码怎么构造?

  4. 为什么 detection/map/planning 可以共享一个主干?

  5. 最终 loss 是怎样把多任务绑在一起训练的?

读完后下一步看什么

接下来去看 drivetransformer_layers.py。因为你已经知道 head 在调用统一 transformer,但还不知道每一层 transformer 究竟在做什么。

drivetransformer_layers.py 解读

这个文件是做什么的

这是 DriveTransformer 的统一 transformer 层定义文件

如果说 drivetransformer_head.py 是大脑总成,那么这个文件就是"大脑皮层的神经回路实现"。论文里提到的三类统一操作,主要都在这里落实:

  • task self-attention

  • temporal cross-attention

  • sensor cross-attention

核心类总览

AttentionLayer

最底层的通用注意力模块。

SwiGLULayer

FFN 模块,使用 SwiGLU 激活结构。

DriveTransformerPreDecoderLayer

预解码层,负责 query 的初步视觉对齐。

DriveTransformerDecoderLayer

主解码层,是论文 unified operation 的核心实现。

DriveTransformerDecoder

把多层 DriveTransformerDecoderLayer 串起来,并在每层输出中做迭代 refinement。

DriveTransformerWrapper

对 decoder 做一个更工程化的包装。

DriveTransformerPreDecoder

预解码 sequence 包装类。

一、AttentionLayer

这是这个文件里最底层也最关键的通用积木。

它做什么

输入:

  • query

  • key

  • value

  • 可选位置编码

  • 可选 attention mask

输出:

  • attention 更新后的 token

它的价值

同一套 AttentionLayer 可以被不同场景复用:

  • query 看 query:task self-attention

  • query 看 memory:temporal cross-attention

  • query 看 image token:sensor cross-attention

这也是为什么论文能强调"统一操作"。

二、SwiGLULayer

这是 transformer 中的 FFN 层变体。

相比传统 ReLU FFN,SwiGLU 通常更适合大模型和 transformer,表达能力更好。

三、DriveTransformerPreDecoderLayer

预解码层的目标是:

  • 在正式多层联合推理之前

  • 让初始 query 先从视觉 token 中吸收一轮信息

它更像"query 热启动器"。

四、DriveTransformerDecoderLayer:整篇论文最重要的层级实现

这是这个文件里最值得认真读的类。

它的输入是什么

它同时接收:

  • 当前 agent/map/ego query

  • 当前图像 token

  • 历史 agent/map/ego memory

  • 各类位置编码

这已经体现出 DriveTransformer 的核心野心:把多任务 + 多传感器 + 多时间放进同一个 layer 里。

operation_order

这个字段非常关键。当前配置中大致是:

  1. task_self_attn

  2. norm

  3. temporal_cross_attn

  4. norm

  5. sensor_cross_attn

  6. norm

  7. ffn

  8. norm

这就是论文思路的工程化顺序:

先让任务 token 彼此交互,再让当前帧对历史记忆交互,再让 query 去读取图像特征,最后过 FFN。

task_self_attn

这里的重点不是普通 self-attention,而是不同任务 query 之间能否互相看见

代码里通过 attention mask 控制:

  • agent 看哪些 token

  • map 看哪些 token

  • ego 看哪些 token

这一步决定了"统一但不混乱"的信息交换边界。

temporal_cross_attentions

这里把当前 query 分成三组:

  • agent

  • map

  • ego

然后分别和自己的历史 memory 做 cross-attention。

这非常符合自动驾驶实际需求,因为:

  • agent 历史用于目标跟踪与轨迹连续性

  • map 历史用于拓扑与稳定性

  • ego 历史用于规划连续性

sensor_cross_attn

这一步让 query 去读取图像 token。

本质上这是当前帧感知信息进入统一 query 表征的关键接口。

ffn

最后对 agent/map/ego 各自过独立 FFN。说明作者在共享 attention 的同时,保留了任务内部的独立非线性变换能力。

五、DriveTransformerDecoder

这个类把多层 decoder layer 串起来,并负责每层迭代 refinement。

它做的核心事情

  1. 根据当前 reference points 构造位置编码

  2. 调用每一层 decoder layer

  3. 用各类输出头更新:

  • detection box

  • class

  • map point/polyline

  • motion trajectory

  • ego planning trajectory

  1. 保存每层中间结果

这本质上和 DETR 系方法的 iterative refinement 很像,但这里是多任务版本。

六、DriveTransformerWrapper

这是一个轻包装器,主要作用是:

  • 包装 decoder

  • 初始化权重

  • 暴露统一 forward 接口

从阅读优先级来说,它低于 DriveTransformerDecoderLayerDriveTransformerDecoder

新工程师读这个文件时最容易卡住的点

1. 为什么会有这么多位置编码

因为这里不是单一检测任务,而是同时处理:

  • agent box 位置

  • map polyline 位置

  • ego trajectory 位置

  • 图像 token 位置

  • 历史 memory 位置

2. 为什么 query 要拆成 agent/map/ego 三段

因为三类任务虽然共享主干,但各自语义不同、reference form 不同、输出头不同。

3. 为什么同时要有 self-attn 和两种 cross-attn

因为 DriveTransformer 要同时解决:

  • 当前任务间协同

  • 当前与历史对齐

  • 当前与视觉观测对齐

读完后下一步看什么

下一步看 dataset:bench2drive_drivetransformer_dataset.py。因为你已经知道模型如何处理 query 了,接下来应该知道这些 query 的监督和输入到底从哪里来。

bench2drive_drivetransformer_dataset.py 解读

这个文件是做什么的

这是 DriveTransformer 在 Bench2Drive 数据集上的数据适配器

它的任务远不止"读取图片和标注",还包括:

  • 组织多相机输入

  • 构造 ego 历史 / 未来轨迹

  • 构造 agent 未来属性标签

  • 提取在线地图标注

  • 维护 clip 级缓存

  • 把时序样本打包成模型需要的格式

为什么这个文件很重要

端到端自动驾驶项目里,dataset 往往决定了模型最终"看到了什么监督"。

很多论文读起来只说:输入图像、输出规划;但工程里真正困难的是把:

  • 多相机几何

  • 自车状态

  • 未来轨迹

  • 地图折线

  • 路径命令

全部拼成统一训练样本。这个文件就在做这件事。

顶层类:B2D_DriveTransformer_Dataset

它继承自 Custom3DDataset,但实际已经做了大量自动驾驶场景特化。

一、__init__

构造函数主要做四类事:

  1. 保存配置参数

  2. 加载地图信息

  3. 初始化 route/clip 缓存

  4. 生成 sequence grouping flag

为什么要做 clip 缓存

Bench2Drive 数据按 route/clip 组织。如果每次取样本都直接重新读整段 pkl,I/O 会很慢。

所以这里实现了按 route 的缓存机制,减少重复读盘。

这对大规模时序训练非常重要。

二、序列分组:_set_sequence_group_flag*

这两个函数负责给每个样本打"属于哪个序列"的标记。

为什么需要这个?

因为时序训练时,相邻样本之间是否连续、是否来自同一 clip,非常重要。

三、缓存读取:get_data_by_index

这是 dataset 工程优化的关键函数。

它会:

  • 先检查当前 route cache

  • 再检查历史缓存

  • 如果都没命中,再从磁盘读取 route pkl

  • 用 LRU 风格策略更新缓存

这个函数很有"工业界自动驾驶训练数据系统"的味道,而不是普通学术 demo。

四、样本准备主入口

prepare_train_data

训练样本生成主入口:

  1. get_data_info(index) 获取结构化原始输入

  2. 执行 pipeline

  3. 补充 map 标注

  4. 统一打包成训练输入

prepare_test_data

测试样本版本,与训练类似,但不会做训练阶段的空样本回退逻辑。

五、union2one

这个函数很关键,它把队列样本合成当前模型需要的单条输入格式。

它尤其做了两件重要事情:

  1. 把多帧图像 stack 起来

  2. can_bus 等时序元信息转成相对位姿

这一步相当于把"连续帧序列"整理成"当前帧可消费的时序包"。

六、get_data_info

这是全文件最值得重点读的方法之一。

它会构造模型真正使用的输入字典,包括:

  • 图像路径

  • lidar2img

  • cam_intrinsic

  • cam_extrinsic

  • ego_pose

  • ego_pose_inv

  • can_bus

  • ego_lcf_feat

  • ego_fut_cmd

  • ego 历史与未来轨迹

这里体现了几个非常重要的工程思想

1. 统一几何坐标系

代码里反复在 world / ego / lidar / camera 坐标系之间转换。这是自动驾驶工程的基础功。

2. 规划命令不是离散 ID,而是结构化表示

command_far / command_near 不仅做 one-hot,还把目标点坐标做位置编码。

这比只给一个左转/右转标签更强,因为带入了几何位置信息。

3. ego 特征是显式输入

模型不只看图像,还显式使用:

  • 速度

  • 加速度

  • 角速度

  • 车身尺寸

  • 转向

这对规划任务非常合理。

七、地图标注:get_map_info

这个方法把城镇地图中的 lane / trigger volume 转换成当前局部坐标系下的 polyline 标注。

核心过程

  1. 先粗筛一定范围内的地图元素

  2. 再投到当前 lidar 坐标系

  3. 裁剪到有效感兴趣区域

  4. 组织成 LiDARInstanceLines

这说明作者把地图任务建模为 vector map prediction,而不是像素级栅格地图分割。

八、目标框标注:get_ann_info

这里会:

  • 过滤空点目标

  • 做类别映射

  • 生成 LiDARInstance3DBoxes

  • 附加 agent future attribute labels

这里的 attr_labels 很重要,因为它不只是检测框标签,还包含运动相关监督。

九、ego 轨迹相关函数

get_ego_past_trajs

生成自车历史轨迹的相对位移序列。

get_ego_future_trajs

生成固定时间间隔的未来轨迹。

get_ego_future_trajs_fix_dis

生成固定距离间隔的未来轨迹,适合规划任务。

这两个版本一起存在,说明作者希望规划输出兼顾时间一致性和空间一致性。

十、agent 未来属性:get_box_attr_labels

这是很容易被忽视、但非常有价值的函数。

它会对每个当前帧目标生成未来运动监督,包括:

  • future offset trajectory

  • valid mask

  • goal direction class

  • 局部几何特征

  • yaw offset

这一步把普通检测框扩展成了"可做运动预测的目标表示"。

十一、__getitem__

普通 PyTorch dataset 入口。

如果是训练模式:

  • 随机增强

  • 准备训练样本

  • 如果样本无效,重新采样

如果是测试模式:

  • 直接准备测试样本

新工程师最应该从这个文件学到什么

  1. 自动驾驶 dataset 远不止读图片和标注。

  2. 坐标系转换是系统核心能力。

  3. ego 规划、agent motion、map prediction 的监督都要在 dataset 层仔细构造。

  4. 工业级时序训练必须考虑缓存、route 连续性和样本打包方式。

读完整条链路之后,你对项目应该形成的整体图景

  1. 配置文件定义系统蓝图

  2. train.py 负责启动工程

  3. drivetransformer.py 负责最外层训练/推理调度

  4. drivetransformer_head.py 负责统一多任务建模

  5. drivetransformer_layers.py 负责 unified transformer 细节

  6. bench2drive_drivetransformer_dataset.py 负责把数据和监督喂给模型

到这里,你就已经掌握了这个项目最关键的主链路。

扩展

相关推荐
大龄程序员狗哥1 小时前
第34篇:自动化机器学习(AutoML)初探——让AI来设计AI(概念入门)
人工智能·机器学习·自动化
一几文1 小时前
什么是硅基时间?什么是碳基时间?为何两者总是同时被提起?
人工智能·机器学习·ai·大模型·算力·碳基·硅基
seasonsyy1 小时前
机器学习领域三大顶会简介
人工智能·机器学习
数智化精益手记局1 小时前
拆解红牌作战的步骤:掌握红牌作战的步骤,解决现场管理难题
大数据·数据结构·人工智能·制造·精益工程
小仙女的小稀罕1 小时前
政务行业政务服务标准化专属解决方案
人工智能·政务
wuyoula1 小时前
尹之盾企业版网络验证
服务器·开发语言·javascript·c++·人工智能·ui·c#
好家伙VCC1 小时前
上市公司产学研合作及专利数据(1998-2022年)
人工智能·python
happyprince2 小时前
[推理]vLLM-2026年第二季度路线图
人工智能
自动驾驶小学生2 小时前
Transformer和LLM前沿内容(4):Long-Context LLM
人工智能·深度学习·transformer