PyTorch 复杂模型转 ONNX 踩坑纪实:从 diff 到 nan_to_num 的三关突破

1. 背景

最近在将一个用 PyTorch 实现的睡眠分期模型(基于多尺度 CNN、Transformer 与注意力机制,结构较复杂)导出为 ONNX 格式,以便在不同框架中部署推理。模型代码包含大量自定义算子与模块,导出过程中接连遇到了三个 ONNX 不支持的算子错误,最终通过逐步替换和升级 opset 版本解决了问题。本文将完整记录这一过程,并提供通用的排查思路与解决方案。

2. 导出环境

  • PyTorch 1.10+

  • ONNX opset 11(初始)

  • Python 3.7

  • 模型输入:三个参数,分别为 (batch, epochs, 1, samples) 的 EEG 数据、(batch, epochs, samples) 的辅助特征以及一个字符串(实际未使用)

  • 导出代码简化如下:

    (myenv) ga@lab406:/mnt/DEV2/ga/remix_eog_self_supervise/depoly$ python -c "import torch; print('PyTorch:', torch.version); import onnx; print('onnx:', onnx.version); import onnxruntime; print('onnxruntime:', onnxruntime.version); import einops; print('einops:', einops.version)"
    PyTorch: 1.11.0+cu113
    onnx: 1.14.1
    onnxruntime: 1.14.1
    einops: 0.6.1

复制代码
model = FeatureExtractor_Deepseek_MLA()`
`model.eval()`

`x1 = torch.randn(1,` `200,` `1,` `3000)`
`x2 = torch.randn(1,` `200,` `3000)`
`x3 =` `""`

`torch.onnx.export(`
`    model,`
    `(x1, x2, x3),`
    `"my_model.onnx",`
`    input_names=["input1",` `"input2",` `"input3"],`
`    output_names=["output"],`
`    opset_version=11`
`)`
`

3. 第一关: torch.diff 不被 opset 11 支持

3.1 错误信息

复制代码
RuntimeError: Exporting the operator diff to ONNX opset version 11 is not supported.`
`

3.2 错误定位

出现在自定义函数 find_extrema_3d_corrected 中:

复制代码
def` `find_extrema_3d_corrected(data, enhance_factor=10.0,` `...):`
`    data_flat = data.reshape(-1, F)`
`    diff = torch.diff(data_flat, dim=1)`    `# ❌ 这行报错`
    `...`
`

3.3 解决方案

torch.diff 实际上就是相邻元素相减,可以直接用切片改写,无需依赖更高 opset 版本:

复制代码
# 替换为`
`diff = data_flat[:,` `1:]` `- data_flat[:,` `:-1]`   `# 形状 [N, F-1]`
`

该替换完全等价,且对任何 ONNX opset 版本都友好。

4. 第二关: torch.triu 要求 opset ≥ 14

4.1 错误信息

复制代码
RuntimeError: Exporting the operator triu to ONNX opset version 11 is not supported.`
`Support for this operator was added in version 14, try exporting with this version.`
`

4.2 错误定位

出现在 MultiheadDiffAttn 的 forward 方法中,用于构造因果注意力掩码:

复制代码
attn_mask = torch.triu(`
`    torch.zeros([tgt_len, src_len]).float().fill_(float("-inf")).type_as(attn_weights),`
    `1` `+ offset,`
`)`
`

4.3 解决方案

最直接的解法:将导出时的 opset_version 提升至 14

因为 ONNX 从 opset 14 开始原生支持 Trilu 算子(含 triu 和 tril)。

修改导出代码:

复制代码
torch.onnx.export(`
    `...,`
`    opset_version=14`   `# ✅ 升级到这里`
`)`
`

备选方案 :若因环境限制不能升级 opset,也可用基础运算手动生成上三角掩码:

复制代码
mask = torch.ones(tgt_len, src_len, device=attn_weights.device)`
`mask = torch.tril(mask, diagonal=offset)`       `# 下三角`
`attn_mask =` `(1` `- mask)` `*` `float('-inf')`         `# 上三角部分置 -inf`
`

但注意 tril 同样需要 opset 14,所以该方案与升级 opset 本质相同。

5. 第三关: torch.nan_to_num 在 opset 14 中仍不被支持

5.1 错误信息

复制代码
RuntimeError: Exporting the operator nan_to_num to ONNX opset version 14 is not supported.`
`

5.2 错误定位

同样在 MultiheadDiffAttn.forward 中:

复制代码
attn_weights = torch.nan_to_num(attn_weights)`   `# ❌ 这行报错`
`

当 softmax 的输入全部为 -inf 时,输出会产生 NaN,此处原意是将其替换为 0。

5.3 解决方案

用 torch.where 与 torch.isnan 手动实现相同功能:

复制代码
# 替换 nan_to_num,仅处理 NaN`
`attn_weights = torch.where(`
`    torch.isnan(attn_weights),`
`    torch.zeros_like(attn_weights),`
`    attn_weights`
`)`
`

若也要处理无穷值(实际上此处不会出现),可继续添加:

复制代码
attn_weights = torch.where(`
`    torch.isinf(attn_weights)` `&` `(attn_weights >` `0),`
`    torch.full_like(attn_weights,` `float('inf')),`
`    attn_weights`
`)`
`attn_weights = torch.where(`
`    torch.isinf(attn_weights)` `&` `(attn_weights <` `0),`
`    torch.full_like(attn_weights,` `float('-inf')),`
`    attn_weights`
`)`
`

这些基础算子(torch.where、torch.isnan、torch.isinf)在 opset 11 即已支持,因此完美兼容。

6. 导出成功

经过以上三步修改后,再次运行导出脚本,控制台输出:

复制代码
✅ 你的模型已成功转为 ONNX!`
`

导出成功生成的 my_model.onnx 文件可直接用于 ONNX Runtime 推理。

7. 其他注意事项与警告解读

导出过程中还出现了一些 TracerWarning 和 UserWarning,虽然不影响功能,但值得关注:

  • TracerWarning: Converting a tensor to a Python boolean

例如 if F > 1 会在 Trace 时固定为常量。若推理时输入的特征维度不变则无碍,否则应避免此类 Python 条件判断。

  • floordiv is deprecated

PyTorch 新版改进了整数除法行为,当前代码仍能正确运行,可忽略或改用 torch.div(..., rounding_mode='trunc')。

  • LSTM 初始状态警告

模型包含 LSTM,导出时初始状态 h0/c0 被固化。若推理时 batch size 固定为导出时的值则安全;否则需将初始状态作为模型输入。

  • 高级索引警告

aten::index 被拆分为多个 ONNX 算子,需确保索引中没有负数,否则结果可能错误。

快速验证 ONNX 模型:

复制代码
import onnxruntime`
`import numpy as np`

`session = onnxruntime.InferenceSession("my_model.onnx")`
`out = session.run(None,` `{`
    `"input1": np.random.randn(1,` `200,` `1,` `3000).astype(np.float32),`
    `"input2": np.random.randn(1,` `200,` `3000).astype(np.float32),`
    `"input3": np.array("")`  `# 字符串可能不支持,实际未使用时可改传零张量`
`})`
`print("Output shape:", out[0].shape)`
`

8. 总结

复杂 PyTorch 模型导出 ONNX 时,核心挑战在于算子兼容性。本次踩坑经验可归纳为:

  1. 优先升级 ONNX opset 版本 ,新版本支持更多算子(如 triu 需要 opset 14)。
  2. 对于无法通过升级解决的算子 (如 diff、nan_to_num),用等价的基础操作替换。
  3. 善用 PyTorch 的 torch.onnx 文档 查看各算子的支持情况。
  4. 导出后务必用 onnx.checker 和实际数据验证 模型结构和输出正确性。

通过这三关,我们成功将一个包含 CNN、LSTM、Multi-head Attention、稀疏注意力等组件的混合模型导出为 ONNX。希望本文能为遇到类似问题的读者提供清晰的解决思路。

附录:主要错误与修改速查表

|------------------|---------------------------|--------------------------------------|
| 错误算子 | 出现位置 | 最终解决方案 |
| torch.diff | find_extrema_3d_corrected | 切片相减:x:,1:-x:,:-1 |
| torch.triu | MultiheadDiffAttn.forward | 升级 opset 至 14 |
| torch.nan_to_num | MultiheadDiffAttn.forward | 用 torch.where(torch.isnan(...), ...) |

相关推荐
不爱吃糖の糖糖5 小时前
RAG 06:RAG 多路召回与检索优化策略详解
人工智能·embedding
Xuantong_905 小时前
玄同科技亮相2026金砖新工业革命展览会,智启全球合作新篇
大数据·人工智能
python在学ing5 小时前
Django框架学习笔记:从零基础到项目实战
数据库·python·django·sqlite
曾经我也有梦想5 小时前
机器学习入门(四):三种学习方式 + 数据从原料到模型
人工智能
PAK向日葵6 小时前
从零实现 Python 虚拟机(二):S.A.A.U.S.O 的总体架构设计
c++·python
独自归家的兔6 小时前
AI界的 GitHub?Hugging Face 全面解析
人工智能·github
程序员小远6 小时前
系统性能指标全解析
自动化测试·软件测试·python·测试工具·职场和发展·测试用例·性能测试
一次旅行6 小时前
全场景AI智能体工作台WorkBuddy实战操作详解
人工智能
逻辑君6 小时前
Foresight研究报告【20260009】
人工智能