PyTorch 转 ONNX 实用教程

PyTorch 转 ONNX 实用教程

1. 模型转换使用场景

ONNX 是一个用于机器学习模型的开放格式,旨在解决不同框架之间的模型互操作性和跨平台部署问题。它是一个中间表示格式,当你使用一个框架训练模型,但需要在另一个不同的框架中运行它时,可以使用 ONNX 进行转换。或者当你需要将模型部署到不同的硬件设备上时,例如从云端的 GPU 迁移到边缘设备的 CPU 时,ONNX 可以提供一个通用的部署桥梁。

2. pytorch转onnx的流程

2.1 环境准备

  • 安装依赖:pip install torch onnx onnxruntime onnxsim
  • 可选(量化/优化):pip install onnxruntime-gpu onnxruntime-tools onnxruntime-extensions

2.2 导出方法 torch.onnx.export(经典方式)

关键要点

  • 调用前设置 model.eval(),避免 Dropout/BatchNorm 行为不一致。
  • 导出时提供"样例输入",可指定动态维度 dynamic_axes
  • 选择合适 opset_version(常用 13/17),并用 ONNX Runtime 对齐输出。

代码示例

以下代码包含:1)pytorch模型转onnx模型部分 2)模型验证部分

python 复制代码
import torch
import torch.nn as nn
import onnx, onnxruntime as ort
import numpy as np

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 16, 3, padding=1)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(16, 10)
    def forward(self, x):
        x = torch.relu(self.conv(x))
        x = self.pool(x).flatten(1)
        return self.fc(x)

model = Net().eval()
dummy = torch.randn(1, 3, 224, 224)  # 样例输入

torch.onnx.export(
    model, dummy, "model.onnx",
    input_names=["input"], output_names=["logits"],
    opset_version=13, do_constant_folding=True,
    dynamic_axes={"input": {0: "batch", 2: "height", 3: "width"},
                  "logits": {0: "batch"}}
)

# 结构检查
onnx.checker.check_model(onnx.load("model.onnx"))

# 结果校验(与 PyTorch 对齐)
with torch.no_grad():
    torch_out = model(dummy).cpu().numpy()
sess = ort.InferenceSession("model.onnx", providers=["CPUExecutionProvider"])
ort_out, = sess.run(None, {"input": dummy.cpu().numpy()})
print("max abs diff:", np.max(np.abs(torch_out - ort_out)))

多输入/多输出

python 复制代码
y1, y2 = model(x1, x2)
torch.onnx.export(
    model, (x1, x2), "m.onnx",
    input_names=["x1", "x2"], output_names=["y1", "y2"],
    dynamic_axes={"x1": {0: "batch"}, "x2": {0: "batch"},
                  "y1": {0: "batch"}, "y2": {0: "batch"}}
)

校验与对齐输出

  • 结构校验:onnx.checker.check_model(onnx.load("model.onnx"))
  • 数值对齐:用相同输入在 PyTorch 与 ORT 推理,比较 np.allclose(torch_out, ort_out, rtol=1e-3, atol=1e-5);不同算子/后端可能需放宽阈值。
  • 注意:输入需在 CPU 且为 numpy;混合精度需在两侧一致。

3. 常见问题与建议

参数说明

  • opset_version: 13 兼容性好;若需要新算子/特性可选 17+,确保推理引擎支持。
  • do_constant_folding=True: 常量折叠,简化图。
  • export_params=True: 将权重打包进 ONNX(默认)。
  • 大模型(>2GB):use_external_data_format=True,并指定 all_tensors_to_one_file=True 控制权重落盘方式。
  • 训练/推理模式:确保 model.eval();如确需训练图,设置 training=torch.onnx.TrainingMode.TRAINING(受限)。

常见问题

  • 未设 eval():导出或对齐误差大;务必 model.eval()
  • 控制流/数据依赖:经典导出可能失败或数值偏差;优先试 dynamo_export
  • 不支持的算子:升级 torch/onnxruntime,尝试更高 opset;或替换为等价算子。
  • 上采样/插值:建议使用明确的 sizescale_factor,并确保 align_corners 与后端一致;优选 opset>=11
  • 自定义算子:经典导出可通过 register_custom_op_symbolic 提供 symbolic;部署侧需实现同域自定义算子。
  • 设备/精度不一致:确保导出与验证使用相同 dtype(FP32/FP16)和数据布局;验证输入转为 CPU。
  • Inplace 操作:个别 in-place 可能阻碍导出,改用 out-of-place 或升级导出器。

版本与兼容性建议

  • 首选 opset_version=13 以获得跨后端的稳健兼容;若后端版本较新且需要特性(如更优 Resize/Conv 行为),再选择更高 opset。
  • 保持 torchonnxruntime 同步更新;导出失败优先尝试升级与切换导出器(dynamo_exportexport)。
相关推荐
SUPER526615 小时前
本地开发环境_spring-ai项目启动异常
java·人工智能·spring
上进小菜猪20 小时前
基于 YOLOv8 的智能车牌定位检测系统设计与实现—从模型训练到 PyQt 可视化落地的完整实战方案
人工智能
AI浩20 小时前
UNIV:红外与可见光模态的统一基础模型
人工智能·深度学习
GitCode官方20 小时前
SGLang AI 金融 π 对(杭州站)回顾:大模型推理的工程实践全景
人工智能·金融·sglang
醒过来摸鱼20 小时前
Java classloader
java·开发语言·python
superman超哥20 小时前
仓颉语言中元组的使用:深度剖析与工程实践
c语言·开发语言·c++·python·仓颉
小鸡吃米…20 小时前
Python - 继承
开发语言·python
木头左21 小时前
LSTM模型入参有效性验证基于量化交易策略回测的方法学实践
人工智能·rnn·lstm
祁思妙想21 小时前
Python中的FastAPI框架的设计特点和性能优势
开发语言·python·fastapi
找方案21 小时前
我的 all-in-rag 学习笔记:文本分块 ——RAG 系统的 “信息切菜术“
人工智能·笔记·all-in-rag