如何自动生成ONNX模型?
实际开发中,我们通常从现有深度学习框架自动导出ONNX模型,而非手动编写。以下是主流框架的自动转换方法:
1. PyTorch → ONNX(最常用)
PyTorch内置了ONNX导出功能,只需一行代码:
python
import torch
import torch.nn as nn
假设有一个PyTorch模型
python
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 5)
def forward(self, x):
return self.linear(x)
model = MyModel()
dummy_input = torch.randn(1, 10) # 虚拟输入(用于追踪计算图)
自动导出为ONNX
python
torch.onnx.export(
model, # PyTorch模型
dummy_input, # 示例输入(用于确定输入形状)
"model.onnx", # 输出文件名
input_names=["X"], # 输入节点名称
output_names=["Y"], # 输出节点名称
dynamic_axes={
"X": {0: "batch"}, # 动态维度(如可变batch_size)
"Y": {0: "batch"}
}
)
关键点:
torch.onnx.export会自动追踪模型的计算图并转换为ONNX格式。
dynamic_axes允许定义动态维度(如可变batch_size)。
2. TensorFlow/Keras → ONNX
使用 tf2onnx工具自动转换
python
import tensorflow as tf
import tf2onnx
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(5, input_shape=(10,))
])
保存为SavedModel格式(或直接转换)
python
tf.saved_model.save(model, "tmp_model")
转换为ONNX
python
cmd = f"python -m tf2onnx.convert --saved-model tmp_model --output model.onnx"
!{cmd} # 在Jupyter中执行命令行(或直接在终端运行)
总结
95%的实战场景:直接用 torch.onnx.export或 tf2onnx自动转换。
特殊需求才需要手动编写ONNX(如你的代码),但需注意手动编写容易出错(例如形状不匹配)。需要熟悉ONNX的算子规范(如支持哪些操作、属性如何设置)。