pytorch支持更多onnx算子
本文主要参考扩展onnx算子。
而要使 PyTorch 算子顺利转换到 ONNX ,我们需要保证以下三个环节都不出错:
- 算子在 PyTorch 中有实现
- 有把该 PyTorch 算子映射成一个或多个 ONNX 算子的方法
- ONNX 有相应的算子
PyTorch 算子
- 组合现有算子
- 添加 TorchScript 算子
- 添加普通 C++ 拓展算子
映射方法 - 为 ATen 算子添加符号函数
- 为 TorchScript 算子添加符号函数
- 封装成 torch.autograd.Function 并添加符号函数
ONNX 算子 - 使用现有 ONNX 算子
- 定义新 ONNX 算子
支持ATen算子
ATen 是 PyTorch 内置的 C++ 张量计算库,PyTorch 算子在底层绝大多数计算都是用 ATen 实现的。
针对的问题:ATen有定义,但缺少和ONNX的映射规则。
解决的思路:
- 获取Aten算子接口定义。去
torch/_C/_VariableFunctions.pyi
和torch/nn/functional.pyi
搜索算子名。如asinh
,对应的接口为def asinh(input: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...
- 添加符号函数
添加符号函数def symbolic(g: torch._C.Graph, input_0: torch._C.Value, input_1: torch._C.Value, ...):
,g有一个op方法,在把 PyTorch 算子转换成 ONNX 算子时,需要在符号函数中调用此方法来为最终的计算图添加一个 ONNX 算子。在最简单的情况下,我们只要把 PyTorch 算子的输入用g.op()一一对应到 ONNX 算子上即可,并把g.op()的返回值作为符号函数的返回值。在情况更复杂时,我们转换一个 PyTorch 算子可能要新建若干个 ONNX 算子。我们先去翻阅一下 ONNX 算子文档,学习一下我们在符号函数里的映射关系 g.op() 里应该怎么写。Asinh 的文档写道:该算子有一个输入 input,一个输出 output,二者的类型都为张量。
代码汇总如下
python
import torch
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return torch.asinh(x)
from torch.onnx.symbolic_registry import register_op
def asinh_symbolic(g, input, *, out=None):
return g.op("Asinh", input)
register_op('asinh', asinh_symbolic, '', 9)
model = Model()
input = torch.rand(1, 3, 10, 10)
torch.onnx.export(model, input, 'asinh.onnx')
自定义算子
针对的问题:ONNX中没有对应算子的定义,需要自定义ONNX算子,执行转换。
g.op()
是用来定义 ONNX 算子的函数,对于 ONNX 官方定义的算子,g.op() 的第一个参数就是该算子的名称。而对于一个自定义算子,g.op() 的第一个参数是一个带命名空间的算子名。
完整代码
python
import torch
import torchvision
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 18, 3)
self.conv2 = torchvision.ops.DeformConv2d(3, 3, 3)
def forward(self, x):
return self.conv2(x, self.conv1(x))
from torch.onnx import register_custom_op_symbolic
from torch.onnx.symbolic_helper import parse_args
@parse_args("v", "v", "v", "v", "v", "i", "i", "i", "i", "i", "i", "i", "i", "none")
def symbolic(g,
input,
weight,
offset,
mask,
bias,
stride_h, stride_w,
pad_h, pad_w,
dil_h, dil_w,
n_weight_grps,
n_offset_grps,
use_mask):
return g.op("custom::deform_conv2d", input, offset)
register_custom_op_symbolic("torchvision::deform_conv2d", symbolic, 9)
model = Model()
input = torch.rand(1, 3, 10, 10)
torch.onnx.export(model, input, 'dcn.onnx')