使用PyTorch导出JIT模型:C++ API与libtorch实战

PyTorch导出JIT模型并用C++ API libtorch调用

本文将介绍如何将一个 PyTorch 模型导出为 JIT 模型并用 PyTorch 的 C++API libtorch运行这个模型。

Step1:导出模型

首先我们进行第一步,用 Python API 来导出模型,由于本文的重点是在后面的部署阶段,因此,模型的训练就不进行了,直接对 torchvision 中自带的 ResNet50 进行导出。在实际应用中,大家可以对自己训练好的模型进行导出。

# export_jit_model.py
import torch
import torchvision.models as models

model = models.resnet50(pretrained=True)
model.eval()

example_input = torch.rand(1, 3, 224, 224)

jit_model = torch.jit.trace(model, example_input)
torch.jit.save(jit_model, 'resnet50_jit.pth')

导出 JIT 模型的方式有两种:trace 和 script。

我们采用
torch.jit.trace

的方式来导出 JIT 模型,这种方式会根据一个输入将模型跑一遍,然后记录下执行过程。这种方式的问题在于对于有分支判断的模型不能很好的应对,因为一个输入不能覆盖到所有的分支。但是在我们 ResNet50 模型中不会遇到分支判断,因此这里是合适的。关于两种导出 JIT 模型的方式各自优劣不是本文的中断,以后会再写一篇来分析。

在我们的工程目录
demo

下运行上面的
export_jit_model.py

,会得到一个 JIT 模型件:
resnet50_jit.pth

Step 2:安装libtorch

接下来我们要安装 PyTorch 的 C++ API:libtorch。这一步很简单,直接下载官方预编译的文件并解压即可:

wget https://download.pytorch.org/libtorch/nightly/cpu/libtorch-shared-with-deps-latest.zip
unzip libtorch-shared-with-deps-latest.zip

也解压在我们的工程目录
demo

下即可。

Step 3:安装OpenCV

用 Python 或 C++ 做图像任务,OpenCV 是经常用到的。如果还没有安装的读者可以参考如下在工程目录
demo

下进行安装,构建的过程可能会比较久。已经安装的读者可跳过此步骤,一会儿在
CMakeLists.txt

文件中正确地指定本机的 OpenCV 地址即可。

git clone --branch 3.4 --depth 1 https://github.com/opencv/opencv.git
mkdir demo/build && cd demo/build
cmake ..
make -j 6

Step 4:准备测试图像并用Python测试

我们先准备一张小猫的图像,并用 PyTorch ResNet50 模型正常跑一下,一会儿与我们 C++ 模型运行的结果对比来验证 C++ 模型是否被正确的部署。

kitten.jpg

写一个脚本用 PyTorch 运行一下模型:

# pytorch_test.py

import torchvision.models as models
from torchvision.transforms import transforms
import torch
from PIL import Image

# normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
all_transforms = transforms.Compose([
                    transforms.Resize(224),
                    transforms.ToTensor()])
                    # normalize])

model = models.resnet50(pretrained=True)
model.eval()

img = Image.open('kitten.jpg').convert('RGB')
img_tensor = all_transforms(img).unsqueeze(dim=0)
pred = model(img_tensor).squeeze(dim=0)
print(torch.argmax(pred).item())

输出结果是:282。通过查看
ImageNet 1K 类别名与索引的对应关系

,可以看到,结果为 tiger cat,模型预测正确。一会儿我们看一下部署后的 C++ 模型是否能正确输出结果 282。

Step 5:准备cpp源文件

我们下面准备一会要执行的 cpp 源文件,第一次使用 libtorch 的读者可以先借鉴下面的文件。

这里有几个点要说一下,不注意可能会犯错:

  1. cv::imread()
    默认读取为三通道BGR,需要进行B/R通道交换,这里采用
    cv::cvtColor()
    实现。
  2. 图像尺寸需要调整到

224

×

224

224\times 224

2

2

4

×

2

2

4

,通过
cv::resize()

实现。

  1. opencv读取的图像矩阵存储形式:H x W x C, 但是pytorch中 Tensor的存储为:N x C x H x W, 因此需要进行变换,就是
    np.transpose()

操作,这里使用
tensor.permut()

实现,效果是一样的。

  1. 数据归一化,采用
    tensor.div(255)

实现。

// test_model.cpp
#include <vector>

#include <torch/torch.h>
#include <torch/script.h>

#include <opencv2/core.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <opencv2/highgui/highgui.hpp>

int main(int argc, char* argv[]) {
  // 加载JIT模型
  auto module = torch::jit::load(argv[1]);

  // 加载图像
  auto image = cv::imread(argv[2], cv::ImreadModes::IMREAD_COLOR);
  cv::Mat image_transfomed;
  cv::resize(image, image_transfomed, cv::Size(224, 224));
  cv::cvtColor(image_transfomed, image_transfomed, cv::COLOR_BGR2RGB);

  // 图像转换为Tensor
  torch::Tensor tensor_image = torch::from_blob(image_transfomed.data, {image_transfomed.rows, image_transfomed.cols, 3},torch::kByte);
  tensor_image = tensor_image.permute({2, 0, 1});
  // tensor_image = tensor_image.toType(torch::kFloat);
  tensor_image = tensor_image.div(255.);
  // tensor_image = tensor_image.sub(0.5);
  // tensor_image = tensor_image.div(0.5);

  tensor_image = tensor_image.unsqueeze(0);

  // 运行模型
  torch::Tensor output = module.forward({tensor_image}).toTensor();

  // 结果处理
  int result = output.argmax().item<int>();
  std::cout << "The classifiction index is: " << result << std::endl;
  return 0;
}

Step 6:构建运行验证

我们先来写一下
CMakeLists.txt

cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(resnet50)

find_package(Torch REQUIRED PATHS ./libtorch)
find_package(OpenCV REQUIRED)

add_executable(resnet50  test_model.cpp)
target_link_libraries(resnet50 "${TORCH_LIBRARIES}" "${OpenCV_LIBS}")

set_property(TARGET resnet50  PROPERTY CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD_REQUIRED TRUE)

set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")

现在我们的工程目录
demo

下有以下文件:

CMakeLists.txt  export_jit_model.py  kitten.jpg  libtorch  pytorch_test.py  resnet50_jit.pth  test_model.cpp

然后开始用 CMake 构建工程:

mkdir build && cd build
OpenCV_DIR=[YOUR_PATH_TO_OPENCV]/opencv/build cmake ..
make

整个过程没有报错的话我们就已经构建完成了,会得到一个可执行文件
resnet50

在工程目录
demo

下。

接下来我们执行,并验证运行结果是否与 PyTorch 的结果一致:

./build/resnet50 resnet50_jit.pth kitten.jpg

输出:

The classifiction index is: 282

运行成功并且结果正确。

Ref:

https://www.jianshu.com/p/7cddc09ca7a4

https://blog.csdn.net/cxx654/article/details/115916275

https://zhuanlan.zhihu.com/p/370455320

相关推荐
进击的小小学生7 分钟前
2024年第45周ETF周报
大数据·人工智能
TaoYuan__44 分钟前
机器学习【激活函数】
人工智能·机器学习
TaoYuan__1 小时前
机器学习的常用算法
人工智能·算法·机器学习
正义的彬彬侠1 小时前
协方差矩阵及其计算方法
人工智能·机器学习·协方差·协方差矩阵
致Great1 小时前
Invar-RAG:基于不变性对齐的LLM检索方法提升生成质量
人工智能·大模型·rag
华奥系科技1 小时前
智慧安防丨以科技之力,筑起防范人贩的铜墙铁壁
人工智能·科技·安全·生活
槿花Hibiscus1 小时前
C++基础:Pimpl设计模式的实现
c++·设计模式
ZPC82102 小时前
OpenCV—颜色识别
人工智能·opencv·计算机视觉
Mr.简锋2 小时前
vs2022搭建opencv开发环境
人工智能·opencv·计算机视觉
weixin_443290692 小时前
【论文阅读】InstructPix2Pix: Learning to Follow Image Editing Instructions
论文阅读·人工智能·计算机视觉