PyTorch 转 ONNX 实用教程
1. 模型转换使用场景
ONNX 是一个用于机器学习模型的开放格式,旨在解决不同框架之间的模型互操作性和跨平台部署问题。它是一个中间表示格式,当你使用一个框架训练模型,但需要在另一个不同的框架中运行它时,可以使用 ONNX 进行转换。或者当你需要将模型部署到不同的硬件设备上时,例如从云端的 GPU 迁移到边缘设备的 CPU 时,ONNX 可以提供一个通用的部署桥梁。
2. pytorch转onnx的流程
2.1 环境准备
- 安装依赖:
pip install torch onnx onnxruntime onnxsim - 可选(量化/优化):
pip install onnxruntime-gpu onnxruntime-tools onnxruntime-extensions
2.2 导出方法 torch.onnx.export(经典方式)
关键要点
- 调用前设置
model.eval(),避免 Dropout/BatchNorm 行为不一致。 - 导出时提供"样例输入",可指定动态维度
dynamic_axes。 - 选择合适
opset_version(常用 13/17),并用 ONNX Runtime 对齐输出。
代码示例
以下代码包含:1)pytorch模型转onnx模型部分 2)模型验证部分
python
import torch
import torch.nn as nn
import onnx, onnxruntime as ort
import numpy as np
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 16, 3, padding=1)
self.pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(16, 10)
def forward(self, x):
x = torch.relu(self.conv(x))
x = self.pool(x).flatten(1)
return self.fc(x)
model = Net().eval()
dummy = torch.randn(1, 3, 224, 224) # 样例输入
torch.onnx.export(
model, dummy, "model.onnx",
input_names=["input"], output_names=["logits"],
opset_version=13, do_constant_folding=True,
dynamic_axes={"input": {0: "batch", 2: "height", 3: "width"},
"logits": {0: "batch"}}
)
# 结构检查
onnx.checker.check_model(onnx.load("model.onnx"))
# 结果校验(与 PyTorch 对齐)
with torch.no_grad():
torch_out = model(dummy).cpu().numpy()
sess = ort.InferenceSession("model.onnx", providers=["CPUExecutionProvider"])
ort_out, = sess.run(None, {"input": dummy.cpu().numpy()})
print("max abs diff:", np.max(np.abs(torch_out - ort_out)))
多输入/多输出
python
y1, y2 = model(x1, x2)
torch.onnx.export(
model, (x1, x2), "m.onnx",
input_names=["x1", "x2"], output_names=["y1", "y2"],
dynamic_axes={"x1": {0: "batch"}, "x2": {0: "batch"},
"y1": {0: "batch"}, "y2": {0: "batch"}}
)
校验与对齐输出
- 结构校验:
onnx.checker.check_model(onnx.load("model.onnx")) - 数值对齐:用相同输入在 PyTorch 与 ORT 推理,比较
np.allclose(torch_out, ort_out, rtol=1e-3, atol=1e-5);不同算子/后端可能需放宽阈值。 - 注意:输入需在 CPU 且为
numpy;混合精度需在两侧一致。
3. 常见问题与建议
参数说明
opset_version: 13 兼容性好;若需要新算子/特性可选 17+,确保推理引擎支持。do_constant_folding=True: 常量折叠,简化图。export_params=True: 将权重打包进 ONNX(默认)。- 大模型(>2GB):
use_external_data_format=True,并指定all_tensors_to_one_file=True控制权重落盘方式。 - 训练/推理模式:确保
model.eval();如确需训练图,设置training=torch.onnx.TrainingMode.TRAINING(受限)。
常见问题
- 未设
eval():导出或对齐误差大;务必model.eval()。 - 控制流/数据依赖:经典导出可能失败或数值偏差;优先试
dynamo_export。 - 不支持的算子:升级
torch/onnxruntime,尝试更高opset;或替换为等价算子。 - 上采样/插值:建议使用明确的
size或scale_factor,并确保align_corners与后端一致;优选opset>=11。 - 自定义算子:经典导出可通过
register_custom_op_symbolic提供 symbolic;部署侧需实现同域自定义算子。 - 设备/精度不一致:确保导出与验证使用相同 dtype(FP32/FP16)和数据布局;验证输入转为 CPU。
- Inplace 操作:个别 in-place 可能阻碍导出,改用 out-of-place 或升级导出器。
版本与兼容性建议
- 首选
opset_version=13以获得跨后端的稳健兼容;若后端版本较新且需要特性(如更优 Resize/Conv 行为),再选择更高 opset。 - 保持
torch和onnxruntime同步更新;导出失败优先尝试升级与切换导出器(dynamo_export↔export)。