PyTorch自定义C++拓展算子、转换ONNX模型、ORT部署实现

文章涵盖从具有PyTorch自定义算子的模型训练开始,到基于ONNXRuntime框架的端侧部署整个过程,以为SNN(脉冲神经网络,Spiking Neural Network)添加Lif算子支持作为示例,为读者提供技术路径和官方文档、相关优质教程链接。不懂Lif算子的读者不要有压力,实际上只是替换了激活函数而已,以下是Lif原理及实现的一些说明,有兴趣的可以看源码。

示例模型说明:文章建立了一个一维多层卷积模型(数据排布:NCT),将卷积层后的激活函数改为Lif,Lif层将会遍历每一个元素,并在时间维度T进行衰减和发射迭代,模拟SNN中神经元的发射机制。代码依赖于SNN-HAR,对于Lif的基本实现可以详见仓库内文件:models/spike.py

文章目录

PyTorch自定义算子

我们先来看一下建立的PyTorch模型,这是一个三层一维卷积模型,最后接了线性层完成分类任务输出。可以看到将应有的激活函数替换为了LIFSpike,并且传入了threshbeta两个参数,模型中的其他参数省略。

python 复制代码
class SFCN(FCN):
    def __init__(self, n_channels, n_classes, out_channels=128, backbone=True, **kwargs):
        super(SFCN, self).__init__(n_channels,
                                   n_classes, out_channels, backbone)
       	self.conv_block = nn.Sequential(
            nn.Conv1d(n_channels, 32, kernel_size=8,
                      stride=1, padding=4, bias=False),
            nn.BatchNorm1d(32),
            LIFSpike(thresh=kwargs['thresh'], beta=kwargs['tau']),
            nn.MaxPool1d(kernel_size=2, stride=2, padding=1),
            nn.Dropout(0.35),
            nn.Conv1d(32, 64, kernel_size=8, stride=1, padding=4, bias=False),
            nn.BatchNorm1d(64),
            LIFSpike(thresh=kwargs['thresh'], beta=kwargs['tau']),
            nn.MaxPool1d(kernel_size=2, stride=2, padding=1),
            nn.Conv1d(64, out_channels, kernel_size=8,
                      stride=1, padding=4, bias=False),
            nn.BatchNorm1d(out_channels),
            LIFSpike(thresh=kwargs['thresh'], beta=kwargs['tau']),
            nn.MaxPool1d(kernel_size=2, stride=2, padding=1)
        )
        self.logits = nn.Linear(self.out_len * out_channels, n_classes)
    def forward(self, x):
        x = self.conv_block(x)
        x_flat = x.reshape(x.shape[0], -1)
        logits = self.logits(x_flat)
        return logits

以下是LIFSpike的python部分实现,完整代码请见SNN-HAR,这里为了方便理解把源码中的ZIF替换为了FastSigmoid。构建模型使用继承自nn.Module的LIFSpike,FastSigmoid继承自torch.autograd.Function,他们都具有前向和后向实现(LIFSpike依赖PyTorch自动推导实现后向)。LIFSpike中具有一个激活act 操作,他被赋值为FastSigmoid.apply从而引用FastSigmoid,在LIFSpike的forward中,对act 直接传参即可调用FastSigmoid的forward。我们可以再封装一层需要引用的FastSigmoid来减少每次act函数调用所需的参数量,详见两行注释。

python 复制代码
class FastSigmoid(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input_, slope=25):
        ctx.save_for_backward(input_)
        ctx.slope = slope
        out = (input_ > 0).float()
        return out

    @staticmethod
    def backward(ctx, grad_output):
        (input_,) = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad = grad_input / (ctx.slope * torch.abs(input_) + 1.0) ** 2
        return grad, None

def fast_sigmoid(slope=25):
    slope = slope
    def inner(x):
        return FastSigmoid.apply(x, slope)
    return inner

class LIFSpike(nn.Module):
    def __init__(self, thresh=0.5, tau=0.75, gamma=1.0, dspike=True, soft_reset=True):
        super(LIFSpike, self).__init__()
        self._act = FastSigmoid.apply
        # self._act = fast_sigmoid(slope=25)

    def forward(self, x):
        spike = self._act(mem - self.thresh, slope=25)
        # spike = self._act(mem - self.thresh)
        return torch.stack(spike_out, dim=2)

C++拓展算子的实现

为了提高算子执行速度,我们选择使用C++拓展来实现LIFSpike,PyTorch的C++拓展算子不支持通过前向过程实现反向过程的自动推理,必须要自己写出反向过程。本示例中只写了前向过程用于加速模型推理,模型训练阶段仍然使用基于python描述的算子,不过算子的反向过程C++拓展与前向的添加步骤相同。

这部分有非常好的视频教程和配套的源码库可供学习:y2b or bilibil

首先需要建立一个用于放置C++拓展算子的目录,需要创建的文件如下:

shell 复制代码
lif/
├── include
│   └── utils.h
├── lif.cpp
├── setup.py
└── test.py

utils.h:(没啥好说的,CHECK用于检查输入是否正确)

cpp 复制代码
#include <torch/extension.h>

#define CHECK_CPU(x) TORCH_CHECK(x.is_cpu(), #x " must be a CPU tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
  CHECK_CPU(x);        \
  CHECK_CONTIGUOUS(x)

torch::Tensor lif_fw(
    torch::Tensor input,
    float thresh_,
    float beta_);

setup.py:(也没啥好说的,extra_compile_args可以设置编译器优化级别)

python 复制代码
import glob
import os.path as osp
from setuptools import setup
from torch.utils.cpp_extension import CppExtension, BuildExtension

ROOT_DIR = osp.dirname(osp.abspath(__file__))
include_dirs = [osp.join(ROOT_DIR, "include")]

sources = glob.glob('*.cpp')

setup(
    name='lif',
    version='0.1',
    author='',
    author_email='',
    description='lif for snn',
    long_description='lif for snn',
    ext_modules=[
        CppExtension(
            name='lif',
            sources=sources,
            include_dirs=include_dirs,
            extra_compile_args={'cxx': ['-O2']}
        )
    ],
    cmdclass={
        'build_ext': BuildExtension
    }
)

lif.cpp:这边把完整的LIF实现写出来了,可以参考一下取值、赋值、取大小的几个api。(fw是forward的意思)

python 复制代码
#include "include/utils.h"

torch::Tensor
lif_fw(
    torch::Tensor input,
    float thresh_,
    float beta_)
{
  CHECK_INPUT(input);
  
  float *data = input.data_ptr<float>();

  float thresh = thresh_;
  float beta = beta_;
  
  auto sizes = input.sizes();
  int CIN = sizes[1];
  int T = sizes[2];

  // ASSUME THAT THE INPUT HAS ONLY ONE BATCH
  float mem[CIN] = {0};
  float spike_out[CIN * T] = {0};
  float spike = 0;
  for (int i = 0; i < T; i++)
  {
    for (int c = 0; c < CIN; c++)
    {
      mem[c] *= beta;
      mem[c] += data[c * T + i];
      spike = mem[c] - thresh;
      if (spike > 0)
      {
        mem[c] -= thresh;
        spike_out[c * T + i] = 1;
      }
    }
  }

  torch::Tensor output = torch::zeros(sizes, input.options());
  output.copy_(torch::from_blob(spike_out, sizes, input.options()));
  return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
  m.def("lif_fw", &lif_fw);
}

最后在创建的目录下执行:

bash 复制代码
pip install .

然后可以按照前面的视频教程链接通过test.py文件进行测试。

C++拓展算子的引入

回到python部分,前一节已经将lif算子通过pip安装成了一个包,现在可以直接导入并直接使用。这里建立了一层Lif封装,定义了用于onnx模型导出的symbolic和forward前向,然后使用继承自nn.Module的LifSpike通过apply调用。

python 复制代码
import lif
class Lif(torch.autograd.Function):
    @staticmethod
    def symbolic(g, input, thresh, beta):
        return g.op("custom::Lif", input, thresh_f=thresh, beta_f=beta)

    @staticmethod
    def forward(ctx, input, thresh, beta):
        output = lif.lif_fw(input, thresh, beta)
        ctx.save_for_backward(input)
        return output
    
class LifSpike(nn.Module):
    def __init__(self, thresh, beta):
        super(LifSpike, self).__init__()
        self.thresh = thresh
        self.beta = beta
    def forward(self, x):
        return Lif.apply(x, self.thresh, self.beta)

定义算子symbol

这里要先讲一下把包含自定义算子的模型导出为ONNX模型的几种方法,由于ONNX并不知道我们这个自定义算子究竟是什么样子,不知道算子的输入输出和基本属性。在导出的时候可以使用现有ONNX算子来描述你的自定义算子,这会导致潜在的算子执行效率降低,并且对于复杂运算并不一定能通过现有算子来实现,所以文章主要介绍将我们的自定义算子包装成一个独立的描述完整的算子,并在ONNXRuntime中添加对这个算子的支持。

属性即attribute,padding、stride等就是卷积层的属性,属性在ONNXRuntime中的获取方式不同于input输入,这个下一节会讲
关于PyTorch ONNX Exporter的详细介绍:官方文档 or 中文翻译教程

上一节出现的Lif算子中包含一个symbolic定义,这就是将自定义算子包装成独立算子的必要条件。g.op() 是用来定义 ONNX 算子的函数。对于 ONNX 官方定义的算子,g.op() 的第一个参数就是该算子的名称。而对于一个自定义算子,g.op() 的第一个参数是一个带命名空间的算子名。命名空间(OpDomain)用于防止命名冲突,就这样我们建立一个名为"custom"的命名空间,并在该空间中定义了一个名为"Lif"的算子,这种独立包装是不包含任何具体实现的。

python 复制代码
def symbolic(g, input, thresh, beta):
    return g.op("custom::Lif", input, thresh_f=thresh, beta_f=beta)

对于LIF算子,threshbeta 是他的属性,input是他的输入。可以注意到在给g.op传参的时候,threshbeta加了一个_f的后缀,请ChatGPT帮我们解释:

在ONNX中,当您通过symbolic函数定义自定义算子的属性时,需要在属性名后加上一个类型后缀来明确属性值的类型。这些类型后缀有:

_i:整数

_f:浮点数

_s:字符串

_t:张量

_is:整数列表

_fs:浮点数列表

_ss:字符串列表

_ts:张量列表

模型定义部分全部结束,接下来我们将模型导出。

导出为ONNX模型

包含自定义算子的模型导出与正常的模型相比,需要传入你的自定义命名空间和一个版本号,版本号可以自行设置。

python 复制代码
torch.onnx.export(model_test, ipt, 'results/onnx/' + model_name, 
                    input_names=['input'],
                    output_names=['output'],
                    custom_opsets={"custom":1})

以下是官方源码中对custom_opsets这一参数的解释:

custom_opsets (dict[str, int], default empty dict): A dict with schema:

​ * KEY (str): opset domain name

​ * VALUE (int): opset version

​ If a custom opset is referenced by model but not mentioned in this dictionary, the opset version is set to 1. Only custom opset domain name and version should be indicated through this argument.

通过Netron来看看自定义算子Lif,用于参考。

为ONNXRuntime提供自定义算子实现

从 onnxruntime 1.16开始,可以使用编译后的api非常方便地直接完成自定义算子的实现。在官方文档的第一个示例中,提供了一个v1::CustomOpOne算子的实现,这个算子没有属性,所以实现起来更加简单。

Custom operators:官方文档

官方文档中的第二个示例提供了一个具有属性的算子该如何获取自己的属性,主要通过ort_api->KernelInfoGetAttribute_int64来获取,不同的属性类型有对应的api,这里直接放出我写的Lif实现:

cpp 复制代码
struct Lif
{
  float thresh;
  float beta;
  Lif(const OrtApi *ort_api, const OrtKernelInfo *info)
  {
    // 从这里获取对应名字的属性,并赋值到结构体中声明的属性变量以便Compute函数调用
    auto ret = ort_api->KernelInfoGetAttribute_float(info, "thresh", &thresh);
    ret = ort_api->KernelInfoGetAttribute_float(info, "beta", &beta);
  }
  void Compute(const Ort::Custom::Tensor<float> &X,
               Ort::Custom::Tensor<float> &Z)
  {
    auto input_shape = X.Shape();
    auto x_raw = X.Data();
    auto z_raw = Z.Allocate(input_shape);
    int CIN = input_shape[1];
    int T = input_shape[2];

    // ASSUME THAT THE INPUT HAS ONLY ONE BATCH
    int idx = 0;
    for (int c = 0; c < CIN; c++)
    {
      float mem_ = 0;
      for (int t = 0; t < T; t++)
      {
        mem_ = mem_ * beta + x_raw[idx];
        if (mem_ > thresh)
        {
          mem_ -= thresh;
          z_raw[idx] = 1;
        }
        else
        {
          z_raw[idx] = 0;
        }
        idx++;
      }
    }
  }
};

onnxruntime通过创建的session来执行模型推理,session_options可以用于指定session执行时的算子实现:

cpp 复制代码
#include <onnxruntime_cxx_api.h>
#include <onnxruntime_lite_custom_op.h>

Ort::SessionOptions session_options;
// Add Custom Op
Ort::CustomOpDomain custom_domain{"custom"};
std::unique_ptr<Ort::Custom::OrtLiteCustomOp> custom_op_lif{Ort::Custom::CreateLiteCustomOp<Lif>("Lif", "CPUExecutionProvider")};
custom_domain.Add(custom_op_lif.get());
session_options.Add(custom_domain);
// Create Session
Ort::Session session = Ort::Session(env, model_file.c_str(), session_options);
相关推荐
FL16238631293 分钟前
[数据集][目标检测]车油口挡板开关闭合检测数据集VOC+YOLO格式138张2类别
人工智能·yolo·目标检测
YesPMP平台官方5 分钟前
AI+教育|拥抱AI智能科技,让课堂更生动高效
人工智能·科技·ai·数据分析·软件开发·教育
FL162386312931 分钟前
AI健身体能测试之基于paddlehub实现引体向上计数个数统计
人工智能
黑客-雨34 分钟前
构建你的AI职业生涯:从基础知识到专业实践的路线图
人工智能·产品经理·ai大模型·ai产品经理·大模型学习·大模型入门·大模型教程
子午36 分钟前
动物识别系统Python+卷积神经网络算法+TensorFlow+人工智能+图像识别+计算机毕业设计项目
人工智能·python·cnn
AlexMercer10121 小时前
【C++】二、数据类型 (同C)
c语言·开发语言·数据结构·c++·笔记·算法
大耳朵爱学习1 小时前
掌握Transformer之注意力为什么有效
人工智能·深度学习·自然语言处理·大模型·llm·transformer·大语言模型
TAICHIFEI1 小时前
目标检测-数据集
人工智能·目标检测·目标跟踪
qq_15321452641 小时前
【2023工业异常检测文献】SimpleNet
图像处理·人工智能·深度学习·神经网络·机器学习·计算机视觉·视觉检测
洛阳泰山1 小时前
如何使用Chainlit让所有网站快速嵌入一个AI聊天助手Copilot
人工智能·ai·llm·copilot·网站·chainlit·copliot