将Pytorch搭建的ViT模型转为onnx模型

本文尝试将pytorch搭建的ViT模型转为onnx模型。

首先将博主上一篇文章中搭建的模型ViT Vision Transformer超详细解析,网络构建,可视化,数据预处理,全流程实例教程-CSDN博客转存为.pth

python 复制代码
torch.save(model, 'my_vit_model.pth')

然后新建一个py文件,要新建py文件的原因是,博主上一篇文章的main.py文件引用了很多torch相关的库,如果还是在main.py文件中运行转onnx的代码,回报错circle import 重复循环引用的错误,所以姑且将.pth作为一个中转。

新建一个py文件,写入

python 复制代码
import importlib
torch = importlib.import_module('torch')


model = torch.load("my_vit_model.pth")


model.cpu()
# 创建一个随机的输入张量
dummy_input = torch.randn(1, 3, 16, 16)
torch.onnx.export(model, dummy_input, 'model.onnx', opset_version=18)

引入importlib,通过它来引用torch也是为了解决循环引用的问题。

这时运行这段代码,会报错onnx 不支持aten::unflatten运算。这里有两种解决方法,一种是将自己pytorch模型中的unflatten运算全部换成onnx支持的reshape函数(参见文章:https://www.cnblogs.com/antelx/p/17564039.html

还有一种方法是,修改onnx库中的 symbolic_opset18.py 文件(/home/.local/lib/python3.8/site-packages/torch/onnx),改为如下形式

python 复制代码
"""This file exports ONNX ops for opset 18.

Note [ONNX Operators that are added/updated in opset 18]

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-18-of-the-default-onnx-operator-set
New operators:
    CenterCropPad
    Col2Im
    Mish
    OptionalGetElement
    OptionalHasElement
    Pad
    Resize
    ScatterElements
    ScatterND
"""

import functools
from typing import Sequence

import torch
import torch._C._onnx as _C_onnx
from torch.onnx import (
    _constants,
    _type_utils,
    errors,
    symbolic_helper,
    symbolic_opset11 as opset11,
    symbolic_opset9 as opset9,
    utils,
)
from torch.onnx._internal import _beartype, jit_utils, registration

from torch import _C
from torch.onnx import symbolic_helper
from torch.onnx._internal import _beartype, registration

# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py

__all__ = ["col2im"]

_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=18)


@_onnx_symbolic("aten::col2im")
@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is")
@_beartype.beartype
def col2im(
    g,
    input: _C.Value,
    output_size: _C.Value,
    kernel_size: _C.Value,
    dilation: Sequence[int],
    padding: Sequence[int],
    stride: Sequence[int],
):
    # convert [i0, i1, ..., in] into [i0, i0, i1, i1, ..., in, in]
    adjusted_padding = []
    for pad in padding:
        for _ in range(2):
            adjusted_padding.append(pad)

    num_dimensional_axis = symbolic_helper._get_tensor_sizes(output_size)[0]
    if not adjusted_padding:
        adjusted_padding = [0, 0] * num_dimensional_axis

    if not dilation:
        dilation = [1] * num_dimensional_axis

    if not stride:
        stride = [1] * num_dimensional_axis

    return g.op(
        "Col2Im",
        input,
        output_size,
        kernel_size,
        dilations_i=dilation,
        pads_i=adjusted_padding,
        strides_i=stride,
    )



@_onnx_symbolic("aten::unflatten")
def unflatten(g:jit_utils.GraphContext, input, dim, unflattened_size):
    input_dim = symbolic_helper._get_tensor_rank(input)
    if input_dim is None:
        return symbolic_helper._unimplemented(
            "dim",
            "ONNX and PyTorch use different strategies to split the input. "
            "Input rank must be known at export time.",
        )

    # dim could be negative
    input_dim = g.op("Constant", value_t=torch.tensor([input_dim], dtype=torch.int64))
    dim = g.op("Add", input_dim, dim)
    dim = g.op("Mod", dim, input_dim)

    input_size = g.op("Shape", input)

    head_start_idx = g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64))
    head_end_idx = g.op(
        "Reshape", dim, g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64))
    )
    head_part_rank = g.op("Slice", input_size, head_start_idx, head_end_idx)

    dim_plus_one = g.op(
        "Add", dim, g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64))
    )
    tail_start_idx = g.op(
        "Reshape",
        dim_plus_one,
        g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64)),
    )
    tail_end_idx = g.op(
        "Constant", value_t=torch.tensor([_constants.INT64_MAX], dtype=torch.int64)
    )
    tail_part_rank = g.op("Slice", input_size, tail_start_idx, tail_end_idx)

    final_shape = g.op(
        "Concat", head_part_rank, unflattened_size, tail_part_rank, axis_i=0
    )

    return symbolic_helper._reshape_helper(g, input, final_shape)

这里这样做是相当于自己在onnx库中注册aten::unflatten运算。

再新建一个py文件,写入

python 复制代码
import onnxruntime as rt
import numpy as np

# 加载模型
sess = rt.InferenceSession("model.onnx")

# 获取输入和输出名称
input_name = sess.get_inputs()[0].name
output_name = sess.get_outputs()[0].name

# 创建输入数据
input_data = np.random.rand(1, 3, 16, 16).astype(np.float32)

# 运行模型
pred_onnx = sess.run([output_name], {input_name: input_data})

# 打印预测结果
print(pred_onnx)

就可以运行onnx模型了。

相关推荐
white-persist3 分钟前
MCP协议深度解析:AI时代的通用连接器
网络·人工智能·windows·爬虫·python·自动化
递归不收敛4 分钟前
吴恩达机器学习课程(PyTorch 适配)学习笔记:3.4 强化学习
pytorch·学习·机器学习
新智元4 分钟前
谷歌杀入诺奖神殿,两年三冠五得主!世界TOP3重现贝尔实验室神话
人工智能·openai
StarPrayers.7 分钟前
卷积层(Convolutional Layer)学习笔记
人工智能·笔记·深度学习·学习·机器学习
skywalk816310 分钟前
AutoCoder Nano 是一款轻量级的编码助手, 利用大型语言模型(LLMs)帮助开发者编写, 理解和修改代码。
人工智能
金井PRATHAMA16 分钟前
描述逻辑对人工智能自然语言处理中深层语义分析的影响与启示
人工智能·自然语言处理·知识图谱
却道天凉_好个秋22 分钟前
OpenCV(四):视频采集与保存
人工智能·opencv·音视频
minhuan23 分钟前
构建AI智能体:五十七、LangGraph + Gradio:构建可视化AI工作流的趣味指南
人工智能·语言模型·workflow·langgraph·自定义工作流
codists34 分钟前
2025年9月文章一览
python
语落心生36 分钟前
FastDeploy SD & Flux 扩散模型边缘端轻量化推理部署实现
python