LibTorch实战四:模型序列化

目录

在C++环境中加载一个TORCHSCRIP

[Step1: 将pytorch模型转为torch scrip类型的模型](#Step1: 将pytorch模型转为torch scrip类型的模型)

[1.1、基于Tracing的方法来转换为Torch Script](#1.1、基于Tracing的方法来转换为Torch Script)

[1.2、基于Annotating (Script)的方法来转换为Torch Script](#1.2、基于Annotating (Script)的方法来转换为Torch Script)

[Step2: 序列化torch.jit.ScriptModule类型的对象,并保存为文件](#Step2: 序列化torch.jit.ScriptModule类型的对象,并保存为文件)

[Step3: 在libtorch中加载ScriptModule模型](#Step3: 在libtorch中加载ScriptModule模型)

总结


在C++环境中加载一个TORCHSCRIP

一般地,类似python的脚本语言可用于算法快速实现、验证;但在产品化过程中,一般采用效率更高的C++语言,下面的工作就是将模型从python环境中移植到c++环境。

Step1: 将pytorch模型转为torch scrip类型的模型

通过TorchSript,我们可将pytorch模型从python转为c++。那么,什么是TorchScript呢?其实,它也是Pytorch模型的一种,这种模型能够被TorchScript的编译器识别读取、序列化。一般地,在处理模型过程中,我们都会先将模型转为torch script格式,例如:".pt" -> "yolov5x.torchscript.pt"。转为torchscript格式有两种方法:一是函数torch.jit.trace;二是函数torch.jit.script。

torch.jit.trace原理:基于跟踪机制,需要输入一张图(0矩阵、张量亦可),模型会对输入的tensor进行处理,并记录所有张量的操作,torch::jit::trace能够捕获模型的结构、参数并保存。由于跟踪仅记录张量上的操作,因此它不会记录任何控制流操作,如if语句或循环。

torch.jit.script原理:需要开发者先定义好神经网络模型结构,即:提前写好 class MyModule(torch.nn.Module),这样TorchScript可以根据定义好的MyModule来解析网络结构。

1.1、基于Tracing的方法来转换为Torch Script

如下代码,给 torch.jit.trace 函数输入一个指定size的随机张量、ResNet18的网络模型,得到一个类型为 torch.jit.ScriptModule 的对象,即:traced_script_module

python 复制代码
import torch
import torchvision

# An instance of your model.
model = torchvision.models.resnet18()

# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)

# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)

经过上述处理,traced_script_module变量已经包含网络的结构和参数,可以直接用于推理,如下代码:

python 复制代码
1 In[1]: output = traced_script_module(torch.ones(1, 3, 224, 224))
2 In[2]: output[0, :5]
3 Out[2]: tensor([-0.2698, -0.0381,  0.4023, -0.3010, -0.0448], grad_fn=<SliceBackward>)

1.2、基于Annotating (Script)的方法来转换为Torch Script

如果你的模型中有类似于控制流操作(例如:if or for循环),基于上述tracing的方式不再适用,这种方式会排上用场,下面以vanilla模型为例子,注:下面网络结构中有个if判断。

python 复制代码
# 定义一个vanilla模型
import torch

class MyModule(torch.nn.Module):
    def __init__(self, N, M):
        super(MyModule, self).__init__()
        self.weight = torch.nn.Parameter(torch.rand(N, M))

    def forward(self, input):
        if input.sum() > 0:
          output = self.weight.mv(input)
        else:
          output = self.weight + input
        return output

这里调用 torch.jit.script 来获取 torch.jit.ScriptModule 类型的对象,即:sm

python 复制代码
class MyModule(torch.nn.Module):
    def __init__(self, N, M):
        super(MyModule, self).__init__()
        self.weight = torch.nn.Parameter(torch.rand(N, M))

    def forward(self, input):
        if input.sum() > 0:
          output = self.weight.mv(input)
        else:
          output = self.weight + input
        return output

my_module = MyModule(10,20)
sm = torch.jit.script(my_module)

Step2: 序列化torch.jit.ScriptModule类型的对象,并保存为文件

注:上述的tacing和script方法都将得到一个类型为torch.jit.ScriptModule的对象(这里简单记为:ScriptModule ),该对象就是常规的前向传播模块。不管是哪一种方法,此时,只需要将ScriptModule进行序列化保存就行。这里保存的是上述基于Tracing得到的ResNet推理模块traced_script_module。

python 复制代码
traced_script_module.save("traced_resnet_model.pt") # 序列化,保存
# 保存后可用工具:https://netron.app/ 进行可视化

同理,如下是保存基于Annotating得到推理模块my_module 后续,在libtorch中加载上述保存的模型文件就行,不再依赖任何python包。

python 复制代码
my_module.save("my_module_model.pt") # 为什么不是sm

Step3: 在libtorch中加载ScriptModule模型

如何配置libtorh?,我这里仅贴下vs环境下的属性表:

python 复制代码
1 include:
2 D:\ThirdParty\libtorch-win-shared-with-deps-1.7.1+cu110\libtorch\include
4 D:\ThirdParty\libtorch-win-shared-with-deps-1.7.1+cu110\libtorch\include\torch\csrc\api\include
7 lib:
8 D:\ThirdParty\libtorch-win-shared-with-deps-1.7.1+cu110\libtorch\lib
9 
11 链接器:
12 c10.lib
13 c10_cuda.lib
14 torch.lib
15 torch_cpu.lib
16 torch_cuda.lib
17 
18 环境变量:
19 D:\ThirdParty\libtorch-win-shared-with-deps-1.7.1+cu110\libtorch\lib

以下c++代码加载上述模型文件

cpp 复制代码
#include<torch/torch.h>
#include<torch/script.h>
#include<iostream>
#include<memory>

int main()
{
    torch::jit::script::Module module;
    std::string str = "traced_resnet_model.pt";
    try
    {
        module = torch::jit::load(str);
    }
    catch (const c10::Error& e)
    {
        std::cerr << "12313";
        return -1;
    }

    // 创建一个输入
    std::vector<torch::jit::IValue> inputs;
    inputs.push_back(torch::ones({ 1, 3, 224, 224 }));
    // 推理
    at::Tensor output = module.forward(inputs).toTensor();
    std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';

    return 1;
}

总结

python模型的序列化、保存代码:

python 复制代码
import torchvision
import torch

model = torchvision.models.resnet18()

example = torch.rand(1, 3, 224, 224)

traced_script_module = torch.jit.trace(model, example)

output = traced_script_module(torch.ones(1, 3, 224, 224))

#traced_script_module.save("traced_resnet_model.pt") # 和下面等价,格式名称不同,仅此而已,在libtorch中是一样的
traced_script_module.save("traced_resnet_model.torchscript.pt")
print()

libtorch的模型加载,推理代码:

cpp 复制代码
#include<torch/torch.h>
#include<torch/script.h>
#include<iostream>
#include<memory>

int main()
{
    torch::jit::script::Module module;
    std::string str = "traced_resnet_model.pt";
    //std::string str = "traced_resnet_model.torchscript.pt"; // 和上面等价,模型格式而已
    try
    {
        module = torch::jit::load(str);
    }
    catch (const c10::Error& e)
    {
        std::cerr << "12313";
        return -1;
    }

    // 创建一个输入
    std::vector<torch::jit::IValue> inputs;
    inputs.push_back(torch::ones({ 1, 3, 224, 224 }));
    // 推理
    at::Tensor output = module.forward(inputs).toTensor();
    std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';

    return 1;
}
相关推荐
__lost6 分钟前
Python图像变清晰与锐化,调整对比度,高斯滤波除躁,卷积锐化,中值滤波钝化,神经网络变清晰
python·opencv·计算机视觉
海绵波波10712 分钟前
玉米产量遥感估产系统的开发实践(持续迭代与更新)
python·flask
欣然~15 分钟前
借助 OpenCV 和 PyTorch 库,利用卷积神经网络提取图像边缘特征
人工智能·计算机视觉
谦行24 分钟前
工欲善其事,必先利其器—— PyTorch 深度学习基础操作
pytorch·深度学习·ai编程
逢生博客1 小时前
使用 Python 项目管理工具 uv 快速创建 MCP 服务(Cherry Studio、Trae 添加 MCP 服务)
python·sqlite·uv·deepseek·trae·cherry studio·mcp服务
堕落似梦1 小时前
Pydantic增强SQLALchemy序列化(FastAPI直接输出SQLALchemy查询集)
python
白熊1881 小时前
【计算机视觉】CV实战项目 - 基于YOLOv5的人脸检测与关键点定位系统深度解析
人工智能·yolo·计算机视觉
nenchoumi31191 小时前
VLA 论文精读(十六)FP3: A 3D Foundation Policy for Robotic Manipulation
论文阅读·人工智能·笔记·学习·vln
后端小肥肠1 小时前
文案号搞钱潜规则:日入四位数的Coze工作流我跑通了
人工智能·coze
LCHub低代码社区1 小时前
钧瓷产业原始创新的许昌共识:技术破壁·产业再造·生态重构(一)
大数据·人工智能·维格云·ai智能体·ai自动化·大禹智库·钧瓷码