Pytorch转onnx

pytorch 转 onnx 模型需要函数 torch.onnx.export。

python 复制代码
def export(
    model: Union[torch.nn.Module, torch.jit.ScriptModule, torch.jit.ScriptFunction],
    args: Union[Tuple[Any, ...], torch.Tensor],
    f: Union[str, io.BytesIO],
    export_params: bool = True,
    verbose: bool = False,
    training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL,
    input_names: Optional[Sequence[str]] = None,
    output_names: Optional[Sequence[str]] = None,
    operator_export_type: _C_onnx.OperatorExportTypes = _C_onnx.OperatorExportTypes.ONNX,
    opset_version: Optional[int] = None,
    do_constant_folding: bool = True,
    dynamic_axes: Optional[
        Union[Mapping[str, Mapping[int, str]], Mapping[str, Sequence[int]]]
    ] = None,
    keep_initializers_as_inputs: Optional[bool] = None,
    custom_opsets: Optional[Mapping[str, int]] = None,
    export_modules_as_functions: Union[bool, Collection[Type[torch.nn.Module]]] = False,
) -> None:

常用参数说明

python 复制代码
model------需要导出的pytorch模型
args------模型的输入参数,满足输入层的shape正确即可。
f------输出的onnx模型的位置。例如'yolov5.onnx'。
export_params------输出模型是否可训练。default=True,表示导出trained model,否则untrained。
verbose------是否打印模型转换信息。default=False。
input_names------输入节点名称。default=None。
output_names------输出节点名称。default=None。
opset_version------算子指令集合
do_constant_folding------是否使用常量折叠,默认即可。default=True。
dynamic_axes------模型的输入输出有时是可变的,如Rnn,或者输出图像的batch可变,可通过该参数设置。如输入层的shape为(b,3,h,w),batch,height,width是可变的,但是chancel是固定三通道

参数说明
ONNX算子文档

ONNX 算子的定义情况,都可以在官方的算子文档中查看

这份文档中最重要的开头的这个算子变更表格。表格的第一列是算子名,第二列是该算子发生变动的算子集版本号,也就是我们之前在torch.onnx.export中提到的opset_version表示的算子集版本号。通过查看算子第一次发生变动的版本号,我们可以知道某个算子是从哪个版本开始支持的;通过查看某算子小于等于opset_version的第一个改动记录,我们可以知道当前算子集版本中该算子的定义规则。

练习

python 复制代码
import torch
import torch.nn as nn
import torch.onnx

class Model(torch.nn.Module):
    def __init__(self, in_features, out_features, weights, bias=False):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features, bias)
        with torch.no_grad():
            self.linear.weight.copy_(weights)
    
    def forward(self, x):
        x = self.linear(x)
        return x

def infer():
    in_features = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
    weights = torch.tensor([
        [1, 2, 3, 4],
        [2, 3, 4, 5],
        [3, 4, 5, 6]
    ],dtype=torch.float32)
    
    model = Model(4, 3, weights)
    x = model(in_features)
    print("result is: ", x)

def export_onnx():
    input   = torch.zeros(1, 1, 1, 4)
    weights = torch.tensor([
        [1, 2, 3, 4],
        [2, 3, 4, 5],
        [3, 4, 5, 6]
    ],dtype=torch.float32)
    model   = Model(4, 3, weights)
    model.eval() #添加eval防止权重继续更新

    # pytorch导出onnx的方式,参数有很多,也可以支持动态size
    # 我们先做一些最基本的导出,从netron学习一下导出的onnx都有那些东西
    torch.onnx.export(
        model         = model, 
        args          = (input,),
        f             = "../models/example.onnx",
        input_names   = ["input0"],
        output_names  = ["output0"],
        opset_version = 12)
    print("Finished onnx export")


if __name__ == "__main__":
    infer()
    export_onnx()

然后使用netron打开onnx文件,如果没有安装netron,在终端使用pip install netron。

参考链接
模型部署入门教程(三):PyTorch 转 ONNX 详解

相关推荐
Tianyanxiao3 分钟前
如何利用探商宝精准营销,抓住行业机遇——以AI技术与大数据推动企业信息精准筛选
大数据·人工智能·科技·数据分析·深度优先·零售
撞南墙者10 分钟前
OpenCV自学系列(1)——简介和GUI特征操作
人工智能·opencv·计算机视觉
OCR_wintone42111 分钟前
易泊车牌识别相机,助力智慧工地建设
人工智能·数码相机·ocr
王哈哈^_^33 分钟前
【数据集】【YOLO】【VOC】目标检测数据集,查找数据集,yolo目标检测算法详细实战训练步骤!
人工智能·深度学习·算法·yolo·目标检测·计算机视觉·pyqt
一者仁心39 分钟前
【AI技术】PaddleSpeech
人工智能
是瑶瑶子啦1 小时前
【深度学习】论文笔记:空间变换网络(Spatial Transformer Networks)
论文阅读·人工智能·深度学习·视觉检测·空间变换
EasyCVR1 小时前
萤石设备视频接入平台EasyCVR多品牌摄像机视频平台海康ehome平台(ISUP)接入EasyCVR不在线如何排查?
运维·服务器·网络·人工智能·ffmpeg·音视频
柳鲲鹏1 小时前
OpenCV视频防抖源码及编译脚本
人工智能·opencv·计算机视觉
西柚小萌新1 小时前
8.机器学习--决策树
人工智能·决策树·机器学习
向阳12181 小时前
Bert快速入门
人工智能·python·自然语言处理·bert