LibTorch实战四:模型序列化

目录

在C++环境中加载一个TORCHSCRIP

[Step1: 将pytorch模型转为torch scrip类型的模型](#Step1: 将pytorch模型转为torch scrip类型的模型)

[1.1、基于Tracing的方法来转换为Torch Script](#1.1、基于Tracing的方法来转换为Torch Script)

[1.2、基于Annotating (Script)的方法来转换为Torch Script](#1.2、基于Annotating (Script)的方法来转换为Torch Script)

[Step2: 序列化torch.jit.ScriptModule类型的对象,并保存为文件](#Step2: 序列化torch.jit.ScriptModule类型的对象,并保存为文件)

[Step3: 在libtorch中加载ScriptModule模型](#Step3: 在libtorch中加载ScriptModule模型)

总结


在C++环境中加载一个TORCHSCRIP

一般地,类似python的脚本语言可用于算法快速实现、验证;但在产品化过程中,一般采用效率更高的C++语言,下面的工作就是将模型从python环境中移植到c++环境。

Step1: 将pytorch模型转为torch scrip类型的模型

通过TorchSript,我们可将pytorch模型从python转为c++。那么,什么是TorchScript呢?其实,它也是Pytorch模型的一种,这种模型能够被TorchScript的编译器识别读取、序列化。一般地,在处理模型过程中,我们都会先将模型转为torch script格式,例如:".pt" -> "yolov5x.torchscript.pt"。转为torchscript格式有两种方法:一是函数torch.jit.trace;二是函数torch.jit.script。

torch.jit.trace原理:基于跟踪机制,需要输入一张图(0矩阵、张量亦可),模型会对输入的tensor进行处理,并记录所有张量的操作,torch::jit::trace能够捕获模型的结构、参数并保存。由于跟踪仅记录张量上的操作,因此它不会记录任何控制流操作,如if语句或循环。

torch.jit.script原理:需要开发者先定义好神经网络模型结构,即:提前写好 class MyModule(torch.nn.Module),这样TorchScript可以根据定义好的MyModule来解析网络结构。

1.1、基于Tracing的方法来转换为Torch Script

如下代码,给 torch.jit.trace 函数输入一个指定size的随机张量、ResNet18的网络模型,得到一个类型为 torch.jit.ScriptModule 的对象,即:traced_script_module

python 复制代码
import torch
import torchvision

# An instance of your model.
model = torchvision.models.resnet18()

# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)

# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)

经过上述处理,traced_script_module变量已经包含网络的结构和参数,可以直接用于推理,如下代码:

python 复制代码
1 In[1]: output = traced_script_module(torch.ones(1, 3, 224, 224))
2 In[2]: output[0, :5]
3 Out[2]: tensor([-0.2698, -0.0381,  0.4023, -0.3010, -0.0448], grad_fn=<SliceBackward>)

1.2、基于Annotating (Script)的方法来转换为Torch Script

如果你的模型中有类似于控制流操作(例如:if or for循环),基于上述tracing的方式不再适用,这种方式会排上用场,下面以vanilla模型为例子,注:下面网络结构中有个if判断。

python 复制代码
# 定义一个vanilla模型
import torch

class MyModule(torch.nn.Module):
    def __init__(self, N, M):
        super(MyModule, self).__init__()
        self.weight = torch.nn.Parameter(torch.rand(N, M))

    def forward(self, input):
        if input.sum() > 0:
          output = self.weight.mv(input)
        else:
          output = self.weight + input
        return output

这里调用 torch.jit.script 来获取 torch.jit.ScriptModule 类型的对象,即:sm

python 复制代码
class MyModule(torch.nn.Module):
    def __init__(self, N, M):
        super(MyModule, self).__init__()
        self.weight = torch.nn.Parameter(torch.rand(N, M))

    def forward(self, input):
        if input.sum() > 0:
          output = self.weight.mv(input)
        else:
          output = self.weight + input
        return output

my_module = MyModule(10,20)
sm = torch.jit.script(my_module)

Step2: 序列化torch.jit.ScriptModule类型的对象,并保存为文件

注:上述的tacing和script方法都将得到一个类型为torch.jit.ScriptModule的对象(这里简单记为:ScriptModule ),该对象就是常规的前向传播模块。不管是哪一种方法,此时,只需要将ScriptModule进行序列化保存就行。这里保存的是上述基于Tracing得到的ResNet推理模块traced_script_module。

python 复制代码
traced_script_module.save("traced_resnet_model.pt") # 序列化,保存
# 保存后可用工具:https://netron.app/ 进行可视化

同理,如下是保存基于Annotating得到推理模块my_module 后续,在libtorch中加载上述保存的模型文件就行,不再依赖任何python包。

python 复制代码
my_module.save("my_module_model.pt") # 为什么不是sm

Step3: 在libtorch中加载ScriptModule模型

如何配置libtorh?,我这里仅贴下vs环境下的属性表:

python 复制代码
1 include:
2 D:\ThirdParty\libtorch-win-shared-with-deps-1.7.1+cu110\libtorch\include
4 D:\ThirdParty\libtorch-win-shared-with-deps-1.7.1+cu110\libtorch\include\torch\csrc\api\include
7 lib:
8 D:\ThirdParty\libtorch-win-shared-with-deps-1.7.1+cu110\libtorch\lib
9 
11 链接器:
12 c10.lib
13 c10_cuda.lib
14 torch.lib
15 torch_cpu.lib
16 torch_cuda.lib
17 
18 环境变量:
19 D:\ThirdParty\libtorch-win-shared-with-deps-1.7.1+cu110\libtorch\lib

以下c++代码加载上述模型文件

cpp 复制代码
#include<torch/torch.h>
#include<torch/script.h>
#include<iostream>
#include<memory>

int main()
{
    torch::jit::script::Module module;
    std::string str = "traced_resnet_model.pt";
    try
    {
        module = torch::jit::load(str);
    }
    catch (const c10::Error& e)
    {
        std::cerr << "12313";
        return -1;
    }

    // 创建一个输入
    std::vector<torch::jit::IValue> inputs;
    inputs.push_back(torch::ones({ 1, 3, 224, 224 }));
    // 推理
    at::Tensor output = module.forward(inputs).toTensor();
    std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';

    return 1;
}

总结

python模型的序列化、保存代码:

python 复制代码
import torchvision
import torch

model = torchvision.models.resnet18()

example = torch.rand(1, 3, 224, 224)

traced_script_module = torch.jit.trace(model, example)

output = traced_script_module(torch.ones(1, 3, 224, 224))

#traced_script_module.save("traced_resnet_model.pt") # 和下面等价,格式名称不同,仅此而已,在libtorch中是一样的
traced_script_module.save("traced_resnet_model.torchscript.pt")
print()

libtorch的模型加载,推理代码:

cpp 复制代码
#include<torch/torch.h>
#include<torch/script.h>
#include<iostream>
#include<memory>

int main()
{
    torch::jit::script::Module module;
    std::string str = "traced_resnet_model.pt";
    //std::string str = "traced_resnet_model.torchscript.pt"; // 和上面等价,模型格式而已
    try
    {
        module = torch::jit::load(str);
    }
    catch (const c10::Error& e)
    {
        std::cerr << "12313";
        return -1;
    }

    // 创建一个输入
    std::vector<torch::jit::IValue> inputs;
    inputs.push_back(torch::ones({ 1, 3, 224, 224 }));
    // 推理
    at::Tensor output = module.forward(inputs).toTensor();
    std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';

    return 1;
}
相关推荐
贝多财经2 分钟前
千里科技报考港股上市:高度依赖吉利,AI智驾转型收入仍为零
大数据·人工智能·科技
嵌入式-老费13 分钟前
自己动手写深度学习框架(pytorch入门)
人工智能·pytorch·深度学习
irisMoon0622 分钟前
yolov5单目测距+速度测量+目标跟踪
人工智能·yolo·目标跟踪
Linux猿25 分钟前
365科技简报 2025年11月13日 星期四
人工智能·科技简报
终端域名32 分钟前
当今前沿科技:脑机共生界面(脑机接口)深度解析
人工智能·智能电视
麦麦大数据33 分钟前
F047 vue3+flask微博舆情推荐可视化问答系统
python·flask·知识图谱·neo4j·推荐算法·舆情分析·舆情监测
MediaTea36 分钟前
Python 第三方库:Flask(轻量级 Web 框架)
开发语言·前端·后端·python·flask
java干货1 小时前
Spring Boot 为什么“抛弃”了 spring.factories?
spring boot·python·spring
2501_941111821 小时前
使用Python进行网络设备自动配置
jvm·数据库·python
源码之家1 小时前
基于python租房大数据分析系统 房屋数据分析推荐 scrapy爬虫+可视化大屏 贝壳租房网 计算机毕业设计 推荐系统(源码+文档)✅
大数据·爬虫·python·scrapy·数据分析·推荐算法·租房