LibTorch实战四:模型序列化

目录

在C++环境中加载一个TORCHSCRIP

[Step1: 将pytorch模型转为torch scrip类型的模型](#Step1: 将pytorch模型转为torch scrip类型的模型)

[1.1、基于Tracing的方法来转换为Torch Script](#1.1、基于Tracing的方法来转换为Torch Script)

[1.2、基于Annotating (Script)的方法来转换为Torch Script](#1.2、基于Annotating (Script)的方法来转换为Torch Script)

[Step2: 序列化torch.jit.ScriptModule类型的对象,并保存为文件](#Step2: 序列化torch.jit.ScriptModule类型的对象,并保存为文件)

[Step3: 在libtorch中加载ScriptModule模型](#Step3: 在libtorch中加载ScriptModule模型)

总结


在C++环境中加载一个TORCHSCRIP

一般地,类似python的脚本语言可用于算法快速实现、验证;但在产品化过程中,一般采用效率更高的C++语言,下面的工作就是将模型从python环境中移植到c++环境。

Step1: 将pytorch模型转为torch scrip类型的模型

通过TorchSript,我们可将pytorch模型从python转为c++。那么,什么是TorchScript呢?其实,它也是Pytorch模型的一种,这种模型能够被TorchScript的编译器识别读取、序列化。一般地,在处理模型过程中,我们都会先将模型转为torch script格式,例如:".pt" -> "yolov5x.torchscript.pt"。转为torchscript格式有两种方法:一是函数torch.jit.trace;二是函数torch.jit.script。

torch.jit.trace原理:基于跟踪机制,需要输入一张图(0矩阵、张量亦可),模型会对输入的tensor进行处理,并记录所有张量的操作,torch::jit::trace能够捕获模型的结构、参数并保存。由于跟踪仅记录张量上的操作,因此它不会记录任何控制流操作,如if语句或循环。

torch.jit.script原理:需要开发者先定义好神经网络模型结构,即:提前写好 class MyModule(torch.nn.Module),这样TorchScript可以根据定义好的MyModule来解析网络结构。

1.1、基于Tracing的方法来转换为Torch Script

如下代码,给 torch.jit.trace 函数输入一个指定size的随机张量、ResNet18的网络模型,得到一个类型为 torch.jit.ScriptModule 的对象,即:traced_script_module

python 复制代码
import torch
import torchvision

# An instance of your model.
model = torchvision.models.resnet18()

# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)

# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)

经过上述处理,traced_script_module变量已经包含网络的结构和参数,可以直接用于推理,如下代码:

python 复制代码
1 In[1]: output = traced_script_module(torch.ones(1, 3, 224, 224))
2 In[2]: output[0, :5]
3 Out[2]: tensor([-0.2698, -0.0381,  0.4023, -0.3010, -0.0448], grad_fn=<SliceBackward>)

1.2、基于Annotating (Script)的方法来转换为Torch Script

如果你的模型中有类似于控制流操作(例如:if or for循环),基于上述tracing的方式不再适用,这种方式会排上用场,下面以vanilla模型为例子,注:下面网络结构中有个if判断。

python 复制代码
# 定义一个vanilla模型
import torch

class MyModule(torch.nn.Module):
    def __init__(self, N, M):
        super(MyModule, self).__init__()
        self.weight = torch.nn.Parameter(torch.rand(N, M))

    def forward(self, input):
        if input.sum() > 0:
          output = self.weight.mv(input)
        else:
          output = self.weight + input
        return output

这里调用 torch.jit.script 来获取 torch.jit.ScriptModule 类型的对象,即:sm

python 复制代码
class MyModule(torch.nn.Module):
    def __init__(self, N, M):
        super(MyModule, self).__init__()
        self.weight = torch.nn.Parameter(torch.rand(N, M))

    def forward(self, input):
        if input.sum() > 0:
          output = self.weight.mv(input)
        else:
          output = self.weight + input
        return output

my_module = MyModule(10,20)
sm = torch.jit.script(my_module)

Step2: 序列化torch.jit.ScriptModule类型的对象,并保存为文件

注:上述的tacing和script方法都将得到一个类型为torch.jit.ScriptModule的对象(这里简单记为:ScriptModule ),该对象就是常规的前向传播模块。不管是哪一种方法,此时,只需要将ScriptModule进行序列化保存就行。这里保存的是上述基于Tracing得到的ResNet推理模块traced_script_module。

python 复制代码
traced_script_module.save("traced_resnet_model.pt") # 序列化,保存
# 保存后可用工具:https://netron.app/ 进行可视化

同理,如下是保存基于Annotating得到推理模块my_module 后续,在libtorch中加载上述保存的模型文件就行,不再依赖任何python包。

python 复制代码
my_module.save("my_module_model.pt") # 为什么不是sm

Step3: 在libtorch中加载ScriptModule模型

如何配置libtorh?,我这里仅贴下vs环境下的属性表:

python 复制代码
1 include:
2 D:\ThirdParty\libtorch-win-shared-with-deps-1.7.1+cu110\libtorch\include
4 D:\ThirdParty\libtorch-win-shared-with-deps-1.7.1+cu110\libtorch\include\torch\csrc\api\include
7 lib:
8 D:\ThirdParty\libtorch-win-shared-with-deps-1.7.1+cu110\libtorch\lib
9 
11 链接器:
12 c10.lib
13 c10_cuda.lib
14 torch.lib
15 torch_cpu.lib
16 torch_cuda.lib
17 
18 环境变量:
19 D:\ThirdParty\libtorch-win-shared-with-deps-1.7.1+cu110\libtorch\lib

以下c++代码加载上述模型文件

cpp 复制代码
#include<torch/torch.h>
#include<torch/script.h>
#include<iostream>
#include<memory>

int main()
{
    torch::jit::script::Module module;
    std::string str = "traced_resnet_model.pt";
    try
    {
        module = torch::jit::load(str);
    }
    catch (const c10::Error& e)
    {
        std::cerr << "12313";
        return -1;
    }

    // 创建一个输入
    std::vector<torch::jit::IValue> inputs;
    inputs.push_back(torch::ones({ 1, 3, 224, 224 }));
    // 推理
    at::Tensor output = module.forward(inputs).toTensor();
    std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';

    return 1;
}

总结

python模型的序列化、保存代码:

python 复制代码
import torchvision
import torch

model = torchvision.models.resnet18()

example = torch.rand(1, 3, 224, 224)

traced_script_module = torch.jit.trace(model, example)

output = traced_script_module(torch.ones(1, 3, 224, 224))

#traced_script_module.save("traced_resnet_model.pt") # 和下面等价,格式名称不同,仅此而已,在libtorch中是一样的
traced_script_module.save("traced_resnet_model.torchscript.pt")
print()

libtorch的模型加载,推理代码:

cpp 复制代码
#include<torch/torch.h>
#include<torch/script.h>
#include<iostream>
#include<memory>

int main()
{
    torch::jit::script::Module module;
    std::string str = "traced_resnet_model.pt";
    //std::string str = "traced_resnet_model.torchscript.pt"; // 和上面等价,模型格式而已
    try
    {
        module = torch::jit::load(str);
    }
    catch (const c10::Error& e)
    {
        std::cerr << "12313";
        return -1;
    }

    // 创建一个输入
    std::vector<torch::jit::IValue> inputs;
    inputs.push_back(torch::ones({ 1, 3, 224, 224 }));
    // 推理
    at::Tensor output = module.forward(inputs).toTensor();
    std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';

    return 1;
}
相关推荐
知来者逆1 分钟前
探索大型语言模型在文化常识方面的理解能力与局限性
人工智能·gpt·深度学习·语言模型·自然语言处理·chatgpt·llm
Python极客之家38 分钟前
基于深度学习的乳腺癌分类识别与诊断系统
人工智能·深度学习·分类
Mopes__38 分钟前
Python | Leetcode Python题解之第452题用最少数量的箭引爆气球
python·leetcode·题解
AI视觉网奇1 小时前
pymeshlab 学习笔记
开发语言·python
mftang1 小时前
TMR传感器的实现原理和特性介绍
人工智能
纪伊路上盛名在1 小时前
如何初步部署自己的服务器,达到生信分析的及格线
linux·运维·服务器·python·学习·r语言·github
吃什么芹菜卷1 小时前
深度学习:词嵌入embedding和Word2Vec
人工智能·算法·机器学习
计算机源码社1 小时前
分享一个餐饮连锁店点餐系统 餐馆食材采购系统Java、python、php三个版本(源码、调试、LW、开题、PPT)
java·python·php·毕业设计项目·计算机课程设计·计算机毕业设计源码·计算机毕业设计选题
汤兰月1 小时前
Python中的观察者模式:从基础到实战
开发语言·python·观察者模式
chnyi6_ya1 小时前
论文笔记:Online Class-Incremental Continual Learning with Adversarial Shapley Value
论文阅读·人工智能