PyTorch训练,TensorRT部署的简要步骤(采用ONNX中转的方式)

1. 简述

使用PyTorch执行训练,使用TensorRT进行部署有很多种方法,比较常用的是基于INetworkDefinition进行每一层的自定义,这样一来,会反向促使研究者能够对真个网络的细节有更深的理解。

另一种相对简便的方式就是通过ONNX中间转换的形式。本文主要针对该途径进行简单的脉络阐述。

2. 导出ONNX

如果使用的是PyTorch训练框架,可采用其自带的ONNX导出API。

复制代码
torch.onnx.export()

3. 生成推理引擎

使用TensorRT自带的转换工具trtexec执行ONNX到推理引擎的转换工作。

4. 如何确定哪些OP是不被TRT支持的

执行ONNX到TensorRT推理引擎的转换工作时,难免遇到一些不支持的OP,此时可以通过日志等查看,推荐在转换时添加--verbose以获得更多的过程信息。

复制代码
trtexec --onnx=<model_path>.onnx --verbose

5. 如何处理不被支持的OP

(1)修改模型

如果可能,尝试修改模型架构,使用TensorRT支持的OP代替不支持的OP。

(2)自定义层

TensorRT允许你创建自定义层(也称为插件)来实现那些原生不支持的OP。这需要一些额外的编程工作,但是可以提供一个解决方案来支持特定的操作。

(3)求助、等待支持

随着TensorRT的不断更新和发展,更多的OP将会得到支持。如果可能的话,你可以等待未来版本中对这些OP的支持

6. 如何自定义,并导入

(1)实现自定义插件

假设你需要实现一个自定义ReLU激活函数作为插件。首先,你需要实现插件类,这里给出一个非常基本的示例框架(不包含完整实现细节):

复制代码
#include "NvInfer.h"


class CustomReLUPlugin: public nvinfer1::IPluginV2IOExt {

public:

    // 构造函数、析构函数

    CustomReLUPlugin();

    ~CustomReLUPlugin();

    // 实现基类中的虚函数

    int getNbOutputs() const override;

    nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs, int nbInputDims) override;

    int enqueue(int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream) override;

    // 序列化、反序列化等其他必要的方法

    // ...

    // 插件的克隆方法

    IPluginV2IOExt* clone() const override;

    // 插件类型和版本,非常重要,确保与模型中的OP名称对应

    const char* getPluginType() const override;

    const char* getPluginVersion() const override;

};

// 实现插件的创建函数,TensorRT在解析模型时会调用此函数来创建插件实例

extern "C" nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) {

    return new CustomReLUPlugin();

}

(2)注意事项

要确保ONNX模型中的特定操作(OP)能够与自定义插件正确关联,你需要在模型转换或创建时指定对应的插件。这通常涉及以下几点:

插件类型和版本:自定义插件类中的getPluginTypegetPluginVersion方法返回的字符串应与你在模型中指定的自定义OP名称相匹配。

模型转换工具:在某些情况下,可能需要使用如onnx2trt这样的工具将ONNX模型转换为TensorRT模型,并在此过程中指明哪些OP应当使用哪些插件。

对于直接使用trtexec的情况,确保ONNX模型中引用的自定义OP名称与插件的getPluginType方法返回的名称一致。这样,当trtexec解析模型时,它会自动查找并使用相应名称的插件处理指定的OP。

(3)编译自定义插件

利用如下编译指令,将上述自定义插件编译成共享库(.so文件)。如何,编译后得到CustomReLUPlugin.so。需要注意的是,编译时,需要链接TensorRT库。

复制代码
g++ -std=c++11 -shared -o CustomReLUPlugin.so CustomReLUPlugin.cpp -I<path_to_tensorrt_include> -L<path_to_tensorrt_lib> -lnvinfer

(4)调用自定义插件

使用trtexec转换ONNX模型时,通过"--plugins"指定并调用自定义插件共享库。

复制代码
trtexec --onnx=my_model.onnx --plugins=CustomReLUPlugin.so

(5)链接多个自定义插件

当有多个自定义插件需要链接时,可通过多次调用"--plugins"引用,也可以通过一个"--plugins"后面跟多个共享库,共享库之间由逗号隔开。

复制代码
trtexec --onnx=my_model.onnx --plugins=CustomPluginA.so --plugins=CustomPluginB.so

trtexec --onnx=my_model.onnx --plugins=CustomPluginA.so,CustomPluginB.so

7. 部署、执行推理

通过trtexec生成的推理引擎文件一般是.plan或.engine文件后缀(也有其他自定义后缀)。此时可以通过TensorRT的C++接口进行加载和推理。

(1)包含必要的头文件

复制代码
#include <NvInfer.h>

#include <NvInferRuntime.h>

#include <NvInferRuntimeCommon.h>

(2)创建运行时和引擎

使用运行时接口"IRuntime"来反序列化引擎文件,并创建推理引擎"ICudaEngine"。

复制代码
nvinfer1::IRuntime* runtime = nvinfer1::createInferRuntime(gLogger);

assert(runtime != nullptr);


std::string engineFilePath = "path/to/your/engine.plan";

std::ifstream engineFile(engineFilePath, std::ios::binary);

if (!engineFile) {

    std::cerr << "Failed to open engine file: " << engineFilePath << std::endl;

    return;

}

engineFile.seekg(0, engineFile.end);

long int fsize = engineFile.tellg();

engineFile.seekg(0, engineFile.beg);

std::vector<char> engineData(fsize);

engineFile.read(engineData.data(), fsize);

if (!engineFile) {

    std::cerr << "Failed to read engine file: " << engineFilePath << std::endl;

    return;

}

nvinfer1::ICudaEngine* engine = runtime->deserializeCudaEngine(engineData.data(), fsize, nullptr);

assert(engine != nullptr);

(3)创建执行上下文

创建一个执行上下文"IexecutionContext",用于管理治理的执行。

复制代码
nvinfer1::IExecutionContext* context = engine->createExecutionContext();

assert(context != nullptr);

(4)准备输入输出缓冲区

在GPU上为输入和输出分配缓冲区。

复制代码
void* buffers[engine->getNbBindings()];

for (int i = 0; i < engine->getNbBindings(); ++i) {

    auto bindingSize = getSizeByDataType(engine->getBindingDataType(i)) * volume(engine->getBindingDimensions(i));

    cudaMalloc(&buffers[i], bindingSize);

}

(5)执行推理

使用执行上下文执行推理任务。

复制代码
context->executeV2(buffers);

(6)处理输出

推理完成后,从输出缓冲区读取网络推理结果,并根据实际情况,执行后处理、解析等操作。

(7)释放资源

整个进程执行结束,释放所有资源。

复制代码
for (int i = 0; i < engine->getNbBindings(); ++i) {

    cudaFree(buffers[i]);

}

context->destroy();

engine->destroy();

runtime->destroy();
相关推荐
AI小白的Python之路8 分钟前
机器学习-集成学习
人工智能·机器学习·集成学习
小和尚同志12 分钟前
10k star!各大 AI 应用系统提示词集合
人工智能·开源·aigc
刘媚-海外13 分钟前
Go语言开发AI应用
开发语言·人工智能·golang·go
Blossom.11825 分钟前
从“能写”到“能干活”:大模型工具调用(Function-Calling)的工程化落地指南
数据库·人工智能·python·深度学习·机器学习·计算机视觉·oracle
Memene摸鱼日报1 小时前
「Memene 摸鱼日报 2025.9.12」前OpenAI CTO 公司发布首篇技术博客,Qwen-Next 80B 发布,Kimi 开源轻量级中间件
人工智能·agi
飞哥数智坊1 小时前
CodeBuddy CLI 实测:比 Claude Code 稚嫩,但我感觉值得期待
人工智能·ai编程
电商软件开发 小银1 小时前
本地生活服务平台创新模式观察:积分体系如何重塑消费生态?
大数据·人工智能·数字化转型·私域运营·消费者心理学
扬帆起航131 小时前
亚马逊新品推广破局指南:从手动试错到智能闭环的系统化路径
大数据·数据库·人工智能
小王爱学人工智能1 小时前
利用OpenCV进行指纹识别的案例
人工智能·opencv·计算机视觉
代码AI弗森1 小时前
DPO 深度解析:从公式到工程,从偏好数据到可复用训练管线
人工智能