PyTorch之Torch Script的简单使用

一、参考资料

TorchScript 简介
Torch Script
Loading a TorchScript Model in C++
TorchScript 解读(一):初识 TorchScript
libtorch教程(一)开发环境搭建:VS+libtorch和Qt+libtorch

二、Torch Script模型格式

1. Torch Script简介

Torch Script 是一种序列化和优化 PyTorch 模型的格式,在优化过程中,一个 torch.nn.Module 模型会被转换成 Torch Script 的 torch.jit.ScriptModule 模型。通常,TorchScript 被当成一种中间表示使用。

Torch Script 的主要用途是进行模型部署,需要记录生成一个便于推理优化的 IR,对计算图的编辑通常都是面向性能提升等等,不会给模型本身添加新的功能。

模型格式 支持语言 适用场景
PyTorch model Python 模型训练
Torch Script C++ 模型推理,模型部署

2. 生成Torch Script模型

如何将PyTorch model格式转换为Torch Script,有两种方式: torch.jit.tracetorch.jit.script

As its name suggests, the primary interface to PyTorch is the Python programming language. While Python is a suitable and preferred language for many scenarios requiring dynamism and ease of iteration, there are equally many situations where precisely these properties of Python are unfavorable. One environment in which the latter often applies is production -- the land of low latencies and strict deployment requirements. For production scenarios, C++ is very often the language of choice, even if only to bind it into another language like Java, Rust or Go. The following paragraphs will outline the path PyTorch provides to go from an existing Python model to a serialized representation that can be loaded and executed purely from C++, with no dependency on Python.

A PyTorch model's journey from Python to C++ is enabled by Torch Script, a representation of a PyTorch model that can be understood, compiled and serialized by the Torch Script compiler. If you are starting out from an existing PyTorch model written in the vanilla "eager" API, you must first convert your model to Torch Script.

There exist two ways of converting a PyTorch model to Torch Script. The first is known as tracing, a mechanism in which the structure of the model is captured by evaluating it once using example inputs, and recording the flow of those inputs through the model. This is suitable for models that make limited use of control flow. The second approach is to add explicit annotations to your model that inform the Torch Script compiler that it may directly parse and compile your model code, subject to the constraints imposed by the Torch Script language.

2.1 trace跟踪模式

Converting to Torch Script via Tracing
How to convert your PyTorch model to TorchScript

功能:将不带控制流的模型转换为 Torch Script,并生成一个 ScriptModule 对象。

函数原型:torch.jit.trace

所谓 trace 指的是进行一次模型推理,在推理的过程中记录所有经过的计算,将这些记录整合成计算图,即模型的静态图。

trace跟踪模式的缺点是:无法识别出模型中的控制流(如循环)。

To convert a PyTorch model to Torch Script via tracing, you must pass an instance of your model along with an example input to the torch.jit.trace function. This will produce a torch.jit.ScriptModule object with the trace of your model evaluation embedded in the module's forward method.

python 复制代码
import torch
import torchvision


# An instance of your model.
model = torchvision.models.resnet18(pretrained=True)

# Switch the model to eval model
model.eval()

# An example input you would normally provide to your model's forward() method.
dummy_input = torch.rand(1, 3, 224, 224)

# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, dummy_input)

# Save the TorchScript model
traced_script_module.save("traced_resnet18_model.pt")

output = traced_script_module(torch.ones(1, 3, 224, 224))

# IR中间表示
print(traced_script_module.graph)
print(traced_script_module.code)

# 调用traced_cell会产生与 Python 模块相同的结果
print(model(x, h))
print(traced_script_module(x, h))

2.2 script记录模式(带控制流)

功能:将带控制流的模型转换为 Torch Script,并生成一个 ScriptModule 对象。

函数原型:torch.jit.script

script记录模式,通过解析模型来正确记录所有的控制流。script记录模式直接解析网络定义的 python 代码,生成抽象语法树 AST。

Because the forward method of this module uses control flow that is dependent on the input, it is not suitable for tracing. Instead, we can convert it to a ScriptModule. In order to convert the module to the ScriptModule, one needs to compile the module with torch.jit.script as follows.

python 复制代码
class MyModule(torch.nn.Module):
    def __init__(self, N, M):
        super(MyModule, self).__init__()
        self.weight = torch.nn.Parameter(torch.rand(N, M))

    def forward(self, input):
        if input.sum() > 0:
          output = self.weight.mv(input)
        else:
          output = self.weight + input
        return output

    
my_module = MyModule(10,20)
sm = torch.jit.script(my_module)

# Save the ScriptModule
sm.save("my_module_model.pt")

2.3 trace格式转换

python 复制代码
import torch
import torchvision
from unet import UNet


model = UNet(3, 2) 
model.load_state_dict(torch.load("best_weights.pth"))
model.eval()

example = torch.rand(1, 3, 320, 480) 
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("model.pt")

三、Loading a TorchScript Model in C++

1. C++加载 ScriptModule

创建一个简单的程序,目录结构如下所示:

bash 复制代码
example-app/
  CMakeLists.txt
  example-app.cpp

1.1 example-app.cpp

c++ 复制代码
#include <torch/script.h> // One-stop header.

#include <iostream>
#include <memory>

int main(int argc, const char* argv[]) {
  if (argc != 2) {
    std::cerr << "usage: example-app <path-to-exported-script-module>\n";
    return -1;
  }


  torch::jit::script::Module module;
  try {
    // Deserialize the ScriptModule from a file using torch::jit::load().
    module = torch::jit::load(argv[1]);
  }
  catch (const c10::Error& e) {
    std::cerr << "error loading the model\n";
    return -1;
  }

  std::cout << "ok\n";
}

The <torch/script.h> header encompasses all relevant includes from the LibTorch library necessary to run the example. Our application accepts the file path to a serialized PyTorch ScriptModule as its only command line argument and then proceeds to deserialize the module using the torch::jit::load() function, which takes this file path as input. In return we receive a torch::jit::script::Module object.

1.2 CMakeLists.txt

cmake 复制代码
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(custom_ops)

find_package(Torch REQUIRED)

add_executable(example-app example-app.cpp)
target_link_libraries(example-app "${TORCH_LIBRARIES}")
set_property(TARGET example-app PROPERTY CXX_STANDARD 14)

1.3 编译执行

bash 复制代码
mkdir build
cd build
cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
cmake --build . --config Release
make

输出结果

bash 复制代码
root@4b5a67132e81:/example-app# mkdir build
root@4b5a67132e81:/example-app# cd build
root@4b5a67132e81:/example-app/build# cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
-- The C compiler identification is GNU 5.4.0
-- The CXX compiler identification is GNU 5.4.0
-- Check for working C compiler: /usr/bin/cc
-- Check for working C compiler: /usr/bin/cc -- works
-- Detecting C compiler ABI info
-- Detecting C compiler ABI info - done
-- Detecting C compile features
-- Detecting C compile features - done
-- Check for working CXX compiler: /usr/bin/c++
-- Check for working CXX compiler: /usr/bin/c++ -- works
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Looking for pthread.h
-- Looking for pthread.h - found
-- Looking for pthread_create
-- Looking for pthread_create - not found
-- Looking for pthread_create in pthreads
-- Looking for pthread_create in pthreads - not found
-- Looking for pthread_create in pthread
-- Looking for pthread_create in pthread - found
-- Found Threads: TRUE
-- Configuring done
-- Generating done
-- Build files have been written to: /example-app/build
root@4b5a67132e81:/example-app/build# make
Scanning dependencies of target example-app
[ 50%] Building CXX object CMakeFiles/example-app.dir/example-app.cpp.o
[100%] Linking CXX executable example-app
[100%] Built target example-app

1.4 执行结果

If we supply the path to the traced ResNet18 model traced_resnet_model.pt we created earlier to the resulting example-app binary, we should be rewarded with a friendly "ok". Please note, if try to run this example with my_module_model.pt you will get an error saying that your input is of an incompatible shape. my_module_model.pt expects 1D instead of 4D.

bash 复制代码
root@4b5a67132e81:/example-app/build# ./example-app <path_to_model>/traced_resnet_model.pt
ok

2. C++推理ScriptModule

main()函数中添加模型推理的代码:

c++ 复制代码
// Create a vector of inputs.
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones({1, 3, 224, 224}));

// Execute the model and turn its output into a tensor.
at::Tensor output = module.forward(inputs).toTensor();
std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';

编译执行

bash 复制代码
root@4b5a67132e81:/example-app/build# make
Scanning dependencies of target example-app
[ 50%] Building CXX object CMakeFiles/example-app.dir/example-app.cpp.o
[100%] Linking CXX executable example-app
[100%] Built target example-app
root@4b5a67132e81:/example-app/build# ./example-app traced_resnet_model.pt
-0.2698 -0.0381  0.4023 -0.3010 -0.0448
[ Variable[CPUFloatType]{1,5} ]
相关推荐
Hoper.J4 小时前
PyTorch 模型保存与加载的三种常用方式
人工智能·pytorch·python
没有余地 EliasJie5 小时前
Windows Ubuntu下搭建深度学习Pytorch训练框架与转换环境TensorRT
pytorch·windows·深度学习·ubuntu·pycharm·conda·tensorflow
被制作时长两年半的个人练习生7 小时前
【pytorch】权重为0的情况
人工智能·pytorch·深度学习
GarryLau10 小时前
使用pytorch进行迁移学习的两个步骤
pytorch·迁移学习·torchvision
醒了就刷牙16 小时前
56 门控循环单元(GRU)_by《李沐:动手学深度学习v2》pytorch版
pytorch·深度学习·gru
橙子小哥的代码世界16 小时前
【深度学习】05-RNN循环神经网络-02- RNN循环神经网络的发展历史与演化趋势/LSTM/GRU/Transformer
人工智能·pytorch·rnn·深度学习·神经网络·lstm·transformer
最近好楠啊19 小时前
Pytorch实现RNN实验
人工智能·pytorch·rnn
IFTICing19 小时前
【文献阅读】Attention Bottlenecks for Multimodal Fusion
人工智能·pytorch·python·神经网络·学习·模态融合
醒了就刷牙1 天前
67 自注意力_by《李沐:动手学深度学习v2》pytorch版
人工智能·pytorch·深度学习
涩即是Null1 天前
PyTorch实现卷积神经网络CNN
人工智能·pytorch·深度学习·神经网络·卷积神经网络