pytorch支持更多onnx算子

pytorch支持更多onnx算子

本文主要参考扩展onnx算子

而要使 PyTorch 算子顺利转换到 ONNX ,我们需要保证以下三个环节都不出错:

  • 算子在 PyTorch 中有实现
  • 有把该 PyTorch 算子映射成一个或多个 ONNX 算子的方法
  • ONNX 有相应的算子

PyTorch 算子

  • 组合现有算子
  • 添加 TorchScript 算子
  • 添加普通 C++ 拓展算子
    映射方法
  • 为 ATen 算子添加符号函数
  • 为 TorchScript 算子添加符号函数
  • 封装成 torch.autograd.Function 并添加符号函数
    ONNX 算子
  • 使用现有 ONNX 算子
  • 定义新 ONNX 算子

支持ATen算子

ATen 是 PyTorch 内置的 C++ 张量计算库,PyTorch 算子在底层绝大多数计算都是用 ATen 实现的。

针对的问题:ATen有定义,但缺少和ONNX的映射规则。

解决的思路:

  1. 获取Aten算子接口定义。去 torch/_C/_VariableFunctions.pyitorch/nn/functional.pyi搜索算子名。如asinh,对应的接口为def asinh(input: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...
  2. 添加符号函数
    添加符号函数def symbolic(g: torch._C.Graph, input_0: torch._C.Value, input_1: torch._C.Value, ...): ,g有一个op方法,在把 PyTorch 算子转换成 ONNX 算子时,需要在符号函数中调用此方法来为最终的计算图添加一个 ONNX 算子。在最简单的情况下,我们只要把 PyTorch 算子的输入用g.op()一一对应到 ONNX 算子上即可,并把g.op()的返回值作为符号函数的返回值。在情况更复杂时,我们转换一个 PyTorch 算子可能要新建若干个 ONNX 算子。我们先去翻阅一下 ONNX 算子文档,学习一下我们在符号函数里的映射关系 g.op() 里应该怎么写。Asinh 的文档写道:该算子有一个输入 input,一个输出 output,二者的类型都为张量。

代码汇总如下

python 复制代码
import torch 
 
class Model(torch.nn.Module): 
    def __init__(self): 
        super().__init__() 
 
    def forward(self, x): 
        return torch.asinh(x) 
 
from torch.onnx.symbolic_registry import register_op 
 
def asinh_symbolic(g, input, *, out=None): 
    return g.op("Asinh", input) 
 
register_op('asinh', asinh_symbolic, '', 9) 
 
model = Model() 
input = torch.rand(1, 3, 10, 10) 
torch.onnx.export(model, input, 'asinh.onnx') 

自定义算子

针对的问题:ONNX中没有对应算子的定义,需要自定义ONNX算子,执行转换。

g.op() 是用来定义 ONNX 算子的函数,对于 ONNX 官方定义的算子,g.op() 的第一个参数就是该算子的名称。而对于一个自定义算子,g.op() 的第一个参数是一个带命名空间的算子名。

完整代码

python 复制代码
import torch 
import torchvision 
 
class Model(torch.nn.Module): 
    def __init__(self): 
        super().__init__() 
        self.conv1 = torch.nn.Conv2d(3, 18, 3) 
        self.conv2 = torchvision.ops.DeformConv2d(3, 3, 3) 
 
    def forward(self, x): 
        return self.conv2(x, self.conv1(x)) 
 
from torch.onnx import register_custom_op_symbolic 
from torch.onnx.symbolic_helper import parse_args 
 
@parse_args("v", "v", "v", "v", "v", "i", "i", "i", "i", "i", "i", "i", "i", "none") 
def symbolic(g,  
        input, 
        weight, 
        offset, 
        mask, 
        bias, 
        stride_h, stride_w, 
        pad_h, pad_w, 
        dil_h, dil_w, 
        n_weight_grps, 
        n_offset_grps, 
        use_mask): 
    return g.op("custom::deform_conv2d", input, offset) 
 
register_custom_op_symbolic("torchvision::deform_conv2d", symbolic, 9) 
 
model = Model() 
input = torch.rand(1, 3, 10, 10) 
torch.onnx.export(model, input, 'dcn.onnx')
相关推荐
0思必得02 小时前
[Web自动化] Selenium处理动态网页
前端·爬虫·python·selenium·自动化
水如烟2 小时前
孤能子视角:“组织行为学–组织文化“
人工智能
韩立学长2 小时前
【开题答辩实录分享】以《基于Python的大学超市仓储信息管理系统的设计与实现》为例进行选题答辩实录分享
开发语言·python
大山同学2 小时前
图片补全-Context Encoder
人工智能·机器学习·计算机视觉
qq_192779872 小时前
高级爬虫技巧:处理JavaScript渲染(Selenium)
jvm·数据库·python
薛定谔的猫19822 小时前
十七、用 GPT2 中文对联模型实现经典上联自动对下联:
人工智能·深度学习·gpt2·大模型 训练 调优
u0109272712 小时前
使用Plotly创建交互式图表
jvm·数据库·python
爱学习的阿磊2 小时前
Python GUI开发:Tkinter入门教程
jvm·数据库·python
壮Sir不壮2 小时前
2026年奇点:Clawdbot引爆个人AI代理
人工智能·ai·大模型·claude·clawdbot·moltbot·openclaw
PaperRed ai写作降重助手2 小时前
高性价比 AI 论文写作软件推荐:2026 年预算友好型
人工智能·aigc·论文·写作·ai写作·智能降重