1. 背景
最近在将一个用 PyTorch 实现的睡眠分期模型(基于多尺度 CNN、Transformer 与注意力机制,结构较复杂)导出为 ONNX 格式,以便在不同框架中部署推理。模型代码包含大量自定义算子与模块,导出过程中接连遇到了三个 ONNX 不支持的算子错误,最终通过逐步替换和升级 opset 版本解决了问题。本文将完整记录这一过程,并提供通用的排查思路与解决方案。
2. 导出环境
-
PyTorch 1.10+
-
ONNX opset 11(初始)
-
Python 3.7
-
模型输入:三个参数,分别为 (batch, epochs, 1, samples) 的 EEG 数据、(batch, epochs, samples) 的辅助特征以及一个字符串(实际未使用)
-
导出代码简化如下:
(myenv) ga@lab406:/mnt/DEV2/ga/remix_eog_self_supervise/depoly$ python -c "import torch; print('PyTorch:', torch.version); import onnx; print('onnx:', onnx.version); import onnxruntime; print('onnxruntime:', onnxruntime.version); import einops; print('einops:', einops.version)"
PyTorch: 1.11.0+cu113
onnx: 1.14.1
onnxruntime: 1.14.1
einops: 0.6.1
model = FeatureExtractor_Deepseek_MLA()`
`model.eval()`
`x1 = torch.randn(1,` `200,` `1,` `3000)`
`x2 = torch.randn(1,` `200,` `3000)`
`x3 =` `""`
`torch.onnx.export(`
` model,`
`(x1, x2, x3),`
`"my_model.onnx",`
` input_names=["input1",` `"input2",` `"input3"],`
` output_names=["output"],`
` opset_version=11`
`)`
`
3. 第一关: torch.diff 不被 opset 11 支持
3.1 错误信息
RuntimeError: Exporting the operator diff to ONNX opset version 11 is not supported.`
`
3.2 错误定位
出现在自定义函数 find_extrema_3d_corrected 中:
def` `find_extrema_3d_corrected(data, enhance_factor=10.0,` `...):`
` data_flat = data.reshape(-1, F)`
` diff = torch.diff(data_flat, dim=1)` `# ❌ 这行报错`
`...`
`
3.3 解决方案
torch.diff 实际上就是相邻元素相减,可以直接用切片改写,无需依赖更高 opset 版本:
# 替换为`
`diff = data_flat[:,` `1:]` `- data_flat[:,` `:-1]` `# 形状 [N, F-1]`
`
该替换完全等价,且对任何 ONNX opset 版本都友好。
4. 第二关: torch.triu 要求 opset ≥ 14
4.1 错误信息
RuntimeError: Exporting the operator triu to ONNX opset version 11 is not supported.`
`Support for this operator was added in version 14, try exporting with this version.`
`
4.2 错误定位
出现在 MultiheadDiffAttn 的 forward 方法中,用于构造因果注意力掩码:
attn_mask = torch.triu(`
` torch.zeros([tgt_len, src_len]).float().fill_(float("-inf")).type_as(attn_weights),`
`1` `+ offset,`
`)`
`
4.3 解决方案
最直接的解法:将导出时的 opset_version 提升至 14 。
因为 ONNX 从 opset 14 开始原生支持 Trilu 算子(含 triu 和 tril)。
修改导出代码:
torch.onnx.export(`
`...,`
` opset_version=14` `# ✅ 升级到这里`
`)`
`
备选方案 :若因环境限制不能升级 opset,也可用基础运算手动生成上三角掩码:
mask = torch.ones(tgt_len, src_len, device=attn_weights.device)`
`mask = torch.tril(mask, diagonal=offset)` `# 下三角`
`attn_mask =` `(1` `- mask)` `*` `float('-inf')` `# 上三角部分置 -inf`
`
但注意 tril 同样需要 opset 14,所以该方案与升级 opset 本质相同。
5. 第三关: torch.nan_to_num 在 opset 14 中仍不被支持
5.1 错误信息
RuntimeError: Exporting the operator nan_to_num to ONNX opset version 14 is not supported.`
`
5.2 错误定位
同样在 MultiheadDiffAttn.forward 中:
attn_weights = torch.nan_to_num(attn_weights)` `# ❌ 这行报错`
`
当 softmax 的输入全部为 -inf 时,输出会产生 NaN,此处原意是将其替换为 0。
5.3 解决方案
用 torch.where 与 torch.isnan 手动实现相同功能:
# 替换 nan_to_num,仅处理 NaN`
`attn_weights = torch.where(`
` torch.isnan(attn_weights),`
` torch.zeros_like(attn_weights),`
` attn_weights`
`)`
`
若也要处理无穷值(实际上此处不会出现),可继续添加:
attn_weights = torch.where(`
` torch.isinf(attn_weights)` `&` `(attn_weights >` `0),`
` torch.full_like(attn_weights,` `float('inf')),`
` attn_weights`
`)`
`attn_weights = torch.where(`
` torch.isinf(attn_weights)` `&` `(attn_weights <` `0),`
` torch.full_like(attn_weights,` `float('-inf')),`
` attn_weights`
`)`
`
这些基础算子(torch.where、torch.isnan、torch.isinf)在 opset 11 即已支持,因此完美兼容。
6. 导出成功
经过以上三步修改后,再次运行导出脚本,控制台输出:
✅ 你的模型已成功转为 ONNX!`
`
导出成功生成的 my_model.onnx 文件可直接用于 ONNX Runtime 推理。
7. 其他注意事项与警告解读
导出过程中还出现了一些 TracerWarning 和 UserWarning,虽然不影响功能,但值得关注:
- TracerWarning: Converting a tensor to a Python boolean
例如 if F > 1 会在 Trace 时固定为常量。若推理时输入的特征维度不变则无碍,否则应避免此类 Python 条件判断。
- floordiv is deprecated
PyTorch 新版改进了整数除法行为,当前代码仍能正确运行,可忽略或改用 torch.div(..., rounding_mode='trunc')。
- LSTM 初始状态警告
模型包含 LSTM,导出时初始状态 h0/c0 被固化。若推理时 batch size 固定为导出时的值则安全;否则需将初始状态作为模型输入。
- 高级索引警告
aten::index 被拆分为多个 ONNX 算子,需确保索引中没有负数,否则结果可能错误。
快速验证 ONNX 模型:
import onnxruntime`
`import numpy as np`
`session = onnxruntime.InferenceSession("my_model.onnx")`
`out = session.run(None,` `{`
`"input1": np.random.randn(1,` `200,` `1,` `3000).astype(np.float32),`
`"input2": np.random.randn(1,` `200,` `3000).astype(np.float32),`
`"input3": np.array("")` `# 字符串可能不支持,实际未使用时可改传零张量`
`})`
`print("Output shape:", out[0].shape)`
`
8. 总结
复杂 PyTorch 模型导出 ONNX 时,核心挑战在于算子兼容性。本次踩坑经验可归纳为:
- 优先升级 ONNX opset 版本 ,新版本支持更多算子(如 triu 需要 opset 14)。
- 对于无法通过升级解决的算子 (如 diff、nan_to_num),用等价的基础操作替换。
- 善用 PyTorch 的 torch.onnx 文档 查看各算子的支持情况。
- 导出后务必用 onnx.checker 和实际数据验证 模型结构和输出正确性。
通过这三关,我们成功将一个包含 CNN、LSTM、Multi-head Attention、稀疏注意力等组件的混合模型导出为 ONNX。希望本文能为遇到类似问题的读者提供清晰的解决思路。
附录:主要错误与修改速查表
|------------------|---------------------------|--------------------------------------|
| 错误算子 | 出现位置 | 最终解决方案 |
| torch.diff | find_extrema_3d_corrected | 切片相减:x:,1:-x:,:-1 |
| torch.triu | MultiheadDiffAttn.forward | 升级 opset 至 14 |
| torch.nan_to_num | MultiheadDiffAttn.forward | 用 torch.where(torch.isnan(...), ...) |