引入
写在前面
由于时代变了,几年前学习bevformer的时候 自己的笔记 查阅的资料 AI的问答 还有后续补充更新杂乱在一起,导致自己写的东西 几年后再看 也不读不进去了 写的太乱了
传统规控转型的我 感觉 对于新问题 知道问什么 和怎么问 之后 学习起来基本就靠ai带飞了
当下的学习,可以很方便的用IDE中 或者龙虾的AI全局阅读,然后给出框接,总结和细节。
文体从不知所云的散文 改变为对话式了
AI给出的框架和代码解读在前 细节作为代码备注直接在代码里 传在附件中
部分扩展的知识点、额外的提问,作为自我查阅的备忘放在最后面
todo
目前只看完了AI 对主要代码的解读
有待细看提出问题
文件阅读顺序
先看配置:drivetransformer_small.py 模型结构、数据集、loss query 数量。
再看训练入口:train.py 训练主入口 读配置、加载 plugin dataset、 model runner
再看模型调度层:drivetransformer.py 模型主体1
然后重点读统一任务头:drivetransformer_head.py 模型主体2
最后补 transformer 细节和数据流:drivetransformer_layers.py 模型主体3
数据怎么喂:bench2drive_drivetransformer_dataset.py
drivetransformer_small.py 解读
这个文件是做什么的
这是 DriveTransformer 的总装图。
-
train.py是启动器; -
drivetransformer.py是总控器; -
drivetransformer_head.py是核心大脑; -
drivetransformer_layers.py是大脑里的神经层; -
bench2drive_drivetransformer_dataset.py是数据供给系统; -
那么这个配置文件就是把这些模块拼到一起的蓝图。
先读配置文件的意义在于:能先知道系统有哪些模块、每个模块的规模、输入输出是什么。
建议阅读顺序
-
全局超参数
-
model段 -
train_cfg -
train_pipeline/test_pipeline -
data -
optimizer/lr_config/runner
这份配置里最重要的全局概念
1. 任务统一
DriveTransformer 不是只做一个任务,而是把四件事放在同一个模型里做:
-
3D agent detection
-
agent motion prediction
-
online map prediction
-
ego planning
所以配置里会不断看到以下三类 query:
-
agent_query -
map_query -
ego query
它们最后在统一 transformer 中交互。
2. 记忆机制
memory_len_frame = 10
这表示模型不仅看当前帧,还会保留历史帧里的记忆 token。时序记忆是这篇论文实现 streaming processing 的核心。
3. query 数量
-
agent_query_num_vec = 900 -
map_query_num_vec = 100 -
agent_num_topk_sift = 900 -
map_num_topk_sift = 100
这些值决定了模型每帧愿意维护多少候选目标和地图元素。query 越多,表达能力越强,但计算量也越大。
model 段怎么读
1. type='DriveTransformer'
说明最外层模型类在 mmdet3d_plugin/ours/drivetransformer.py。
2. img_backbone
这里用的是 ResNet-50,作用是把多相机图像提成高层视觉特征。
3. img_neck
这里用 FPN,把 backbone 的输出变成统一通道数 _dim_ = 128 的特征,方便后面 transformer 使用。
4. pts_bbox_head
这是全配置里最重要的模块。这里承载了:
-
detection
-
motion forecasting
-
map prediction
-
planning
也就是说,这个 head 是论文核心逻辑。
pts_bbox_head 里的关键参数
感知 / 预测 / 地图 / 规划统一维度
-
embed_dims=_dim_ -
num_reg_fcs=2 -
num_cls_fcs=2
这些定义了统一 token 的特征维度,以及每个输出分支的小 MLP 规模。
检测与运动预测
-
agent_num_query -
agent_num_query_sifted -
fut_mode -
fut_ts -
num_classes
含义可以理解为:
-
用多少 agent token 找目标;
-
对每个 agent 预测多少未来模式;
-
每条未来轨迹包含多少时间点。
在线地图
-
map_num_query -
map_num_pts_per_vec -
map_num_classes
这里说明 map 不是当成 dense segmentation,而是当成polyline/vector prediction 来做。
规划
-
fut_ts_ego_fix_time -
fut_ts_ego_fix_dist -
fut_ego_fix_dist
说明 ego 规划同时支持:
-
固定时间间隔轨迹
-
固定距离间隔轨迹
这对闭环控制是很实用的,因为固定时间和固定距离各有优势。
transformer 结构怎么对应论文
agent_prep_decoder / map_prep_decoder
这两个是预解码层,用于让 query 在进入主 decoder 之前,先从图像特征里拿到初始信息。
transformer=dict(type='DriveTransformerWrapper', ...)
这里才是主干统一 transformer。
注意 operation_order:
-
task_self_attn -
temporal_cross_attn -
sensor_cross_attn -
ffn
这基本就是论文的三类 unified operation 在代码里的直接映射。
loss 配置怎么读
配置里 loss 大体分四组:
-
detection:
loss_cls、loss_bbox -
motion:
loss_traj、loss_traj_cls -
map:
loss_map_cls、loss_map_pts、loss_map_dir -
planning:
loss_plan_reg_fix_time、loss_plan_reg_fix_dist、loss_plan_cls
这说明训练时模型不是只优化一个目标,而是统一优化四类任务。
train_cfg 怎么理解
这里的重点是 assignment:
-
HungarianAssigner3D -
MapHungarianAssigner3D
这意味着 detection 和 map prediction 都沿用了 DETR 系风格的 bipartite matching,而不是传统 anchor matching。
DriveTransformer 的 query 预测方式,本质上更像 DETR 系,而不是传统 BEV 检测器。
数据部分怎么读
dataset_type
B2D_DriveTransformer_Dataset 对应 Bench2Drive 数据封装。
train_pipeline
核心过程:
-
读多相机图像
-
图像增强
-
读 3D 标注
-
过滤无效目标
-
构造轨迹标签
-
格式化并收集字段
collect_keys
这很重要,它定义了哪些时序 / 几何字段会被送入模型,例如:
-
lidar2img -
cam_intrinsic -
ego_pose -
ego_pose_inv -
timestamp
这些字段是做 3D 位置编码和时序对齐的关键输入。
训练策略部分
优化器
AdamW 是当前 transformer 系方法的标准选择。
学习率策略
CosineAnnealing + warmup 也是比较标准的 transformer 训练配置。
runner
这里使用 IterBasedRunner,不是 epoch-based。说明作者更关心迭代数控制,而不是按完整数据轮次来表达训练进度。
总结
-
这个文件不是"参数表",而是论文设计的工程映射。
-
model.pts_bbox_head是核心中的核心。 -
query 数、memory 长度、trajectory mode 数量,是理解模型容量的关键。
-
pipeline 和
collect_keys决定了模型到底能看到哪些几何与时序信息。
读完这个文件后,下一步看什么
下一步建议看 train.py,目前已经知道"要搭什么",接下来应该看"怎么启动和组装"。
train.py 解读
这个文件是做什么的
这是训练入口。它不负责定义论文算法,而负责把整个工程真正跑起来。
如果把配置文件比作"建筑蓝图",那 train.py 就是"施工总包":
-
读取蓝图
-
导入自定义模块
-
构建数据集
-
构建模型
-
构建优化器
-
构建 runner
-
开始训练 / 恢复训练 / 加载权重
顶层结构
这个文件主要由四部分组成:
-
BN 融合工具函数:
_fuse_conv_bn/fuse_conv_bn -
参数解析:
parse_args -
主入口:
main -
启动语句:脚本入口
一、BN 融合工具函数
_fuse_conv_bn
这个函数把一个卷积层和其后的 BN 融成一个卷积层。
用途:
-
推理时减少额外 BN 计算
-
简化网络图
这不是 DriveTransformer 特有逻辑,更像一个部署 / 推理优化选项。
fuse_conv_bn
递归地在整个模块里查找 Conv2d + BN 组合,并替换成融合后的结果。
如果你做推理优化、导出模型、部署,这个函数值得关注。
二、parse_args
这个函数定义了训练脚本的命令行接口。
重要的参数有:
-
config:配置文件路径 -
--work-dir:日志和 checkpoint 输出目录 -
--resume-from:断点续训 -
--load-from:加载预训练权重 -
--launcher:是否分布式训练 -
--cfg-options:临时覆盖配置项 -
--compile_before/--compile_after:尝试用torch.compile加速
这说明作者在工程上考虑了:
-
分布式训练
-
断点恢复
-
动态改配置
-
性能优化
三、main() 主流程
这是整个文件最重要的部分。
第 1 步:读取配置
python
cfg = Config.fromfile(args.config)
这里把配置文件变成 MMCV 的 Config 对象。后面所有 build 逻辑都从这里读取信息。
第 2 步:导入 plugin
DriveTransformer 的自定义模块不在 MMCV 默认注册表里,所以必须先 import:
-
自定义 dataset
-
自定义 detector
-
自定义 head
-
自定义 transformer 层
如果这一步漏掉,后面的 build_model / build_dataset 会直接失败。
第 3 步:设置工作目录和日志
这里会:
-
创建
work_dir -
备份当前配置
-
创建日志文件
这是训练工程中非常重要的一步,因为它保证了可复现性。
第 4 步:初始化分布式环境
通过 args.launcher 判断是否需要分布式训练。
如果使用 pytorch launcher,会调用 init_dist 初始化多卡环境。
第 5 步:记录环境信息和随机种子
包括:
-
CUDA / PyTorch / MMCV 环境
-
当前配置文本
-
随机种子
这是实验管理里非常关键的"元数据记录"。
第 6 步:构建数据集与 dataloader
python
datasets = [build_dataset(cfg.data.train)]
data_loaders = [build_dataloader(...)]
这里说明:
-
数据集类型与 pipeline 在 config 中定义;
-
train.py只负责调度,不写死具体数据逻辑。
这也是 MMCV 风格项目的一大特点:配置驱动构建。
第 7 步:构建模型
python
model = build_model(cfg.model, ...)
model.init_weights()
这一步会根据配置真正实例化 DriveTransformer。
注意它还会打印参数量,这对模型规模评估很有用。
第 8 步:包裹并行训练接口
-
分布式:
DistributedDataParallel -
单机多卡 / 单卡:
DataParallel
这说明模型实现本身是普通 PyTorch 模块,训练并行性由外层包装器负责。
第 9 步:构建优化器和优化钩子
这里既支持普通优化,也支持 fp16 hook。
如果你后续要做混合精度、梯度裁剪或 optimizer 替换,改动点通常就在这里附近。
第 10 步:构建 runner
MMCV 的 runner 负责:
-
驱动训练循环
-
注册 log hook
-
注册 checkpoint hook
-
注册 lr schedule
你可以把它理解为"训练主循环引擎"。
第 11 步:可选 torch.compile
作者提供了 compile 前后两个开关,尝试对 backbone、neck、head 的一些子模块做编译优化。
这是一个很工程化的设计点,表明作者在意吞吐和性能。
第 12 步:加载 checkpoint 或恢复训练
-
resume:恢复训练状态 -
load_checkpoint:只加载模型权重
二者区别很重要,新工程师经常混淆。
这个文件里你要建立的工程直觉
-
它不实现论文核心算法,它实现"把算法跑起来"。
-
它大量依赖 MMCV 的 registry + build 机制。
-
它是定位训练问题的第一现场:
-
模块没注册
-
配置路径错
-
DDP 没初始化
-
checkpoint 加载失败
-
dataloader 构建失败
调试建议
如果训练跑不起来,优先检查:
-
config是否指向正确文件 -
plugin_dir是否导入成功 -
dataset 是否能 build
-
model 是否能
init_weights -
checkpoint path 是否存在
读完这个文件后,下一步看什么
下一步看 mmdet3d_plugin/ours/drivetransformer.py。因为 train.py 已经告诉你"怎么启动",接下来要看"模型最外层怎么调度训练和推理"。
扩展
MMCV的机制
核心思路
这个项目对 MMCV 的用法,本质上是:用配置文件描述实验,用 registry 注册模块,用 builder 递归实例化,用 runner 驱动训练。
所以论文方法本身主要写在模型与数据模块里,而训练脚本 train.py 更像一个"装配与调度中心"。
- Config 机制
drivetransformer_small.py 里的 model=dict(...)、data=dict(...)、optimizer=dict(...) 不是普通参数表,而是 MMCV 的声明式实验配置。
Config.fromfile(...) 读入后,cfg.model、cfg.data.train、cfg.optimizer 都变成可递归访问的配置对象。
论文里的"统一 transformer + 多任务 head + 时序记忆 + map/vector prediction"并不是在 train.py 里硬编码拼出来的,而是先写成配置,再由 builder 去构建。
- Registry 机制
MMCV 的关键机制是 registry:模块先注册,builder 才能按名字找到它。
比如:
detector 在 drivetransformer.py 里通过 @DETECTORS.register_module()
head 在 drivetransformer_head.py 里通过 @HEADS.register_module()
transformer layer 在 drivetransformer_layers.py 里通过 @ATTENTION.register_module()、@TRANSFORMER_LAYER.register_module() 等
dataset 在 bench2drive_drivetransformer_dataset.py 里通过 @DATASETS.register_module()
这就是为什么 train.py 要先 import plugin:不 import,就还没注册,build_model / build_dataset 就找不到对应名字。
- Builder 机制
build_model(cfg.model, ...) 会读取配置里的 type='DriveTransformer',去 registry 里找到 DriveTransformer 类并实例化。
然后它会继续递归构建子模块:
img_backbone=dict(type='ResNet', ...)
img_neck=dict(type='FPN', ...)
pts_bbox_head=dict(type='DriveTransformerlHead', ...)
transformer=dict(type='DriveTransformerWrapper', ...)
也就是说,这篇论文的模型不是手写一长串 self.xxx = ... 在训练脚本里拼的,而是配置驱动 + 递归构建。
- Dataset / Pipeline 机制
build_dataset(cfg.data.train) 会根据 dataset_type = "B2D_DriveTransformer_Dataset" 去构建数据集对象。
然后 pipeline 也是配置驱动的,像:
LoadMultiViewImageFromFiles
ResizeCropFlipImage
TrajPreprocess
CustomCollect3D
这非常符合论文需求,因为这篇工作不是普通检测,它需要同时组织:
多相机图像
ego pose
command
agent future trajectory label
map polyline label
所以 MMCV 在这里的价值是:把复杂数据流拆成一串可配置的数据处理模块。
- Runner / Hook 机制
build_runner(cfg.runner, ...) 构建的是训练循环引擎,不是模型本身。
Runner 负责:
迭代 dataloader
调 model.forward
反向传播与 optimizer step
checkpoint 保存
日志输出
学习率调度
runner.register_training_hooks(...) 把 lr_config、optimizer_config、checkpoint_config、log_config 都接进去。
对这篇论文来说,MMCV 让作者不用自己手写完整训练循环,而把精力放在:
统一 query 设计
memory 机制
多任务损失
vector map / planning 输出
- 这篇论文里 MMCV 最"值钱"的地方
配置可组合:可以快速替换 backbone、decoder 层数、loss 权重。
模块可注册:论文自定义的 detector/head/layer/dataset 能无缝接入标准训练框架。
训练逻辑通用化:分布式、fp16、hook、runner 这些基础设施直接复用。
代码边界清晰:
论文创新写在模型文件
工程调度交给 MMCV
- 用这篇代码理解 MMCV 的最佳链路
先看配置怎么声明实验:drivetransformer_small.py
再看训练脚本怎么调用 builder:train.py
再看注册后的模块本体:
drivetransformer_head.py
drivetransformer_layers.py
bench2drive_drivetransformer_dataset.py
一句话总结
对这篇论文来说,MMCV 不是"算法本身",而是把论文算法模块化、配置化、可训练化的基础设施。
drivetransformer.py 解读
这个文件是做什么的
这是 DriveTransformer 的最外层模型封装 。它继承了 MMDetection3D 的 MVXTwoStageDetector,但实际上并没有做传统 two-stage 检测器那套复杂流程,而是借用了 MMCV/MMDet3D 的标准接口来承载 DriveTransformer 的统一多任务 head。
你可以把这个文件理解成:
-
上接训练框架
-
下接核心 head
-
左边接图像 backbone/neck
-
右边接闭环推理输出
它在系统中的位置
调用链大致是:
train.py → build_model(cfg.model) → DriveTransformer → pts_bbox_head
这个文件自己不实现论文里最复杂的 attention 和 query 交互,但决定了:
-
输入如何进入模型
-
图像特征如何提取
-
训练输出如何变成 loss
-
推理输出如何变成 benchmark/agent 可以消费的结构
核心成员变量
self.grid_mask
训练时可选的数据增强,主要用于图像特征鲁棒性。
self.prev_scene_token
用于推理时判断是否切换到了新的场景。一旦切换场景,就必须重置 memory。
self.position_level
决定从 img_feats 的哪一层取视觉特征。当前实现基本只用一个 level。
关键方法一:forward
这是最标准的 PyTorch/MMCV 入口:
-
return_loss=True→ 训练路径 -
return_loss=False→ 测试路径
这里最重要的理解是:训练和推理的数据组织方式不一样。训练通常是 batch tensor,推理常常带更复杂的嵌套结构。
关键方法二:extract_img_feat
这是图像特征提取入口。
它主要做四件事:
-
整理输入维度
-
可选执行
grid_mask -
调用图像 backbone
-
调用 neck,并重新 reshape 回多相机/多时序形式
这个方法是理解多相机输入形状的关键点。
为什么要 reshape
backbone 通常希望输入是 [B*N_cam, C, H, W] 这种 4D tensor;
而自动驾驶系统原始输入更像 [B, T, N_cam, C, H, W]。
所以这里的 reshape 本质上是在:
-
让 backbone 易于复用
-
让 head 还能保留多相机/时序结构信息
关键方法三:forward_train
这是训练时的主逻辑。
训练流程
-
如有必要,先重置 memory
-
提取图像特征
-
对输入字典做形状整理
-
调用
pts_bbox_head得到多任务预测 -
调用
pts_bbox_head.loss得到多任务 loss -
用
_parse_losses汇总日志与标量 loss
这里体现出的设计思想
这个 detector 本身不做各任务的细节计算,只做统一调度。真正的多任务耦合都在 pts_bbox_head 里。
所以如果你要研究:
-
detection query 怎么来的
-
ego planning 怎么预测
-
map polyline 怎么输出
都不要停留在这个文件,要继续往 head 里看。
关键方法四:forward_test
这是闭环评测 / 推理时的重要入口。
这里最关键的工程点:scene memory
python
if img_metas[0]['scene_token'] != self.prev_scene_token:
self.pts_bbox_head.reset_memory()
说明模型在推理时使用了场景级的历史记忆。如果新场景不重置 memory,旧场景信息会污染新场景。
推理流程
-
整理嵌套输入结构
-
判断是否切换 scene
-
提取图像特征
-
调用
pts_bbox_head -
调用
get_bboxes生成结构化结果 -
补充 ego planning 输出
_parse_losses
这是 MMCV 风格常见辅助函数,用来:
-
汇总多个 loss 分量
-
生成日志字典
-
为分布式训练提供统一的 loss 结构
这个文件对新工程师最大的价值
它帮你建立"系统边界感":
-
视觉特征到这里为止;
-
多任务预测从这里进入 head;
-
训练 / 推理分流在这里完成;
-
闭环 memory 生命周期在这里被管理。
读完后下一步看什么
下一步一定要看 drivetransformer_head.py。因为真正的论文主体------query、memory、统一多任务输出、loss 聚合------都在那里。
drivetransformer_head.py 解读
先给结论
如果你只能花时间认真读一个实现文件,优先读这个。
这个文件是 DriveTransformer 的核心大脑。它把四类任务统一到一个 head 里:
-
agent detection
-
agent trajectory prediction
-
online map prediction
-
ego planning
为什么这个文件如此重要
很多传统项目里,head 只是 backbone 后面接一个小分支;但这里不是。
这里的 DriveTransformerlHead 同时负责:
-
定义 query 和 reference points
-
构建多任务输出分支
-
维护时序 memory
-
计算前向预测
-
生成训练目标
-
计算多任务 loss
-
生成推理结果
所以这个文件既是"模型主体",也是"训练目标聚合器"。
顶层结构
主类:DriveTransformerlHead
这是核心类。
辅助类:MLN
文件末尾的 MLN 是一个小型条件调制模块,用在一些特征调节场景中。对入门阅读来说优先级低于主类。
一、__init__:把论文超参数落成代码结构
构造函数非常长,但可以按功能分块理解。
1. query 与模型容量相关参数
-
agent_num_query -
map_num_query -
memory_len_frame -
fut_mode -
fut_ts
这些参数决定了:
-
要追踪多少个 agent 候选
-
要预测多少个 map vector
-
保留多少时序记忆
-
每个目标预测多少条未来轨迹模式
2. 3D 位置编码相关参数
-
position_range -
depth_start -
depth_step -
depth_num
这一块说明模型不是简单地做 2D 图像 token attention,而是试图把多相机图像 token 提升到 3D-aware 的表示空间。
3. 损失函数构建
这个类里直接实例化了 detection / motion / map / planning 的所有 loss。
这意味着训练时的多任务平衡逻辑都汇聚在这里。
二、初始化部分:init_output_head / init_weights
init_output_head
这个方法负责组装多任务输出分支,例如:
-
agent 分类分支
-
agent 框回归分支
-
trajectory 分支
-
map 分支
-
ego trajectory 分支
可以把它理解成:"统一 transformer 表征之后,如何拆成各任务头"。
init_weights
这个方法不仅初始化常规线性层,还初始化 query/reference points。
这里要特别注意:当前文件中有一些 project / TODO / 替换此处代码 标记,说明这份代码可能还带有课程作业式占位实现。阅读时请把这些位置当作"待补全逻辑",不要误认为是论文正式版的完整实现。
三、forward:全文件最重要的方法
这是真正把四类任务串到一起的地方。
可以按以下顺序理解。
第 1 步:更新 memory
python
self.pre_update_memory(data)
模型会先根据当前 batch 的 prev_exists、ego pose、时间戳等信息更新内部记忆状态。
第 2 步:准备图像 token 与 3D 位置编码
核心方法:
-
prepare_location -
img_3d_position_embedding
这一步把多相机视觉特征变成带有 3D 几何意识的 token 表示。
第 3 步:准备 agent/map query 与 reference points
这一步定义模型要"找什么东西"。
-
agent_query:候选动态目标 token -
map_query:候选地图元素 token
对 DETR 系背景的工程师来说,这和 object query 很像;但这里是多任务统一版。
第 4 步:temporal alignment
python
agent_query, map_query, ... = self.temporal_alignment(...)
这一步把历史 memory 和当前 query 对齐起来,是 streaming 处理的关键。
第 5 步:预解码
通过 agent_prep_decoder / map_prep_decoder,让初始 query 先和视觉 token 做一次较轻量交互。
预解码之后,模型会生成初步预测:
-
初始 agent class / bbox
-
初始 map class / polyline
-
初始 motion trajectory
-
初始 ego planning trajectory
第 6 步:ego planning query 构造
ego query 来自多个信号拼接:
-
自车动力学特征
ego_lcf_feat -
自车历史轨迹
ego_his_trajs -
规划命令
ego_fut_cmd
这一点非常值得新工程师记住:端到端规划并不是只看图像,它还明确使用了自车状态和导航命令。
第 7 步:进入统一 transformer 主干
python
self.transformer(...)
这是整篇论文"统一化"的核心落点:agent、map、ego 三类 query 在同一个主干里共同演化。
第 8 步:收集结果并写回 memory
通过 post_update_memory 把当前帧的重要表示写回历史记忆,供下一帧使用。
然后把每一层 decoder 的输出拼成:
-
all_cls_scores -
all_bbox_preds -
all_traj_preds -
map_all_cls_scores -
ego_fut_preds_fix_time
这就是后面 loss 计算和推理输出的基础。
四、memory 相关方法
reset_memory
场景切换时调用,清空历史记忆。
pre_update_memory
在当前帧计算前,先把 memory 做时序推进和有效位处理。
post_update_memory
在当前帧计算后,把新的 agent/map/ego 信息写入记忆池。
temporal_alignment
作用是把当前 query 和历史 memory 对齐,使不同帧之间能在统一参考系下交互。
五、目标分配与 loss
detection target
-
_get_target_single -
get_targets
map target
-
_map_get_target_single -
map_get_targets
planning loss
loss_planning
总 loss
loss
你可以把这个类理解成同时实现了:
-
DETR-style matching for detection
-
vector map matching
-
multimodal motion loss
-
ego planning regression/classification loss
六、推理输出
get_bboxes
这个方法把内部预测结果变成更接近 benchmark 输出的数据结构:
-
3D boxes
-
scores
-
labels
-
trajectories
-
map vectors
阅读这个文件的重点问题
建议带着下面 5 个问题读:
-
query 是怎么初始化的?
-
历史 memory 是怎么进入当前帧的?
-
图像 token 的 3D 位置编码怎么构造?
-
为什么 detection/map/planning 可以共享一个主干?
-
最终 loss 是怎样把多任务绑在一起训练的?
读完后下一步看什么
接下来去看 drivetransformer_layers.py。因为你已经知道 head 在调用统一 transformer,但还不知道每一层 transformer 究竟在做什么。
drivetransformer_layers.py 解读
这个文件是做什么的
这是 DriveTransformer 的统一 transformer 层定义文件。
如果说 drivetransformer_head.py 是大脑总成,那么这个文件就是"大脑皮层的神经回路实现"。论文里提到的三类统一操作,主要都在这里落实:
-
task self-attention
-
temporal cross-attention
-
sensor cross-attention
核心类总览
AttentionLayer
最底层的通用注意力模块。
SwiGLULayer
FFN 模块,使用 SwiGLU 激活结构。
DriveTransformerPreDecoderLayer
预解码层,负责 query 的初步视觉对齐。
DriveTransformerDecoderLayer
主解码层,是论文 unified operation 的核心实现。
DriveTransformerDecoder
把多层 DriveTransformerDecoderLayer 串起来,并在每层输出中做迭代 refinement。
DriveTransformerWrapper
对 decoder 做一个更工程化的包装。
DriveTransformerPreDecoder
预解码 sequence 包装类。
一、AttentionLayer
这是这个文件里最底层也最关键的通用积木。
它做什么
输入:
-
query
-
key
-
value
-
可选位置编码
-
可选 attention mask
输出:
- attention 更新后的 token
它的价值
同一套 AttentionLayer 可以被不同场景复用:
-
query 看 query:task self-attention
-
query 看 memory:temporal cross-attention
-
query 看 image token:sensor cross-attention
这也是为什么论文能强调"统一操作"。
二、SwiGLULayer
这是 transformer 中的 FFN 层变体。
相比传统 ReLU FFN,SwiGLU 通常更适合大模型和 transformer,表达能力更好。
三、DriveTransformerPreDecoderLayer
预解码层的目标是:
-
在正式多层联合推理之前
-
让初始 query 先从视觉 token 中吸收一轮信息
它更像"query 热启动器"。
四、DriveTransformerDecoderLayer:整篇论文最重要的层级实现
这是这个文件里最值得认真读的类。
它的输入是什么
它同时接收:
-
当前 agent/map/ego query
-
当前图像 token
-
历史 agent/map/ego memory
-
各类位置编码
这已经体现出 DriveTransformer 的核心野心:把多任务 + 多传感器 + 多时间放进同一个 layer 里。
operation_order
这个字段非常关键。当前配置中大致是:
-
task_self_attn -
norm -
temporal_cross_attn -
norm -
sensor_cross_attn -
norm -
ffn -
norm
这就是论文思路的工程化顺序:
先让任务 token 彼此交互,再让当前帧对历史记忆交互,再让 query 去读取图像特征,最后过 FFN。
task_self_attn
这里的重点不是普通 self-attention,而是不同任务 query 之间能否互相看见。
代码里通过 attention mask 控制:
-
agent 看哪些 token
-
map 看哪些 token
-
ego 看哪些 token
这一步决定了"统一但不混乱"的信息交换边界。
temporal_cross_attentions
这里把当前 query 分成三组:
-
agent
-
map
-
ego
然后分别和自己的历史 memory 做 cross-attention。
这非常符合自动驾驶实际需求,因为:
-
agent 历史用于目标跟踪与轨迹连续性
-
map 历史用于拓扑与稳定性
-
ego 历史用于规划连续性
sensor_cross_attn
这一步让 query 去读取图像 token。
本质上这是当前帧感知信息进入统一 query 表征的关键接口。
ffn
最后对 agent/map/ego 各自过独立 FFN。说明作者在共享 attention 的同时,保留了任务内部的独立非线性变换能力。
五、DriveTransformerDecoder
这个类把多层 decoder layer 串起来,并负责每层迭代 refinement。
它做的核心事情
-
根据当前 reference points 构造位置编码
-
调用每一层 decoder layer
-
用各类输出头更新:
-
detection box
-
class
-
map point/polyline
-
motion trajectory
-
ego planning trajectory
- 保存每层中间结果
这本质上和 DETR 系方法的 iterative refinement 很像,但这里是多任务版本。
六、DriveTransformerWrapper
这是一个轻包装器,主要作用是:
-
包装 decoder
-
初始化权重
-
暴露统一 forward 接口
从阅读优先级来说,它低于 DriveTransformerDecoderLayer 和 DriveTransformerDecoder。
新工程师读这个文件时最容易卡住的点
1. 为什么会有这么多位置编码
因为这里不是单一检测任务,而是同时处理:
-
agent box 位置
-
map polyline 位置
-
ego trajectory 位置
-
图像 token 位置
-
历史 memory 位置
2. 为什么 query 要拆成 agent/map/ego 三段
因为三类任务虽然共享主干,但各自语义不同、reference form 不同、输出头不同。
3. 为什么同时要有 self-attn 和两种 cross-attn
因为 DriveTransformer 要同时解决:
-
当前任务间协同
-
当前与历史对齐
-
当前与视觉观测对齐
读完后下一步看什么
下一步看 dataset:bench2drive_drivetransformer_dataset.py。因为你已经知道模型如何处理 query 了,接下来应该知道这些 query 的监督和输入到底从哪里来。
bench2drive_drivetransformer_dataset.py 解读
这个文件是做什么的
这是 DriveTransformer 在 Bench2Drive 数据集上的数据适配器。
它的任务远不止"读取图片和标注",还包括:
-
组织多相机输入
-
构造 ego 历史 / 未来轨迹
-
构造 agent 未来属性标签
-
提取在线地图标注
-
维护 clip 级缓存
-
把时序样本打包成模型需要的格式
为什么这个文件很重要
端到端自动驾驶项目里,dataset 往往决定了模型最终"看到了什么监督"。
很多论文读起来只说:输入图像、输出规划;但工程里真正困难的是把:
-
多相机几何
-
自车状态
-
未来轨迹
-
地图折线
-
路径命令
全部拼成统一训练样本。这个文件就在做这件事。
顶层类:B2D_DriveTransformer_Dataset
它继承自 Custom3DDataset,但实际已经做了大量自动驾驶场景特化。
一、__init__
构造函数主要做四类事:
-
保存配置参数
-
加载地图信息
-
初始化 route/clip 缓存
-
生成 sequence grouping flag
为什么要做 clip 缓存
Bench2Drive 数据按 route/clip 组织。如果每次取样本都直接重新读整段 pkl,I/O 会很慢。
所以这里实现了按 route 的缓存机制,减少重复读盘。
这对大规模时序训练非常重要。
二、序列分组:_set_sequence_group_flag*
这两个函数负责给每个样本打"属于哪个序列"的标记。
为什么需要这个?
因为时序训练时,相邻样本之间是否连续、是否来自同一 clip,非常重要。
三、缓存读取:get_data_by_index
这是 dataset 工程优化的关键函数。
它会:
-
先检查当前 route cache
-
再检查历史缓存
-
如果都没命中,再从磁盘读取 route pkl
-
用 LRU 风格策略更新缓存
这个函数很有"工业界自动驾驶训练数据系统"的味道,而不是普通学术 demo。
四、样本准备主入口
prepare_train_data
训练样本生成主入口:
-
get_data_info(index)获取结构化原始输入 -
执行 pipeline
-
补充 map 标注
-
统一打包成训练输入
prepare_test_data
测试样本版本,与训练类似,但不会做训练阶段的空样本回退逻辑。
五、union2one
这个函数很关键,它把队列样本合成当前模型需要的单条输入格式。
它尤其做了两件重要事情:
-
把多帧图像 stack 起来
-
把
can_bus等时序元信息转成相对位姿
这一步相当于把"连续帧序列"整理成"当前帧可消费的时序包"。
六、get_data_info
这是全文件最值得重点读的方法之一。
它会构造模型真正使用的输入字典,包括:
-
图像路径
-
lidar2img -
cam_intrinsic -
cam_extrinsic -
ego_pose -
ego_pose_inv -
can_bus -
ego_lcf_feat -
ego_fut_cmd -
ego 历史与未来轨迹
这里体现了几个非常重要的工程思想
1. 统一几何坐标系
代码里反复在 world / ego / lidar / camera 坐标系之间转换。这是自动驾驶工程的基础功。
2. 规划命令不是离散 ID,而是结构化表示
command_far / command_near 不仅做 one-hot,还把目标点坐标做位置编码。
这比只给一个左转/右转标签更强,因为带入了几何位置信息。
3. ego 特征是显式输入
模型不只看图像,还显式使用:
-
速度
-
加速度
-
角速度
-
车身尺寸
-
转向
这对规划任务非常合理。
七、地图标注:get_map_info
这个方法把城镇地图中的 lane / trigger volume 转换成当前局部坐标系下的 polyline 标注。
核心过程
-
先粗筛一定范围内的地图元素
-
再投到当前 lidar 坐标系
-
裁剪到有效感兴趣区域
-
组织成
LiDARInstanceLines
这说明作者把地图任务建模为 vector map prediction,而不是像素级栅格地图分割。
八、目标框标注:get_ann_info
这里会:
-
过滤空点目标
-
做类别映射
-
生成
LiDARInstance3DBoxes -
附加 agent future attribute labels
这里的 attr_labels 很重要,因为它不只是检测框标签,还包含运动相关监督。
九、ego 轨迹相关函数
get_ego_past_trajs
生成自车历史轨迹的相对位移序列。
get_ego_future_trajs
生成固定时间间隔的未来轨迹。
get_ego_future_trajs_fix_dis
生成固定距离间隔的未来轨迹,适合规划任务。
这两个版本一起存在,说明作者希望规划输出兼顾时间一致性和空间一致性。
十、agent 未来属性:get_box_attr_labels
这是很容易被忽视、但非常有价值的函数。
它会对每个当前帧目标生成未来运动监督,包括:
-
future offset trajectory
-
valid mask
-
goal direction class
-
局部几何特征
-
yaw offset
这一步把普通检测框扩展成了"可做运动预测的目标表示"。
十一、__getitem__
普通 PyTorch dataset 入口。
如果是训练模式:
-
随机增强
-
准备训练样本
-
如果样本无效,重新采样
如果是测试模式:
- 直接准备测试样本
新工程师最应该从这个文件学到什么
-
自动驾驶 dataset 远不止读图片和标注。
-
坐标系转换是系统核心能力。
-
ego 规划、agent motion、map prediction 的监督都要在 dataset 层仔细构造。
-
工业级时序训练必须考虑缓存、route 连续性和样本打包方式。
读完整条链路之后,你对项目应该形成的整体图景
-
配置文件定义系统蓝图
-
train.py负责启动工程 -
drivetransformer.py负责最外层训练/推理调度 -
drivetransformer_head.py负责统一多任务建模 -
drivetransformer_layers.py负责 unified transformer 细节 -
bench2drive_drivetransformer_dataset.py负责把数据和监督喂给模型
到这里,你就已经掌握了这个项目最关键的主链路。