LibTorch实战四:模型序列化

目录

在C++环境中加载一个TORCHSCRIP

[Step1: 将pytorch模型转为torch scrip类型的模型](#Step1: 将pytorch模型转为torch scrip类型的模型)

[1.1、基于Tracing的方法来转换为Torch Script](#1.1、基于Tracing的方法来转换为Torch Script)

[1.2、基于Annotating (Script)的方法来转换为Torch Script](#1.2、基于Annotating (Script)的方法来转换为Torch Script)

[Step2: 序列化torch.jit.ScriptModule类型的对象,并保存为文件](#Step2: 序列化torch.jit.ScriptModule类型的对象,并保存为文件)

[Step3: 在libtorch中加载ScriptModule模型](#Step3: 在libtorch中加载ScriptModule模型)

总结


在C++环境中加载一个TORCHSCRIP

一般地,类似python的脚本语言可用于算法快速实现、验证;但在产品化过程中,一般采用效率更高的C++语言,下面的工作就是将模型从python环境中移植到c++环境。

Step1: 将pytorch模型转为torch scrip类型的模型

通过TorchSript,我们可将pytorch模型从python转为c++。那么,什么是TorchScript呢?其实,它也是Pytorch模型的一种,这种模型能够被TorchScript的编译器识别读取、序列化。一般地,在处理模型过程中,我们都会先将模型转为torch script格式,例如:".pt" -> "yolov5x.torchscript.pt"。转为torchscript格式有两种方法:一是函数torch.jit.trace;二是函数torch.jit.script。

torch.jit.trace原理:基于跟踪机制,需要输入一张图(0矩阵、张量亦可),模型会对输入的tensor进行处理,并记录所有张量的操作,torch::jit::trace能够捕获模型的结构、参数并保存。由于跟踪仅记录张量上的操作,因此它不会记录任何控制流操作,如if语句或循环。

torch.jit.script原理:需要开发者先定义好神经网络模型结构,即:提前写好 class MyModule(torch.nn.Module),这样TorchScript可以根据定义好的MyModule来解析网络结构。

1.1、基于Tracing的方法来转换为Torch Script

如下代码,给 torch.jit.trace 函数输入一个指定size的随机张量、ResNet18的网络模型,得到一个类型为 torch.jit.ScriptModule 的对象,即:traced_script_module

python 复制代码
import torch
import torchvision

# An instance of your model.
model = torchvision.models.resnet18()

# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)

# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)

经过上述处理,traced_script_module变量已经包含网络的结构和参数,可以直接用于推理,如下代码:

python 复制代码
1 In[1]: output = traced_script_module(torch.ones(1, 3, 224, 224))
2 In[2]: output[0, :5]
3 Out[2]: tensor([-0.2698, -0.0381,  0.4023, -0.3010, -0.0448], grad_fn=<SliceBackward>)

1.2、基于Annotating (Script)的方法来转换为Torch Script

如果你的模型中有类似于控制流操作(例如:if or for循环),基于上述tracing的方式不再适用,这种方式会排上用场,下面以vanilla模型为例子,注:下面网络结构中有个if判断。

python 复制代码
# 定义一个vanilla模型
import torch

class MyModule(torch.nn.Module):
    def __init__(self, N, M):
        super(MyModule, self).__init__()
        self.weight = torch.nn.Parameter(torch.rand(N, M))

    def forward(self, input):
        if input.sum() > 0:
          output = self.weight.mv(input)
        else:
          output = self.weight + input
        return output

这里调用 torch.jit.script 来获取 torch.jit.ScriptModule 类型的对象,即:sm

python 复制代码
class MyModule(torch.nn.Module):
    def __init__(self, N, M):
        super(MyModule, self).__init__()
        self.weight = torch.nn.Parameter(torch.rand(N, M))

    def forward(self, input):
        if input.sum() > 0:
          output = self.weight.mv(input)
        else:
          output = self.weight + input
        return output

my_module = MyModule(10,20)
sm = torch.jit.script(my_module)

Step2: 序列化torch.jit.ScriptModule类型的对象,并保存为文件

注:上述的tacing和script方法都将得到一个类型为torch.jit.ScriptModule的对象(这里简单记为:ScriptModule ),该对象就是常规的前向传播模块。不管是哪一种方法,此时,只需要将ScriptModule进行序列化保存就行。这里保存的是上述基于Tracing得到的ResNet推理模块traced_script_module。

python 复制代码
traced_script_module.save("traced_resnet_model.pt") # 序列化,保存
# 保存后可用工具:https://netron.app/ 进行可视化

同理,如下是保存基于Annotating得到推理模块my_module 后续,在libtorch中加载上述保存的模型文件就行,不再依赖任何python包。

python 复制代码
my_module.save("my_module_model.pt") # 为什么不是sm

Step3: 在libtorch中加载ScriptModule模型

如何配置libtorh?,我这里仅贴下vs环境下的属性表:

python 复制代码
1 include:
2 D:\ThirdParty\libtorch-win-shared-with-deps-1.7.1+cu110\libtorch\include
4 D:\ThirdParty\libtorch-win-shared-with-deps-1.7.1+cu110\libtorch\include\torch\csrc\api\include
7 lib:
8 D:\ThirdParty\libtorch-win-shared-with-deps-1.7.1+cu110\libtorch\lib
9 
11 链接器:
12 c10.lib
13 c10_cuda.lib
14 torch.lib
15 torch_cpu.lib
16 torch_cuda.lib
17 
18 环境变量:
19 D:\ThirdParty\libtorch-win-shared-with-deps-1.7.1+cu110\libtorch\lib

以下c++代码加载上述模型文件

cpp 复制代码
#include<torch/torch.h>
#include<torch/script.h>
#include<iostream>
#include<memory>

int main()
{
    torch::jit::script::Module module;
    std::string str = "traced_resnet_model.pt";
    try
    {
        module = torch::jit::load(str);
    }
    catch (const c10::Error& e)
    {
        std::cerr << "12313";
        return -1;
    }

    // 创建一个输入
    std::vector<torch::jit::IValue> inputs;
    inputs.push_back(torch::ones({ 1, 3, 224, 224 }));
    // 推理
    at::Tensor output = module.forward(inputs).toTensor();
    std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';

    return 1;
}

总结

python模型的序列化、保存代码:

python 复制代码
import torchvision
import torch

model = torchvision.models.resnet18()

example = torch.rand(1, 3, 224, 224)

traced_script_module = torch.jit.trace(model, example)

output = traced_script_module(torch.ones(1, 3, 224, 224))

#traced_script_module.save("traced_resnet_model.pt") # 和下面等价,格式名称不同,仅此而已,在libtorch中是一样的
traced_script_module.save("traced_resnet_model.torchscript.pt")
print()

libtorch的模型加载,推理代码:

cpp 复制代码
#include<torch/torch.h>
#include<torch/script.h>
#include<iostream>
#include<memory>

int main()
{
    torch::jit::script::Module module;
    std::string str = "traced_resnet_model.pt";
    //std::string str = "traced_resnet_model.torchscript.pt"; // 和上面等价,模型格式而已
    try
    {
        module = torch::jit::load(str);
    }
    catch (const c10::Error& e)
    {
        std::cerr << "12313";
        return -1;
    }

    // 创建一个输入
    std::vector<torch::jit::IValue> inputs;
    inputs.push_back(torch::ones({ 1, 3, 224, 224 }));
    // 推理
    at::Tensor output = module.forward(inputs).toTensor();
    std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';

    return 1;
}
相关推荐
TMT星球16 分钟前
生数科技携手央视新闻《文博日历》,推动AI视频技术的创新应用
大数据·人工智能·科技
五味香24 分钟前
Java学习,List 元素替换
android·java·开发语言·python·学习·golang·kotlin
AI视觉网奇30 分钟前
图生3d算法学习笔记
人工智能
小锋学长生活大爆炸38 分钟前
【DGL系列】dgl中为graph指定CSR/COO/CSC矩阵格式
人工智能·pytorch·深度学习·图神经网络·gnn·dgl
计算机徐师兄40 分钟前
Python基于Django的花卉商城系统的设计与实现(附源码,文档说明)
python·django·python django·花卉商城系统·花卉·花卉商城·python花卉商城系统
机械心1 小时前
pytorch深度学习模型推理和部署、pytorch&ONNX&tensorRT模型转换以及python和C++版本部署
pytorch·python·深度学习
佛州小李哥1 小时前
在亚马逊云科技上用AI提示词优化功能写出漂亮提示词(上)
人工智能·科技·ai·语言模型·云计算·aws·亚马逊云科技
鸭鸭鸭进京赶烤1 小时前
计算机工程:解锁未来科技之门!
人工智能·科技·opencv·ai·机器人·硬件工程·软件工程
ALISHENGYA1 小时前
精讲Python之turtle库(二):设置画笔颜色、回旋伞、变色回旋伞、黄色三角形、五角星,附源代码
python·turtle
ModelWhale1 小时前
十年筑梦,再创鲸彩!庆祝和鲸科技十周年
人工智能·科技