PyTorch .pth 导出 ONNX:带 TensorRT Plugin 的工程化路线(以 MultiScaleDeformableAttnTRT 为例)
在纯 CNN/ResNet 这类模型里,torch.onnx.export() 往往是一行命令的事;但到了 UniAD / Deformable DETR / BEV 时序模型这种级别,事情立刻变复杂:
- 模型里有 CUDA 扩展算子(例如 Multi-Scale Deformable Attention)
- 计算图里带 时序状态(prev_bev、track instances、timestamp 等)
- 并且部署目标通常是 TensorRT ,想要性能就得用 TRT plugin
这篇文章就从「.pth 到底是什么」开始,讲到「导出 ONNX 的稳定流程」,再深入到「遇到 TRT plugin 算子怎么做」以及你给出的两种插件节点(MultiScaleDeformableAttnTRT / MSDAPlugin)分别在干什么。
1. .pth ≠ 计算图,.onnx 才是"图 + 权重"
先澄清一个误区:
.pth 只是 checkpoint ,通常是 state_dict(权重参数)+ meta(训练信息),它并不包含一张可部署的静态计算图。
而 .onnx 是静态计算图(节点、边、属性)+ 权重常量,它是为了跨框架/跨后端而设计的部署格式。
所以"导出 ONNX"这件事,本质不是把 .pth 文件格式转换一下,而是:
加载模型结构 + 加载权重 → 用一组输入跑通 forward → 让导出器把 forward 的算子序列变成 ONNX 图
2. 通用导出流程(简单模型/复杂模型都适用)
Step A:构建模型并 load checkpoint
典型步骤如下:
- 根据 config 构建模型结构
load_checkpoint(model, pth)加载权重- 切推理态:
model.cuda(); model.eval()
注意:
eval()很关键,BN/Dropout 行为不一致会让导出图不稳定。
Step B:准备导出用输入(dummy inputs)
导出 ONNX 的输入必须满足三点:
- shape 对得上:forward 里所有 reshape/view 都要能跑通
- dtype 对得上:fp16/fp32/int32 不能乱
- device 对得上:很多 CUDA 扩展要求输入必须在 GPU 上
复杂模型(时序 BEV、tracking)还会多一个要求:
- 状态要"滚起来":prev_bev、prev_tracks 等不是全 0 就能正常走图,你往往需要先跑几帧把状态打通,然后再 export。
Step C:调用 torch.onnx.export()
核心参数里,带 plugin 的导出最关键是这两个:
opset_version=16(根据后端支持选)operator_export_type=ONNX_FALLTHROUGH(重点:让不支持的算子"原样保留"为自定义节点)
另外常见搭配是:
do_constant_folding=False:复杂图/插件图里常关掉,避免导出器过度折叠导致后端解析出问题dynamic_axes={...}:把希望动态的维度标注出来(例如 track 数 N)
3. 遇到 TensorRT plugin 算子:三条路线怎么选?
当模型里出现 "ONNX 标准算子无法表达/后端不支持"的运算(比如 MS-DeformAttn),通常有三条路:
路线 1:导出自定义 ONNX 节点 + TensorRT plugin 接管(推荐)
这就是当前代码走的路:
- PyTorch 导出时,把复杂算子变成一个自定义 op 节点
- TensorRT 解析 ONNX 时,遇到该节点就用 plugin 来执行
优点:性能好、图简洁、工程上可控
缺点:要保证 TRT 侧 plugin 编译/注册链路正确
路线 2:把算子改写成 ONNX 原生子图
理论上可行,工程上通常"很痛":
- 图会爆炸式变大
- 性能通常明显不如 plugin
- 容易出现数值误差或行为差异
路线 3:换后端(ORT 自定义算子 / 直接 PyTorch runtime)
如果不执着 TRT,可以走 ORT 或 PyTorch fallback,但你这套脚本明显目标是 TRT,所以主线还是路线 1。
4. 你的 MultiScaleDeformableAttnTRT:为什么能导出成 ONNX 自定义节点?
核心代码是一个 torch.autograd.Function:
python
class _MultiScaleDeformableAttnFunction(Function):
@staticmethod
def symbolic(g, value, value_spatial_shapes, reference_points, sampling_offsets, attention_weights):
return g.op("MultiScaleDeformableAttnTRT", value, value_spatial_shapes, reference_points, sampling_offsets, attention_weights)
这段 symbolic() 是关键中的关键:
导出 ONNX 时,不会执行 forward 的 CUDA 扩展逻辑,而是调用 symbolic 来"画图"。
也就是说,只要 forward 路径里调用了:
python
_multi_scale_deformable_attn_gpu = _MultiScaleDeformableAttnFunction.apply
那么导出 ONNX 时就会出现一个节点:
- op_type =
MultiScaleDeformableAttnTRT - inputs = 传入的 5 个 tensor
最终 ONNX 图里不是一堆 scatter/gather/reshape 的展开实现,而是一个"壳节点",交给 TensorRT plugin 实现。
forward() 里到底干了什么?(理解插件输入输出必须看懂)
你 forward 的逻辑可以拆成三步:
1) reshape/view,把 offsets / weights 变成算子需要的布局
例如把 sampling_offsets reshape 成:
[bs, num_queries, num_heads, num_level, ..., 2]
2) 计算 sampling_locations(reference + normalized offsets)
核心就是:
sampling_locations = reference_points + sampling_offsets / offset_normalizer
其中 offset_normalizer 来自 value_spatial_shapes(每层 feature map 的 (h,w)),用于把 offset 归一化到统一坐标尺度。
3) 调 CUDA 扩展做真正 attention 聚合
最后调用:
python
ext_module.ms_deform_attn_forward(...)
输出再 reshape 成 [bs, num_queries, num_heads, channel]。
另外还做了两件工程常见事:
- fp16 输入时内部临时转 fp32 算(取决于 ext 的数值/实现)
- 计算
level_start_index(多层 feature map 在 value tensor 里的分段起点)
5. _MSDAPlugin / MSDAPlugin:为什么又多一套插件节点?
你还定义了:
python
return g.op("MSDAPlugin", value, spatial_shapes, level_start_index, sampling_locations, attention_weights)
它和 MultiScaleDeformableAttnTRT 的差异点非常明确:
MultiScaleDeformableAttnTRT(大插件)
- 输入:
reference_points、sampling_offsets等"更原始"的输入 - 插件内部可能需要负责更多逻辑(或你希望插件把所有重活都吃掉)
MSDAPlugin(小插件)
- 输入更"展开":你直接把
sampling_locations、level_start_index这些中间结果也传进 plugin - 好处:前处理(reshape/normalize/softmax)都可以在 ONNX 图里用标准算子实现
- 插件只负责最重的聚合核(attention 聚合)
这在工程部署里很常见:
把能用 ONNX 表达的部分留在图里,把最难/最慢的一坨用 plugin 接管。
6. export_onnx.py:按执行顺序拆解一遍
这份脚本不是"加载 → export",而是一个典型的时序导出套路:
先跑几帧,把状态滚起来;最后一帧切 forward 并导出。
下面按关键点解释。
6.1 输入/输出契约:input_shapes / output_shapes / dynamic_axes
定义了非常多输入:
- prev_track_instances0~13
- prev_bev
- timestamp / prev_timestamp
- l2g_r_mat / l2g_t(位姿)
- img / img_metas_xxx
- use_prev_bev / max_obj_id / command 等
并且定义了输出:
- 新的 prev_track_instances*_out
- bev_embed
- bboxes / scores / labels ...
- outs_planning 等
同时给 dynamic_axes 标注了很多 track 的第 0 维动态,这一点对部署非常重要:
- tracking 数量不是常数
- 不标动态就可能把 engine 固死在某个 N 上
6.2 关键循环:for iid in range(6) 先滚状态
逻辑是:
-
每帧构造 inputs:
- 如果磁盘有 npy 就加载(真实输入/真实历史状态)
- 否则就构造(如 image_shape、prev_bev、max_obj_id 等)
-
然后跑:
pythondummy_outputs = model.forward_uniad_trt(*inputs) -
再把输出保存成下一帧的 prev_* 输入(落盘 npy)
这里最关键的三条状态链是:
- prev_bev 链:下一帧的 prev_bev 来自上一帧的 bev_embed
- prev_track_instances 链:下一帧的 prev_tracks 来自上一帧输出
- 位姿/时间链:prev_l2g_*/prev_timestamp 连续传递
对(2,7,10)这些"不会进 ONNX 图"的 track 输入做了 dummy 填充,这是为了:
- PyTorch forward 要求参数齐全
- 但导出 ONNX 时这些变量不会被使用(所以可以给"形状正确"的假输入)
6.3 真正导出发生在 iid == 5:三件大事
当 iid == 5:
1) 把 forward 切到部署路径
python
model.forward = model.forward_uniad_trt
这一步决定了:
导出器捕获的是 forward_uniad_trt 的图,而不是训练/测试时的默认 forward。
2) torch.onnx.export() + FALLTHROUGH
你导出调用的关键参数组合是:
opset_version=16operator_export_type=ONNX_FALLTHROUGHdo_constant_folding=Falsedynamic_axes=dynamic_axes
这套组合基本就是"带 TRT plugin 的 ONNX 导出"常见配置。
3) onnx-graphsurgeon 修 Reshape allowzero
导出后你又修了一遍:
python
if node.op == "Reshape":
node.attrs["allowzero"] = 1
这通常是为了规避某些工具链对 reshape 里 "0 维度语义" 的解析差异(尤其不同版本 TRT/Parser/Graph 优化器之间很容易踩坑)。
7. 部署到 TensorRT:最常见的坑位清单(强烈建议逐条自查)
7.1 导出成功但 ONNX 没有 plugin 节点
原因往往是:
- 你的导出 forward 路径里没走到
Function.apply - 或者被某些条件分支绕开了
建议:
- 用 Netron / onnx.helper 看图里是否存在
MultiScaleDeformableAttnTRT或MSDAPlugin节点
7.2 TensorRT 解析 ONNX 报 "unsupported op / no importer"
这说明:
- ONNX 里确实有自定义节点
- 但 TRT 侧没找到对应 plugin creator
解决方向:
- 确认 plugin
.so被加载 - 确认 plugin name/namespace/version 与 ONNX 节点 op_type 匹配
- 确认 plugin registry 在构建 engine 前已注册
7.3 FP16/FP32 混乱导致 build 失败或精度异常
- 你的 PyTorch 侧 forward 对 fp16 做了 cast 逻辑
- TRT plugin 是否支持 fp16、输入 dtype 是否一致,需要严格对齐
建议:
- 明确一条主线:全 fp32 或明确开启 fp16(并保证 plugin 实现一致支持)
8. 一个最小"可复用"导出模板(带自定义算子保留)
如果你想把复杂脚本拆成最小骨架,核心就是:
python
import torch
from torch.onnx import OperatorExportTypes
model.cuda().eval()
dummy_inputs = (...) # tuple of cuda tensors
torch.onnx.export(
model,
dummy_inputs,
"model.onnx",
opset_version=16,
input_names=[...],
output_names=[...],
dynamic_axes={...},
do_constant_folding=False,
operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
)
只要自定义 Function.symbolic() 写对了,且 forward 路径能走到,就能在 ONNX 图里保留自定义节点,让 TensorRT plugin 接管执行。
模型导出与权重加载
python
model.forward = model.forward_uniad_trt
...
torch.onnx.export(model, inputs, ...)
model.forward = model.forward_uniad_trt 这一步本身完全不涉及"加载权重" 。它做的事情只有一个:把模型实例 model 的 forward() 方法指针,替换成另一个前向函数实现 ,让后面 torch.onnx.export() 捕获的是 forward_uniad_trt 这条图。
权重是在更早的阶段、通过 load_checkpoint(...) 进入 model 的参数里了。
1) 权重真正加载发生在哪里?
在你贴的 export_onnx.py 里,权重加载是这几行:
python
model = build_model(cfg.model, test_cfg=cfg.get('test_cfg'))
...
checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
这里的 args.checkpoint 就是 .pth 文件路径。
load_checkpoint(model, ...) 会把 checkpoint 里的 state_dict 写进 model 的各个参数张量(Parameters / Buffers)里,比如 conv 的 weight、bn 的 running_mean/running_var、transformer 的 projection 权重等等。
从这一刻开始,model 已经"带权重"了;后面无论你调用 model.forward(...) 还是 model.forward_uniad_trt(...),它们用的都是同一套参数。
可以理解为:
load_checkpoint改的是 model.parameters() 里的值model.forward = ...改的是 怎么用这些参数做计算(走哪条前向路径)
2) 为什么 forward 替换不会触发再加载?
因为 PyTorch 模型的"权重"不是放在 forward 里的,也不是 forward 调用时临时读取 .pth 文件的。
PyTorch 的权重(Parameter)是 nn.Module 的成员变量(例如 self.conv.weight),一旦 load_state_dict 成功,这些张量就常驻在内存/GPU 显存中。
所以:
- 你替换 forward 只是换了函数入口
forward_uniad_trt仍然会引用self.xxx里的参数- 不会发生再次 IO / 再次 load checkpoint
3) 这一步的真实作用:导出"部署专用前向图"
你脚本里这么写:
python
model.forward = model.forward_uniad_trt
torch.onnx.export(model, inputs, ...)
等价于告诉导出器:
"请把
model(inputs)时执行的那条路径,固定成forward_uniad_trt(*inputs),并把它转成 ONNX。"
这通常用于:
- 绕开训练时 forward 的控制流(loss、gt、augmentation 等)
- 绕开 Python 后处理(NMS、decode)或让输出更适合 TRT
- 保证导出图里包含你想要的 plugin 节点(例如 MSDA)
4) 常见误会:map_location='cpu' 会不会导致导出时"没权重"?
不会。
代码里是:
python
checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
model = model.cuda()
这表示:
- checkpoint 先被读到 CPU(避免直接在加载时占用 GPU)
load_checkpoint把权重写进 model(此时 model 可能仍在 CPU)- 之后
model.cuda()把参数整体搬到 GPU
导出时权重就在 GPU 上正常参与计算和图捕获。
一句话总结:
权重是在
load_checkpoint(model, args.checkpoint)时加载进model的;
model.forward = model.forward_uniad_trt只是切换前向路径,不会也不需要再加载权重。