1. 简述
使用PyTorch执行训练,使用TensorRT进行部署有很多种方法,比较常用的是基于INetworkDefinition进行每一层的自定义,这样一来,会反向促使研究者能够对真个网络的细节有更深的理解。
另一种相对简便的方式就是通过ONNX中间转换的形式。本文主要针对该途径进行简单的脉络阐述。
2. 导出ONNX
如果使用的是PyTorch训练框架,可采用其自带的ONNX导出API。
torch.onnx.export()
3. 生成推理引擎
使用TensorRT自带的转换工具trtexec执行ONNX到推理引擎的转换工作。
4. 如何确定哪些OP是不被TRT支持的
执行ONNX到TensorRT推理引擎的转换工作时,难免遇到一些不支持的OP,此时可以通过日志等查看,推荐在转换时添加--verbose以获得更多的过程信息。
trtexec --onnx=<model_path>.onnx --verbose
5. 如何处理不被支持的OP
(1)修改模型
如果可能,尝试修改模型架构,使用TensorRT支持的OP代替不支持的OP。
(2)自定义层
TensorRT允许你创建自定义层(也称为插件)来实现那些原生不支持的OP。这需要一些额外的编程工作,但是可以提供一个解决方案来支持特定的操作。
(3)求助、等待支持
随着TensorRT的不断更新和发展,更多的OP将会得到支持。如果可能的话,你可以等待未来版本中对这些OP的支持
6. 如何自定义,并导入
(1)实现自定义插件
假设你需要实现一个自定义ReLU激活函数作为插件。首先,你需要实现插件类,这里给出一个非常基本的示例框架(不包含完整实现细节):
#include "NvInfer.h"
class CustomReLUPlugin: public nvinfer1::IPluginV2IOExt {
public:
// 构造函数、析构函数
CustomReLUPlugin();
~CustomReLUPlugin();
// 实现基类中的虚函数
int getNbOutputs() const override;
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs, int nbInputDims) override;
int enqueue(int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream) override;
// 序列化、反序列化等其他必要的方法
// ...
// 插件的克隆方法
IPluginV2IOExt* clone() const override;
// 插件类型和版本,非常重要,确保与模型中的OP名称对应
const char* getPluginType() const override;
const char* getPluginVersion() const override;
};
// 实现插件的创建函数,TensorRT在解析模型时会调用此函数来创建插件实例
extern "C" nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) {
return new CustomReLUPlugin();
}
(2)注意事项
要确保ONNX模型中的特定操作(OP)能够与自定义插件正确关联,你需要在模型转换或创建时指定对应的插件。这通常涉及以下几点:
插件类型和版本:自定义插件类中的getPluginType 和getPluginVersion方法返回的字符串应与你在模型中指定的自定义OP名称相匹配。
模型转换工具:在某些情况下,可能需要使用如onnx2trt这样的工具将ONNX模型转换为TensorRT模型,并在此过程中指明哪些OP应当使用哪些插件。
对于直接使用trtexec的情况,确保ONNX模型中引用的自定义OP名称与插件的getPluginType方法返回的名称一致。这样,当trtexec解析模型时,它会自动查找并使用相应名称的插件处理指定的OP。
(3)编译自定义插件
利用如下编译指令,将上述自定义插件编译成共享库(.so文件)。如何,编译后得到CustomReLUPlugin.so。需要注意的是,编译时,需要链接TensorRT库。
g++ -std=c++11 -shared -o CustomReLUPlugin.so CustomReLUPlugin.cpp -I<path_to_tensorrt_include> -L<path_to_tensorrt_lib> -lnvinfer
(4)调用自定义插件
使用trtexec转换ONNX模型时,通过"--plugins"指定并调用自定义插件共享库。
trtexec --onnx=my_model.onnx --plugins=CustomReLUPlugin.so
(5)链接多个自定义插件
当有多个自定义插件需要链接时,可通过多次调用"--plugins"引用,也可以通过一个"--plugins"后面跟多个共享库,共享库之间由逗号隔开。
trtexec --onnx=my_model.onnx --plugins=CustomPluginA.so --plugins=CustomPluginB.so
trtexec --onnx=my_model.onnx --plugins=CustomPluginA.so,CustomPluginB.so
7. 部署、执行推理
通过trtexec生成的推理引擎文件一般是.plan或.engine文件后缀(也有其他自定义后缀)。此时可以通过TensorRT的C++接口进行加载和推理。
(1)包含必要的头文件
#include <NvInfer.h>
#include <NvInferRuntime.h>
#include <NvInferRuntimeCommon.h>
(2)创建运行时和引擎
使用运行时接口"IRuntime"来反序列化引擎文件,并创建推理引擎"ICudaEngine"。
nvinfer1::IRuntime* runtime = nvinfer1::createInferRuntime(gLogger);
assert(runtime != nullptr);
std::string engineFilePath = "path/to/your/engine.plan";
std::ifstream engineFile(engineFilePath, std::ios::binary);
if (!engineFile) {
std::cerr << "Failed to open engine file: " << engineFilePath << std::endl;
return;
}
engineFile.seekg(0, engineFile.end);
long int fsize = engineFile.tellg();
engineFile.seekg(0, engineFile.beg);
std::vector<char> engineData(fsize);
engineFile.read(engineData.data(), fsize);
if (!engineFile) {
std::cerr << "Failed to read engine file: " << engineFilePath << std::endl;
return;
}
nvinfer1::ICudaEngine* engine = runtime->deserializeCudaEngine(engineData.data(), fsize, nullptr);
assert(engine != nullptr);
(3)创建执行上下文
创建一个执行上下文"IexecutionContext",用于管理治理的执行。
nvinfer1::IExecutionContext* context = engine->createExecutionContext();
assert(context != nullptr);
(4)准备输入输出缓冲区
在GPU上为输入和输出分配缓冲区。
void* buffers[engine->getNbBindings()];
for (int i = 0; i < engine->getNbBindings(); ++i) {
auto bindingSize = getSizeByDataType(engine->getBindingDataType(i)) * volume(engine->getBindingDimensions(i));
cudaMalloc(&buffers[i], bindingSize);
}
(5)执行推理
使用执行上下文执行推理任务。
context->executeV2(buffers);
(6)处理输出
推理完成后,从输出缓冲区读取网络推理结果,并根据实际情况,执行后处理、解析等操作。
(7)释放资源
整个进程执行结束,释放所有资源。
for (int i = 0; i < engine->getNbBindings(); ++i) {
cudaFree(buffers[i]);
}
context->destroy();
engine->destroy();
runtime->destroy();