PyTorch 转 ONNX 实用教程

PyTorch 转 ONNX 实用教程

1. 模型转换使用场景

ONNX 是一个用于机器学习模型的开放格式,旨在解决不同框架之间的模型互操作性和跨平台部署问题。它是一个中间表示格式,当你使用一个框架训练模型,但需要在另一个不同的框架中运行它时,可以使用 ONNX 进行转换。或者当你需要将模型部署到不同的硬件设备上时,例如从云端的 GPU 迁移到边缘设备的 CPU 时,ONNX 可以提供一个通用的部署桥梁。

2. pytorch转onnx的流程

2.1 环境准备

  • 安装依赖:pip install torch onnx onnxruntime onnxsim
  • 可选(量化/优化):pip install onnxruntime-gpu onnxruntime-tools onnxruntime-extensions

2.2 导出方法 torch.onnx.export(经典方式)

关键要点

  • 调用前设置 model.eval(),避免 Dropout/BatchNorm 行为不一致。
  • 导出时提供"样例输入",可指定动态维度 dynamic_axes
  • 选择合适 opset_version(常用 13/17),并用 ONNX Runtime 对齐输出。

代码示例

以下代码包含:1)pytorch模型转onnx模型部分 2)模型验证部分

python 复制代码
import torch
import torch.nn as nn
import onnx, onnxruntime as ort
import numpy as np

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 16, 3, padding=1)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(16, 10)
    def forward(self, x):
        x = torch.relu(self.conv(x))
        x = self.pool(x).flatten(1)
        return self.fc(x)

model = Net().eval()
dummy = torch.randn(1, 3, 224, 224)  # 样例输入

torch.onnx.export(
    model, dummy, "model.onnx",
    input_names=["input"], output_names=["logits"],
    opset_version=13, do_constant_folding=True,
    dynamic_axes={"input": {0: "batch", 2: "height", 3: "width"},
                  "logits": {0: "batch"}}
)

# 结构检查
onnx.checker.check_model(onnx.load("model.onnx"))

# 结果校验(与 PyTorch 对齐)
with torch.no_grad():
    torch_out = model(dummy).cpu().numpy()
sess = ort.InferenceSession("model.onnx", providers=["CPUExecutionProvider"])
ort_out, = sess.run(None, {"input": dummy.cpu().numpy()})
print("max abs diff:", np.max(np.abs(torch_out - ort_out)))

多输入/多输出

python 复制代码
y1, y2 = model(x1, x2)
torch.onnx.export(
    model, (x1, x2), "m.onnx",
    input_names=["x1", "x2"], output_names=["y1", "y2"],
    dynamic_axes={"x1": {0: "batch"}, "x2": {0: "batch"},
                  "y1": {0: "batch"}, "y2": {0: "batch"}}
)

校验与对齐输出

  • 结构校验:onnx.checker.check_model(onnx.load("model.onnx"))
  • 数值对齐:用相同输入在 PyTorch 与 ORT 推理,比较 np.allclose(torch_out, ort_out, rtol=1e-3, atol=1e-5);不同算子/后端可能需放宽阈值。
  • 注意:输入需在 CPU 且为 numpy;混合精度需在两侧一致。

3. 常见问题与建议

参数说明

  • opset_version: 13 兼容性好;若需要新算子/特性可选 17+,确保推理引擎支持。
  • do_constant_folding=True: 常量折叠,简化图。
  • export_params=True: 将权重打包进 ONNX(默认)。
  • 大模型(>2GB):use_external_data_format=True,并指定 all_tensors_to_one_file=True 控制权重落盘方式。
  • 训练/推理模式:确保 model.eval();如确需训练图,设置 training=torch.onnx.TrainingMode.TRAINING(受限)。

常见问题

  • 未设 eval():导出或对齐误差大;务必 model.eval()
  • 控制流/数据依赖:经典导出可能失败或数值偏差;优先试 dynamo_export
  • 不支持的算子:升级 torch/onnxruntime,尝试更高 opset;或替换为等价算子。
  • 上采样/插值:建议使用明确的 sizescale_factor,并确保 align_corners 与后端一致;优选 opset>=11
  • 自定义算子:经典导出可通过 register_custom_op_symbolic 提供 symbolic;部署侧需实现同域自定义算子。
  • 设备/精度不一致:确保导出与验证使用相同 dtype(FP32/FP16)和数据布局;验证输入转为 CPU。
  • Inplace 操作:个别 in-place 可能阻碍导出,改用 out-of-place 或升级导出器。

版本与兼容性建议

  • 首选 opset_version=13 以获得跨后端的稳健兼容;若后端版本较新且需要特性(如更优 Resize/Conv 行为),再选择更高 opset。
  • 保持 torchonnxruntime 同步更新;导出失败优先尝试升级与切换导出器(dynamo_exportexport)。
相关推荐
梦梦代码精11 小时前
为什么这个开源的AI平台会火?有点东西。。。
人工智能·算法·机器学习·docker·开源
大模型真好玩11 小时前
智能体从入门到精通:6个必学GitHub开源项目
人工智能·agent·deepseek
极客笔记Jack12 小时前
Scanpy AnnData 对象深度解析:高效操作数据结构的10个技巧
python
源图客12 小时前
Aitoearn:OPC(一人公司)的AI内容智能体
人工智能·dreamweaver
颜酱12 小时前
LangChain调用向量模型,存入向量数据库
python·langchain
逸模12 小时前
AI+BIM 重构连锁公装新范式 逸模打造数字化营建核心底座
大数据·人工智能·笔记·其他·信息可视化·重构
2501_9289455212 小时前
七本性全面签名体系:从互递归类型到∞-范畴生成语法
python
phltxy12 小时前
MCP 从协议到 Spring AI 实战
人工智能·spring·oracle
Sirius Wu13 小时前
Agentic端到端&分离式RL技术建设
人工智能·深度学习·机器学习·caffe
AI导出鸭PC端13 小时前
智谱清言怎么生成word文档?AI导出鸭终结乱码烦恼
人工智能·ai·c#·word·豆包·ai导出鸭