PyTorch 最小模型转 ONNX 完整样例

ONNX简介

ONNX(Open Neural Network Exchange)即开放神经网络交换格式,是一种开源通用的深度学习模型标准格式。它统一定义了神经网络算子、计算图与数据存储规范,打破PyTorch、TensorFlow、Paddle等不同训练框架之间的模型壁垒,实现模型跨框架自由迁移。

开发者可将不同框架训练完成的模型导出为ONNX文件,再借助ONNX Runtime、TensorRT、NCNN等推理引擎完成端侧、服务器、嵌入式设备的高效部署,同时支持模型简化、算子优化与量化压缩,极大简化深度学习模型上线流程,是工业界AI模型部署最主流的中间格式。

1. 环境依赖

bash 复制代码
pip install torch onnx onnxsim onnxruntime onnxscript

2. 极简模型 + 导出 ONNX 代码

python 复制代码
import torch
import torch.nn as nn

# 1. 定义超简单单层网络
class MiniModel(nn.Module):
    def __init__(self):
        super().__init__()
        # 输入10维,输出2维
        self.fc = nn.Linear(10, 2)

    def forward(self, x):
        return self.fc(x)

# 2. 初始化模型
model = MiniModel()
model.eval()  # 推理模式

# 3. 构造虚拟输入 (batch=1, 输入维度10)
dummy_input = torch.randn(1, 10)

# 4. 导出 ONNX
torch.onnx.export(
    model,
    dummy_input,
    "mini_model.onnx",       # 导出文件名
    input_names=["input"],    # 输入节点名
    output_names=["output"],  # 输出节点名
    opset_version=17,         # ONNX算子版本
    do_constant_folding=True  # 常量折叠优化
)
print("✅ ONNX 导出完成:mini_model.onnx")

3. 验证 ONNX 是否合法

python 复制代码
import onnx

# 加载校验
onnx_model = onnx.load("mini_model.onnx")
onnx.checker.check_model(onnx_model)
print("✅ ONNX 模型格式合法无错误")

# 打印模型结构
print(onnx.helper.printable_graph(onnx_model.graph))

4. 简化 ONNX(去除冗余节点)

python 复制代码
import onnxsim

model_simplified, ok = onnxsim.simplify("mini_model.onnx")
assert ok, "模型简化失败"
onnx.save(model_simplified, "mini_model_simplified.onnx")
print("✅ ONNX 简化完成")

5. ONNX Runtime 推理测试

python 复制代码
import onnxruntime as ort
import numpy as np

# 加载模型
session = ort.InferenceSession("mini_model_simplified.onnx")
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name

# 构造输入
inp = np.random.randn(1, 10).astype(np.float32)
res = session.run([output_name], {input_name: inp})
print("推理结果:", res[0])

关键参数说明

  1. opset_version:越高支持算子越多,部署兼容性优先选 13~17
  2. do_constant_folding:自动合并常量,减小模型体积
  3. 输入形状:(batch_size, feature_dim) 按需修改

扩展:带ReLU激活的常用小模型

python 复制代码
class MiniModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(10, 32),
            nn.ReLU(),
            nn.Linear(32, 2)
        )
    def forward(self, x):
        return self.net(x)

输出为

复制代码
✅ ONNX 模型格式合法无错误
/workspace/onnx/t2.py:9: DeprecationWarning: Deprecated since 1.19. Consider using onnx.printer.to_text() instead.
  print(onnx.helper.printable_graph(onnx_model.graph))
graph main_graph (
  %input[FLOAT, 1x10]
) initializers (
  %net.0.weight[FLOAT, 32x10]
  %net.0.bias[FLOAT, 32]
  %net.2.weight[FLOAT, 2x32]
  %net.2.bias[FLOAT, 2]
) {
  %linear = Gemm[alpha = 1, beta = 1, transA = 0, transB = 1](%input, %net.0.weight, %net.0.bias)
  %relu = Relu(%linear)
  %output = Gemm[alpha = 1, beta = 1, transA = 0, transB = 1](%relu, %net.2.weight, %net.2.bias)
  return %output
}

详解

相关推荐
_oP_i6 小时前
FFmpeg 如何与ai结合剪辑出效果好的视频
人工智能·ffmpeg·音视频
脑极体6 小时前
嗜血的AI
人工智能·chatgpt
z202305087 小时前
RDMA之RoCEv2 无损网络PFC 、DCQCN 和ECN (7)
linux·服务器·网络·人工智能·ai
必须会一定会7 小时前
我用 AI 做记账 App:技术方案怎么选,才能既简单又能落地
人工智能
m0_380167147 小时前
CoinGlass API vs Glassnode:全面对比分析
人工智能·ai·区块链
陆业聪7 小时前
Gemini Spark深度拆解:Google给AI一台永不关机的云服务器
人工智能·aigc
我星期八休息7 小时前
Linux系统编程—库制作与原理
linux·运维·服务器·数据结构·人工智能·python·散列表
AI品信智慧数智人7 小时前
✨AI 赋能医疗,智启健康新未来
人工智能
AiTop1007 小时前
智谱AI推出ZCube组网架构:大模型推理性能与成本双突破,重构智算基础设施
人工智能·重构·架构