从PyTorch `.pth` 导出 ONNX图文件

PyTorch .pth 导出 ONNX:带 TensorRT Plugin 的工程化路线(以 MultiScaleDeformableAttnTRT 为例)

在纯 CNN/ResNet 这类模型里,torch.onnx.export() 往往是一行命令的事;但到了 UniAD / Deformable DETR / BEV 时序模型这种级别,事情立刻变复杂:

  • 模型里有 CUDA 扩展算子(例如 Multi-Scale Deformable Attention)
  • 计算图里带 时序状态(prev_bev、track instances、timestamp 等)
  • 并且部署目标通常是 TensorRT ,想要性能就得用 TRT plugin

这篇文章就从「.pth 到底是什么」开始,讲到「导出 ONNX 的稳定流程」,再深入到「遇到 TRT plugin 算子怎么做」以及你给出的两种插件节点(MultiScaleDeformableAttnTRT / MSDAPlugin)分别在干什么。


1. .pth ≠ 计算图,.onnx 才是"图 + 权重"

先澄清一个误区:
.pth 只是 checkpoint ,通常是 state_dict(权重参数)+ meta(训练信息),它并不包含一张可部署的静态计算图。

.onnx 是静态计算图(节点、边、属性)+ 权重常量,它是为了跨框架/跨后端而设计的部署格式。

所以"导出 ONNX"这件事,本质不是把 .pth 文件格式转换一下,而是:

加载模型结构 + 加载权重 → 用一组输入跑通 forward → 让导出器把 forward 的算子序列变成 ONNX 图


2. 通用导出流程(简单模型/复杂模型都适用)

Step A:构建模型并 load checkpoint

典型步骤如下:

  1. 根据 config 构建模型结构
  2. load_checkpoint(model, pth) 加载权重
  3. 切推理态:model.cuda(); model.eval()

注意:eval() 很关键,BN/Dropout 行为不一致会让导出图不稳定。


Step B:准备导出用输入(dummy inputs)

导出 ONNX 的输入必须满足三点:

  • shape 对得上:forward 里所有 reshape/view 都要能跑通
  • dtype 对得上:fp16/fp32/int32 不能乱
  • device 对得上:很多 CUDA 扩展要求输入必须在 GPU 上

复杂模型(时序 BEV、tracking)还会多一个要求:

  • 状态要"滚起来":prev_bev、prev_tracks 等不是全 0 就能正常走图,你往往需要先跑几帧把状态打通,然后再 export。

Step C:调用 torch.onnx.export()

核心参数里,带 plugin 的导出最关键是这两个:

  • opset_version=16(根据后端支持选)
  • operator_export_type=ONNX_FALLTHROUGH(重点:让不支持的算子"原样保留"为自定义节点)

另外常见搭配是:

  • do_constant_folding=False:复杂图/插件图里常关掉,避免导出器过度折叠导致后端解析出问题
  • dynamic_axes={...}:把希望动态的维度标注出来(例如 track 数 N)

3. 遇到 TensorRT plugin 算子:三条路线怎么选?

当模型里出现 "ONNX 标准算子无法表达/后端不支持"的运算(比如 MS-DeformAttn),通常有三条路:

路线 1:导出自定义 ONNX 节点 + TensorRT plugin 接管(推荐)

这就是当前代码走的路:

  • PyTorch 导出时,把复杂算子变成一个自定义 op 节点
  • TensorRT 解析 ONNX 时,遇到该节点就用 plugin 来执行

优点:性能好、图简洁、工程上可控

缺点:要保证 TRT 侧 plugin 编译/注册链路正确


路线 2:把算子改写成 ONNX 原生子图

理论上可行,工程上通常"很痛":

  • 图会爆炸式变大
  • 性能通常明显不如 plugin
  • 容易出现数值误差或行为差异

路线 3:换后端(ORT 自定义算子 / 直接 PyTorch runtime)

如果不执着 TRT,可以走 ORT 或 PyTorch fallback,但你这套脚本明显目标是 TRT,所以主线还是路线 1。


4. 你的 MultiScaleDeformableAttnTRT:为什么能导出成 ONNX 自定义节点?

核心代码是一个 torch.autograd.Function

python 复制代码
class _MultiScaleDeformableAttnFunction(Function):
    @staticmethod
    def symbolic(g, value, value_spatial_shapes, reference_points, sampling_offsets, attention_weights):
        return g.op("MultiScaleDeformableAttnTRT", value, value_spatial_shapes, reference_points, sampling_offsets, attention_weights)

这段 symbolic() 是关键中的关键:
导出 ONNX 时,不会执行 forward 的 CUDA 扩展逻辑,而是调用 symbolic 来"画图"。

也就是说,只要 forward 路径里调用了:

python 复制代码
_multi_scale_deformable_attn_gpu = _MultiScaleDeformableAttnFunction.apply

那么导出 ONNX 时就会出现一个节点:

  • op_type = MultiScaleDeformableAttnTRT
  • inputs = 传入的 5 个 tensor

最终 ONNX 图里不是一堆 scatter/gather/reshape 的展开实现,而是一个"壳节点",交给 TensorRT plugin 实现。


forward() 里到底干了什么?(理解插件输入输出必须看懂)

你 forward 的逻辑可以拆成三步:

1) reshape/view,把 offsets / weights 变成算子需要的布局

例如把 sampling_offsets reshape 成:

  • [bs, num_queries, num_heads, num_level, ..., 2]
2) 计算 sampling_locations(reference + normalized offsets)

核心就是:

  • sampling_locations = reference_points + sampling_offsets / offset_normalizer

其中 offset_normalizer 来自 value_spatial_shapes(每层 feature map 的 (h,w)),用于把 offset 归一化到统一坐标尺度。

3) 调 CUDA 扩展做真正 attention 聚合

最后调用:

python 复制代码
ext_module.ms_deform_attn_forward(...)

输出再 reshape 成 [bs, num_queries, num_heads, channel]

另外还做了两件工程常见事:

  • fp16 输入时内部临时转 fp32 算(取决于 ext 的数值/实现)
  • 计算 level_start_index(多层 feature map 在 value tensor 里的分段起点)

5. _MSDAPlugin / MSDAPlugin:为什么又多一套插件节点?

你还定义了:

python 复制代码
return g.op("MSDAPlugin", value, spatial_shapes, level_start_index, sampling_locations, attention_weights)

它和 MultiScaleDeformableAttnTRT 的差异点非常明确:

MultiScaleDeformableAttnTRT(大插件)

  • 输入:reference_pointssampling_offsets 等"更原始"的输入
  • 插件内部可能需要负责更多逻辑(或你希望插件把所有重活都吃掉)

MSDAPlugin(小插件)

  • 输入更"展开":你直接把 sampling_locationslevel_start_index 这些中间结果也传进 plugin
  • 好处:前处理(reshape/normalize/softmax)都可以在 ONNX 图里用标准算子实现
  • 插件只负责最重的聚合核(attention 聚合)

这在工程部署里很常见:
把能用 ONNX 表达的部分留在图里,把最难/最慢的一坨用 plugin 接管。


6. export_onnx.py:按执行顺序拆解一遍

这份脚本不是"加载 → export",而是一个典型的时序导出套路:

先跑几帧,把状态滚起来;最后一帧切 forward 并导出。

下面按关键点解释。


6.1 输入/输出契约:input_shapes / output_shapes / dynamic_axes

定义了非常多输入:

  • prev_track_instances0~13
  • prev_bev
  • timestamp / prev_timestamp
  • l2g_r_mat / l2g_t(位姿)
  • img / img_metas_xxx
  • use_prev_bev / max_obj_id / command 等

并且定义了输出:

  • 新的 prev_track_instances*_out
  • bev_embed
  • bboxes / scores / labels ...
  • outs_planning 等

同时给 dynamic_axes 标注了很多 track 的第 0 维动态,这一点对部署非常重要:

  • tracking 数量不是常数
  • 不标动态就可能把 engine 固死在某个 N 上

6.2 关键循环:for iid in range(6) 先滚状态

逻辑是:

  • 每帧构造 inputs:

    • 如果磁盘有 npy 就加载(真实输入/真实历史状态)
    • 否则就构造(如 image_shape、prev_bev、max_obj_id 等)
  • 然后跑:

    python 复制代码
    dummy_outputs = model.forward_uniad_trt(*inputs)
  • 再把输出保存成下一帧的 prev_* 输入(落盘 npy)

这里最关键的三条状态链是:

  1. prev_bev 链:下一帧的 prev_bev 来自上一帧的 bev_embed
  2. prev_track_instances 链:下一帧的 prev_tracks 来自上一帧输出
  3. 位姿/时间链:prev_l2g_*/prev_timestamp 连续传递

对(2,7,10)这些"不会进 ONNX 图"的 track 输入做了 dummy 填充,这是为了:

  • PyTorch forward 要求参数齐全
  • 但导出 ONNX 时这些变量不会被使用(所以可以给"形状正确"的假输入)

6.3 真正导出发生在 iid == 5:三件大事

iid == 5

1) 把 forward 切到部署路径
python 复制代码
model.forward = model.forward_uniad_trt

这一步决定了:

导出器捕获的是 forward_uniad_trt 的图,而不是训练/测试时的默认 forward。

2) torch.onnx.export() + FALLTHROUGH

你导出调用的关键参数组合是:

  • opset_version=16
  • operator_export_type=ONNX_FALLTHROUGH
  • do_constant_folding=False
  • dynamic_axes=dynamic_axes

这套组合基本就是"带 TRT plugin 的 ONNX 导出"常见配置。

3) onnx-graphsurgeon 修 Reshape allowzero

导出后你又修了一遍:

python 复制代码
if node.op == "Reshape":
    node.attrs["allowzero"] = 1

这通常是为了规避某些工具链对 reshape 里 "0 维度语义" 的解析差异(尤其不同版本 TRT/Parser/Graph 优化器之间很容易踩坑)。


7. 部署到 TensorRT:最常见的坑位清单(强烈建议逐条自查)

7.1 导出成功但 ONNX 没有 plugin 节点

原因往往是:

  • 你的导出 forward 路径里没走到 Function.apply
  • 或者被某些条件分支绕开了

建议:

  • 用 Netron / onnx.helper 看图里是否存在 MultiScaleDeformableAttnTRTMSDAPlugin 节点

7.2 TensorRT 解析 ONNX 报 "unsupported op / no importer"

这说明:

  • ONNX 里确实有自定义节点
  • 但 TRT 侧没找到对应 plugin creator

解决方向:

  • 确认 plugin .so 被加载
  • 确认 plugin name/namespace/version 与 ONNX 节点 op_type 匹配
  • 确认 plugin registry 在构建 engine 前已注册

7.3 FP16/FP32 混乱导致 build 失败或精度异常

  • 你的 PyTorch 侧 forward 对 fp16 做了 cast 逻辑
  • TRT plugin 是否支持 fp16、输入 dtype 是否一致,需要严格对齐

建议:

  • 明确一条主线:全 fp32 或明确开启 fp16(并保证 plugin 实现一致支持)

8. 一个最小"可复用"导出模板(带自定义算子保留)

如果你想把复杂脚本拆成最小骨架,核心就是:

python 复制代码
import torch
from torch.onnx import OperatorExportTypes

model.cuda().eval()
dummy_inputs = (...)  # tuple of cuda tensors

torch.onnx.export(
    model,
    dummy_inputs,
    "model.onnx",
    opset_version=16,
    input_names=[...],
    output_names=[...],
    dynamic_axes={...},
    do_constant_folding=False,
    operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
)

只要自定义 Function.symbolic() 写对了,且 forward 路径能走到,就能在 ONNX 图里保留自定义节点,让 TensorRT plugin 接管执行。


模型导出与权重加载

python 复制代码
model.forward = model.forward_uniad_trt
...
torch.onnx.export(model, inputs, ...)

model.forward = model.forward_uniad_trt 这一步本身完全不涉及"加载权重" 。它做的事情只有一个:把模型实例 modelforward() 方法指针,替换成另一个前向函数实现 ,让后面 torch.onnx.export() 捕获的是 forward_uniad_trt 这条图。

权重是在更早的阶段、通过 load_checkpoint(...) 进入 model 的参数里了。


1) 权重真正加载发生在哪里?

在你贴的 export_onnx.py 里,权重加载是这几行:

python 复制代码
model = build_model(cfg.model, test_cfg=cfg.get('test_cfg'))
...
checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')

这里的 args.checkpoint 就是 .pth 文件路径。

load_checkpoint(model, ...) 会把 checkpoint 里的 state_dict 写进 model 的各个参数张量(Parameters / Buffers)里,比如 conv 的 weight、bn 的 running_mean/running_var、transformer 的 projection 权重等等。

从这一刻开始,model 已经"带权重"了;后面无论你调用 model.forward(...) 还是 model.forward_uniad_trt(...),它们用的都是同一套参数。

可以理解为:

  • load_checkpoint 改的是 model.parameters() 里的值
  • model.forward = ... 改的是 怎么用这些参数做计算(走哪条前向路径)

2) 为什么 forward 替换不会触发再加载?

因为 PyTorch 模型的"权重"不是放在 forward 里的,也不是 forward 调用时临时读取 .pth 文件的。

PyTorch 的权重(Parameter)是 nn.Module 的成员变量(例如 self.conv.weight),一旦 load_state_dict 成功,这些张量就常驻在内存/GPU 显存中。

所以:

  • 你替换 forward 只是换了函数入口
  • forward_uniad_trt 仍然会引用 self.xxx 里的参数
  • 不会发生再次 IO / 再次 load checkpoint

3) 这一步的真实作用:导出"部署专用前向图"

你脚本里这么写:

python 复制代码
model.forward = model.forward_uniad_trt
torch.onnx.export(model, inputs, ...)

等价于告诉导出器:

"请把 model(inputs) 时执行的那条路径,固定成 forward_uniad_trt(*inputs),并把它转成 ONNX。"

这通常用于:

  • 绕开训练时 forward 的控制流(loss、gt、augmentation 等)
  • 绕开 Python 后处理(NMS、decode)或让输出更适合 TRT
  • 保证导出图里包含你想要的 plugin 节点(例如 MSDA)

4) 常见误会:map_location='cpu' 会不会导致导出时"没权重"?

不会。

代码里是:

python 复制代码
checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
model = model.cuda()

这表示:

  • checkpoint 先被读到 CPU(避免直接在加载时占用 GPU)
  • load_checkpoint 把权重写进 model(此时 model 可能仍在 CPU)
  • 之后 model.cuda() 把参数整体搬到 GPU

导出时权重就在 GPU 上正常参与计算和图捕获。


一句话总结:

权重是在 load_checkpoint(model, args.checkpoint) 时加载进 model 的;
model.forward = model.forward_uniad_trt 只是切换前向路径,不会也不需要再加载权重。

相关推荐
AKAMAI2 小时前
Akamai Cloud客户案例 | 全球教育科技公司TalentSprint依托Akamai云计算服务实现八倍增长并有效控制成本
人工智能·云计算
蛋王派2 小时前
GME-多模态嵌入 训练和工程落地的逻辑解析
人工智能
写代码的【黑咖啡】2 小时前
Python 中的 Requests 库:轻松进行 HTTP 请求
开发语言·python·http
栗子叶2 小时前
Spring 中 Servlet 容器和 Python FastAPI 对比
python·spring·servlet·fastapi
Duang007_2 小时前
拆解 Transformer 的灵魂:全景解析 Attention 家族 (Self, Cross, Masked & GQA)
人工智能·深度学习·transformer
磊-3 小时前
AI Agent 学习计划(一)
人工智能·学习
杨杨杨大侠3 小时前
DeepAgents 框架深度解析:从理论到实践的智能代理架构
后端·python·llm
不会打球的摄影师不是好程序员3 小时前
dify实战-个人知识库搭建
人工智能
袁袁袁袁满3 小时前
Python读取doc文件打印内容
开发语言·python·python读取doc文件
xixixi777773 小时前
对 两种不同AI范式——Transformer 和 LSTM 进行解剖和对比
人工智能·深度学习·大模型·lstm·transformer·智能·前沿