PyTorch 最小模型转 ONNX 完整样例

ONNX简介

ONNX(Open Neural Network Exchange)即开放神经网络交换格式,是一种开源通用的深度学习模型标准格式。它统一定义了神经网络算子、计算图与数据存储规范,打破PyTorch、TensorFlow、Paddle等不同训练框架之间的模型壁垒,实现模型跨框架自由迁移。

开发者可将不同框架训练完成的模型导出为ONNX文件,再借助ONNX Runtime、TensorRT、NCNN等推理引擎完成端侧、服务器、嵌入式设备的高效部署,同时支持模型简化、算子优化与量化压缩,极大简化深度学习模型上线流程,是工业界AI模型部署最主流的中间格式。

1. 环境依赖

bash 复制代码
pip install torch onnx onnxsim onnxruntime onnxscript

2. 极简模型 + 导出 ONNX 代码

python 复制代码
import torch
import torch.nn as nn

# 1. 定义超简单单层网络
class MiniModel(nn.Module):
    def __init__(self):
        super().__init__()
        # 输入10维,输出2维
        self.fc = nn.Linear(10, 2)

    def forward(self, x):
        return self.fc(x)

# 2. 初始化模型
model = MiniModel()
model.eval()  # 推理模式

# 3. 构造虚拟输入 (batch=1, 输入维度10)
dummy_input = torch.randn(1, 10)

# 4. 导出 ONNX
torch.onnx.export(
    model,
    dummy_input,
    "mini_model.onnx",       # 导出文件名
    input_names=["input"],    # 输入节点名
    output_names=["output"],  # 输出节点名
    opset_version=17,         # ONNX算子版本
    do_constant_folding=True  # 常量折叠优化
)
print("✅ ONNX 导出完成:mini_model.onnx")

3. 验证 ONNX 是否合法

python 复制代码
import onnx

# 加载校验
onnx_model = onnx.load("mini_model.onnx")
onnx.checker.check_model(onnx_model)
print("✅ ONNX 模型格式合法无错误")

# 打印模型结构
print(onnx.helper.printable_graph(onnx_model.graph))

4. 简化 ONNX(去除冗余节点)

python 复制代码
import onnxsim

model_simplified, ok = onnxsim.simplify("mini_model.onnx")
assert ok, "模型简化失败"
onnx.save(model_simplified, "mini_model_simplified.onnx")
print("✅ ONNX 简化完成")

5. ONNX Runtime 推理测试

python 复制代码
import onnxruntime as ort
import numpy as np

# 加载模型
session = ort.InferenceSession("mini_model_simplified.onnx")
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name

# 构造输入
inp = np.random.randn(1, 10).astype(np.float32)
res = session.run([output_name], {input_name: inp})
print("推理结果:", res[0])

关键参数说明

  1. opset_version:越高支持算子越多,部署兼容性优先选 13~17
  2. do_constant_folding:自动合并常量,减小模型体积
  3. 输入形状:(batch_size, feature_dim) 按需修改

扩展:带ReLU激活的常用小模型

python 复制代码
class MiniModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(10, 32),
            nn.ReLU(),
            nn.Linear(32, 2)
        )
    def forward(self, x):
        return self.net(x)

输出为

复制代码
✅ ONNX 模型格式合法无错误
/workspace/onnx/t2.py:9: DeprecationWarning: Deprecated since 1.19. Consider using onnx.printer.to_text() instead.
  print(onnx.helper.printable_graph(onnx_model.graph))
graph main_graph (
  %input[FLOAT, 1x10]
) initializers (
  %net.0.weight[FLOAT, 32x10]
  %net.0.bias[FLOAT, 32]
  %net.2.weight[FLOAT, 2x32]
  %net.2.bias[FLOAT, 2]
) {
  %linear = Gemm[alpha = 1, beta = 1, transA = 0, transB = 1](%input, %net.0.weight, %net.0.bias)
  %relu = Relu(%linear)
  %output = Gemm[alpha = 1, beta = 1, transA = 0, transB = 1](%relu, %net.2.weight, %net.2.bias)
  return %output
}

详解

相关推荐
用户8356290780511 小时前
Python 实现 PDF 文件加密与解密方法
后端·python
用户8356290780511 小时前
使用 Python 冻结与拆分 Excel 窗格教程
后端·python
阿里云大数据AI技术2 小时前
构建高转化海外电商搜索:阿里云OpenSearch行业算法版的全链路智能优化策略实战
人工智能·搜索引擎
Awu12272 小时前
⚡从零开发 Agent CLI(五)实现一个可治理、可扩展的工具系统
前端·人工智能·claude
字节跳动视频云技术团队2 小时前
让 Agent 成为音视频工作台:AI MediaKit CLI + Skill 发布
人工智能·音视频开发
魏祖潇2 小时前
framework 整合实战——DDD/TDD/SDD 三件套在 framework 仓的真实落地
人工智能·后端
Token炼金师3 小时前
去噪扩散:从随机噪声到高保真图像的数学之路
人工智能·aigc
这个DBA有点耶3 小时前
AI写的SQL跑崩了生产库,这锅谁背?
数据库·人工智能·程序员
阿里云大数据AI技术3 小时前
阿里云 EMR AI 助手正式发布:从问答工具到全栈智能运维助手
运维·人工智能
Larcher4 小时前
从零搭建 MCP 服务——让 AI 拥有无限扩展能力
人工智能·程序员