ONNX简介
ONNX(Open Neural Network Exchange)即开放神经网络交换格式,是一种开源通用的深度学习模型标准格式。它统一定义了神经网络算子、计算图与数据存储规范,打破PyTorch、TensorFlow、Paddle等不同训练框架之间的模型壁垒,实现模型跨框架自由迁移。
开发者可将不同框架训练完成的模型导出为ONNX文件,再借助ONNX Runtime、TensorRT、NCNN等推理引擎完成端侧、服务器、嵌入式设备的高效部署,同时支持模型简化、算子优化与量化压缩,极大简化深度学习模型上线流程,是工业界AI模型部署最主流的中间格式。
1. 环境依赖
bash
pip install torch onnx onnxsim onnxruntime onnxscript
2. 极简模型 + 导出 ONNX 代码
python
import torch
import torch.nn as nn
# 1. 定义超简单单层网络
class MiniModel(nn.Module):
def __init__(self):
super().__init__()
# 输入10维,输出2维
self.fc = nn.Linear(10, 2)
def forward(self, x):
return self.fc(x)
# 2. 初始化模型
model = MiniModel()
model.eval() # 推理模式
# 3. 构造虚拟输入 (batch=1, 输入维度10)
dummy_input = torch.randn(1, 10)
# 4. 导出 ONNX
torch.onnx.export(
model,
dummy_input,
"mini_model.onnx", # 导出文件名
input_names=["input"], # 输入节点名
output_names=["output"], # 输出节点名
opset_version=17, # ONNX算子版本
do_constant_folding=True # 常量折叠优化
)
print("✅ ONNX 导出完成:mini_model.onnx")
3. 验证 ONNX 是否合法
python
import onnx
# 加载校验
onnx_model = onnx.load("mini_model.onnx")
onnx.checker.check_model(onnx_model)
print("✅ ONNX 模型格式合法无错误")
# 打印模型结构
print(onnx.helper.printable_graph(onnx_model.graph))
4. 简化 ONNX(去除冗余节点)
python
import onnxsim
model_simplified, ok = onnxsim.simplify("mini_model.onnx")
assert ok, "模型简化失败"
onnx.save(model_simplified, "mini_model_simplified.onnx")
print("✅ ONNX 简化完成")
5. ONNX Runtime 推理测试
python
import onnxruntime as ort
import numpy as np
# 加载模型
session = ort.InferenceSession("mini_model_simplified.onnx")
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
# 构造输入
inp = np.random.randn(1, 10).astype(np.float32)
res = session.run([output_name], {input_name: inp})
print("推理结果:", res[0])
关键参数说明
opset_version:越高支持算子越多,部署兼容性优先选 13~17do_constant_folding:自动合并常量,减小模型体积- 输入形状:
(batch_size, feature_dim)按需修改
扩展:带ReLU激活的常用小模型
python
class MiniModel(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(10, 32),
nn.ReLU(),
nn.Linear(32, 2)
)
def forward(self, x):
return self.net(x)
输出为
✅ ONNX 模型格式合法无错误
/workspace/onnx/t2.py:9: DeprecationWarning: Deprecated since 1.19. Consider using onnx.printer.to_text() instead.
print(onnx.helper.printable_graph(onnx_model.graph))
graph main_graph (
%input[FLOAT, 1x10]
) initializers (
%net.0.weight[FLOAT, 32x10]
%net.0.bias[FLOAT, 32]
%net.2.weight[FLOAT, 2x32]
%net.2.bias[FLOAT, 2]
) {
%linear = Gemm[alpha = 1, beta = 1, transA = 0, transB = 1](%input, %net.0.weight, %net.0.bias)
%relu = Relu(%linear)
%output = Gemm[alpha = 1, beta = 1, transA = 0, transB = 1](%relu, %net.2.weight, %net.2.bias)
return %output
}
详解

