PyTorch 转 ONNX 实用教程

PyTorch 转 ONNX 实用教程

1. 模型转换使用场景

ONNX 是一个用于机器学习模型的开放格式,旨在解决不同框架之间的模型互操作性和跨平台部署问题。它是一个中间表示格式,当你使用一个框架训练模型,但需要在另一个不同的框架中运行它时,可以使用 ONNX 进行转换。或者当你需要将模型部署到不同的硬件设备上时,例如从云端的 GPU 迁移到边缘设备的 CPU 时,ONNX 可以提供一个通用的部署桥梁。

2. pytorch转onnx的流程

2.1 环境准备

  • 安装依赖:pip install torch onnx onnxruntime onnxsim
  • 可选(量化/优化):pip install onnxruntime-gpu onnxruntime-tools onnxruntime-extensions

2.2 导出方法 torch.onnx.export(经典方式)

关键要点

  • 调用前设置 model.eval(),避免 Dropout/BatchNorm 行为不一致。
  • 导出时提供"样例输入",可指定动态维度 dynamic_axes
  • 选择合适 opset_version(常用 13/17),并用 ONNX Runtime 对齐输出。

代码示例

以下代码包含:1)pytorch模型转onnx模型部分 2)模型验证部分

python 复制代码
import torch
import torch.nn as nn
import onnx, onnxruntime as ort
import numpy as np

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 16, 3, padding=1)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(16, 10)
    def forward(self, x):
        x = torch.relu(self.conv(x))
        x = self.pool(x).flatten(1)
        return self.fc(x)

model = Net().eval()
dummy = torch.randn(1, 3, 224, 224)  # 样例输入

torch.onnx.export(
    model, dummy, "model.onnx",
    input_names=["input"], output_names=["logits"],
    opset_version=13, do_constant_folding=True,
    dynamic_axes={"input": {0: "batch", 2: "height", 3: "width"},
                  "logits": {0: "batch"}}
)

# 结构检查
onnx.checker.check_model(onnx.load("model.onnx"))

# 结果校验(与 PyTorch 对齐)
with torch.no_grad():
    torch_out = model(dummy).cpu().numpy()
sess = ort.InferenceSession("model.onnx", providers=["CPUExecutionProvider"])
ort_out, = sess.run(None, {"input": dummy.cpu().numpy()})
print("max abs diff:", np.max(np.abs(torch_out - ort_out)))

多输入/多输出

python 复制代码
y1, y2 = model(x1, x2)
torch.onnx.export(
    model, (x1, x2), "m.onnx",
    input_names=["x1", "x2"], output_names=["y1", "y2"],
    dynamic_axes={"x1": {0: "batch"}, "x2": {0: "batch"},
                  "y1": {0: "batch"}, "y2": {0: "batch"}}
)

校验与对齐输出

  • 结构校验:onnx.checker.check_model(onnx.load("model.onnx"))
  • 数值对齐:用相同输入在 PyTorch 与 ORT 推理,比较 np.allclose(torch_out, ort_out, rtol=1e-3, atol=1e-5);不同算子/后端可能需放宽阈值。
  • 注意:输入需在 CPU 且为 numpy;混合精度需在两侧一致。

3. 常见问题与建议

参数说明

  • opset_version: 13 兼容性好;若需要新算子/特性可选 17+,确保推理引擎支持。
  • do_constant_folding=True: 常量折叠,简化图。
  • export_params=True: 将权重打包进 ONNX(默认)。
  • 大模型(>2GB):use_external_data_format=True,并指定 all_tensors_to_one_file=True 控制权重落盘方式。
  • 训练/推理模式:确保 model.eval();如确需训练图,设置 training=torch.onnx.TrainingMode.TRAINING(受限)。

常见问题

  • 未设 eval():导出或对齐误差大;务必 model.eval()
  • 控制流/数据依赖:经典导出可能失败或数值偏差;优先试 dynamo_export
  • 不支持的算子:升级 torch/onnxruntime,尝试更高 opset;或替换为等价算子。
  • 上采样/插值:建议使用明确的 sizescale_factor,并确保 align_corners 与后端一致;优选 opset>=11
  • 自定义算子:经典导出可通过 register_custom_op_symbolic 提供 symbolic;部署侧需实现同域自定义算子。
  • 设备/精度不一致:确保导出与验证使用相同 dtype(FP32/FP16)和数据布局;验证输入转为 CPU。
  • Inplace 操作:个别 in-place 可能阻碍导出,改用 out-of-place 或升级导出器。

版本与兼容性建议

  • 首选 opset_version=13 以获得跨后端的稳健兼容;若后端版本较新且需要特性(如更优 Resize/Conv 行为),再选择更高 opset。
  • 保持 torchonnxruntime 同步更新;导出失败优先尝试升级与切换导出器(dynamo_exportexport)。
相关推荐
老蒋新思维1 小时前
创客匠人峰会复盘:AI 赋能 IP 创新增长,知识变现的 4 大实战路径与跨行业案例
大数据·网络·人工智能·tcp/ip·创始人ip·创客匠人·知识变现
AI-嘉文哥哥1 小时前
ADAS自动驾驶-前车碰撞预警(追尾预警、碰撞检测)系统
人工智能·深度学习·yolo·目标检测·数据分析·课程设计·qt5
ManageEngineITSM1 小时前
IT 资产扫描工具与企业服务台的数字化底层价值
大数据·运维·人工智能·itsm·工单系统
skywalk81631 小时前
智能营养食谱平台 - 项目创意策划书
人工智能
酷柚易汛智推官1 小时前
从“废歌”到热单:Mureka新模型改写AIGC音乐的产业规则
人工智能·aigc·酷柚易汛
海底的星星fly1 小时前
【Prompt学习技能树地图】LangChain原理及应用操作指南
人工智能·语言模型·langchain·prompt
撷思、1 小时前
python+flask学习
python·flask
free-elcmacom1 小时前
机器学习入门<2>决策树算法
人工智能·python·机器学习
lxmyzzs1 小时前
vLLM、SGLang 与 TensorRT-LLM 综合对比分析报告
人工智能·自然语言处理