pytorch深度学习模型推理和部署、pytorch&ONNX&tensorRT模型转换以及python和C++版本部署

目录

[1. 采用pytorch进行推理](#1. 采用pytorch进行推理)

[2. 采用onnx进行推理](#2. 采用onnx进行推理)

[2.1 pytorch转换为onnx](#2.1 pytorch转换为onnx)

[2.2 onnx推理](#2.2 onnx推理)

[3. 采用tensorrt进行推理(python环境)](#3. 采用tensorrt进行推理(python环境))

[3.1 onnx转engine文件](#3.1 onnx转engine文件)

[3.2 tensorrt推理](#3.2 tensorrt推理)

[4. 采用tensorrt进行推理(c++环境)](#4. 采用tensorrt进行推理(c++环境))

[5. 采用torch2trt进行推理(python环境)](#5. 采用torch2trt进行推理(python环境))


在pytorch框架下,可以很方便进行深度学习模型的搭建、训练和保存。当模型训练完成后,如何进行模型部署和推理是本文的重点。接下来以resnet模型为例,讲解下述部署和推理方法,并对比不同方法下的推理时间和精度:

1)采用pytorch进行推理(python环境)

2)采用onnx进行推理(python环境)

3)采用tensorrt进行推理(python环境)

4)采用tensorrt进行推理(c++环境)

5)采用torch2trt进行推理(python环境)


1. 采用pytorch进行推理

推理过程中第一次推理时间包含了初始化的时间,差异比较大,对比推理时间应该看后续几次推理时间。

'''
Author:   kk
Date:     2025.1.20
version:  v0.1
Function: resnet pytorch model inference
'''

import time

import json
from PIL import Image
import torch
from torchvision import models, transforms


# 下载 ImageNet 类别标签 (https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json)
with open("imagenet_labels.json", "r") as f:
    labels = json.load(f)

# choose device
device = torch.device("cuda:0")

# 图像预处理
transform = transforms.Compose([
 transforms.Resize(256),
 transforms.CenterCrop(224),
 transforms.ToTensor(),
 transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])

# 使用resnet152
model = models.resnet152(weights=models.ResNet152_Weights.IMAGENET1K_V1)
model.to(device)
model.eval()

# 加载图片
img = Image.open("cat.jpg")
img_t = transform(img)
img_t = img_t.unsqueeze(0).to(device)

# inference
with torch.no_grad():  # 关闭梯度计算
    for _ in range(10):
        start_time = time.time()
        outputs = model(img_t)
        end_time = time.time()
        probabilities = torch.nn.functional.softmax(outputs[0], dim=0)  # 转换为概率
        top_prob, top_catid = torch.topk(probabilities, 5)
        print(f"PyTorch 推理时间: {(end_time - start_time) * 1000:.2f} ms")
    for i in range(top_prob.size(0)):
        print(f"{labels[top_catid[i]]}: {top_prob[i].item():.4f}")
        

# 模型转换
input_shape = (1,3,224,224)
torch.onnx.export(model, img_t, "resnet152.onnx", verbose=False, opset_version=11, input_names=["input0"], output_names=["output0"])

执行结果:

PyTorch 推理时间: 140.22 ms
PyTorch 推理时间: 12.74 ms
PyTorch 推理时间: 14.25 ms
PyTorch 推理时间: 12.80 ms
PyTorch 推理时间: 13.00 ms
PyTorch 推理时间: 12.51 ms
PyTorch 推理时间: 12.12 ms
PyTorch 推理时间: 12.14 ms
PyTorch 推理时间: 12.26 ms
PyTorch 推理时间: 12.17 ms
tabby cat: 0.4985
tiger cat: 0.2573
Egyptian Mau: 0.2003
lynx: 0.0281
remote control: 0.0016

2. 采用onnx进行推理

ONNX(Open Neural Network Exchange)是一个开放的神经网络交换格式,旨在支持不同深度学习框架之间的模型互操作性。它提供了一个标准化的模型表示,使模型能够在不同工具和平台之间进行转换和部署。

2.1 pytorch转换为onnx
'''
Author:   kk
Date:     2025.1.20
version:  v0.1
Function: save pytorch model to onnx
'''

from torchvision import models
import torch

# 使用resnet50, torchvision 0.13及以后的新版本写法
resnet_ = models.resnet152(weights=models.ResNet152_Weights.IMAGENET1K_V1)

# 模型转换, 详细参数请自行查阅
input_shape = (1,3,224,224)
dummy_input = torch.randn(input_shape)
torch.onnx.export(resnet_, dummy_input, "resnet152.onnx", verbose=True, opset_version=11, input_names=["input0"], output_names=["output0"])
2.2 onnx推理
'''
Author:   kk
Date:     2025.1.20
version:  v0.1
Function: resnet onnx model inference
'''

import time
from torchvision import transforms
import torch
from PIL import Image
import onnxruntime as ort
import json


# 下载 ImageNet 类别标签 (https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json)
with open("imagenet_labels.json", "r") as f:
    labels = json.load(f)

# 构建providers
providers = [
    ('CUDAExecutionProvider', {
     'device_id': 0,
     'arena_extend_strategy': 'kNextPowerOfTwo',
     'gpu_mem_limit': 2 * 1024 * 1024 * 1024,
     'cudnn_conv_algo_search': 'EXHAUSTIVE',
     'do_copy_in_default_stream': True,
     }),
    'CPUExecutionProvider',
]

# 加载模型
ort_session = ort.InferenceSession("resnet152.onnx", providers=providers)

# 图像预处理
transform = transforms.Compose([
 transforms.Resize(256),
 transforms.CenterCrop(224),
 transforms.ToTensor(),
 transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])

# 加载图片
img = Image.open("cat.jpg")
img_t = transform(img)
img_numpy = img_t.numpy()[None,:]

# inference
for i in range(10):
    start_time = time.time()
    input_name = ort_session.get_inputs()[0].name
    outputs = ort_session.run(None, {input_name: img_numpy})[0]
    end_time = time.time()
    print("time:", end_time - start_time)

outputs = torch.from_numpy(outputs[0])

probabilities = torch.nn.functional.softmax(outputs, dim=0)  # 转换为概率
top_prob, top_catid = torch.topk(probabilities, 5)
for i in range(top_prob.size(0)):
    print(f"{labels[top_catid[i]]}: {top_prob[i].item():.4f}")

执行结果:

PyTorch 推理时间: 2161.68 ms
PyTorch 推理时间: 14.24 ms
PyTorch 推理时间: 10.43 ms
PyTorch 推理时间: 10.50 ms
PyTorch 推理时间: 10.25 ms
PyTorch 推理时间: 10.59 ms
PyTorch 推理时间: 10.50 ms
PyTorch 推理时间: 10.54 ms
PyTorch 推理时间: 10.59 ms
PyTorch 推理时间: 10.41 ms
tabby cat: 0.4984
tiger cat: 0.2573
Egyptian Mau: 0.2002
lynx: 0.0280
remote control: 0.0016

相比于采用pytorch直接进行推理,速度略微有提升。

3. 采用tensorrt进行推理(python环境)

TensorRT 是 NVIDIA 提供的一款高性能深度学习推理优化库,专为 NVIDIA GPU 硬件设计。它通过优化模型推理过程,显著提高模型在推理阶段的吞吐量和延迟表现,是深度学习应用部署的重要工具。它具有以下优点:

  • 高性能推理:通过内核融合、精度校准(FP16/INT8)、动态张量和高效内存管理,优化模型推理速度。
  • 多精度支持:支持 FP32、FP16 和 INT8 精度,其中 FP16 和 INT8 能显著加速推理并减少显存占用。
  • 动态形状支持:可以处理动态输入(例如,不固定的图像分辨率或 batch size)。
  • 插件机制:支持自定义算子开发,满足特殊模型需求。
  • 跨平台支持:可以在服务器、边缘设备和嵌入式设备上运行(如 NVIDIA Jetson)。
3.1 onnx转engine文件

采用tensorrt进行推理,需要将onnx格式转成engine文件。trtexec 是 TensorRT 提供的一个命令行工具,用于快速测试、优化和部署模型。它可以直接运行 ONNX 模型、TensorRT 引擎文件,并提供模型推理的性能指标。默认精度为FP32,可以选择FP16或者INT8,其中INT8需要提供校准文件。

trtexec --onnx=resnet152.onnx --saveEngine=resnet152_engine.trt
trtexec --onnx=resnet152.onnx --saveEngine=resnet152_fp16_engine.trt --fp16

除了采用trtexec,还可以通过 TensorRT 的 API 使用 Python 或 C++ 编程,更灵活地控制模型转换和优化过程。这种方式适合需要更精细设置或自定义操作的场景。

'''
Author:   kk
Date:     2025.1.21
version:  v0.1
Function: save tensorrt engine file from onnx
'''
import tensorrt as trt


onnx_model_path =  "resnet50.onnx"  

TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(TRT_LOGGER)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, TRT_LOGGER)

with open(onnx_model_path, "rb") as model:
    if not parser.parse(model.read()):
        for error in range(parser.num_errors):
            print(parser.get_error(error))
        raise RuntimeError("ONNX 解析失败")

config = builder.create_builder_config()
config.max_workspace_size = 1 << 30  # 设置最大工作空间

if builder.platform_has_fast_fp16:
    print("启用 FP16 模式")
    config.set_flag(trt.BuilderFlag.FP16)
else:
    print("当前 GPU 不支持 FP16,使用 FP32 模式")   

serialized_engine = builder.build_serialized_network(network, config)


engine_file_path = "resnet50_from_program.engine"

with open(engine_file_path, "wb") as f:
    f.write(serialized_engine)
    
print(f"TensorRT Engine 已保存到 {engine_file_path}")
3.2 tensorrt推理
'''
Author:   kk
Date:     2025.1.20
version:  v0.1
Function: resnet tensorrt inference 
'''

import numpy as np
import time
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit  # 自动初始化 PyCUDA
from PIL import Image
import torchvision.transforms as transforms
import json

# 下载 ImageNet 类别标签 (https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json)
with open("imagenet_labels.json", "r") as f:
    labels = json.load(f)

# 定义输入图像的预处理步骤
def preprocess_image(image_path):
    preprocess = transforms.Compose([
        transforms.Resize((224,224)),
        # transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])
    image = Image.open(image_path).convert("RGB")
    input_tensor = preprocess(image).unsqueeze(0)  # 添加 batch 维度
    return input_tensor.numpy()

# 加载 TensorRT Engine 文件
def load_engine(engine_file_path):
    TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
    with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
        return runtime.deserialize_cuda_engine(f.read())

# 使用 TensorRT 推理
def infer_with_tensorrt(engine, input_data):
    # 创建 TensorRT 上下文
    context = engine.create_execution_context()

    # 分配输入和输出缓冲区
    input_index = engine.get_binding_index("input0")
    output_index = engine.get_binding_index("output0")
    input_shape = engine.get_binding_shape(input_index)
    output_shape = engine.get_binding_shape(output_index)

    # 校验输入数据尺寸
    assert input_data.shape[1:] == tuple(input_shape[1:]), f"输入尺寸不匹配: {input_data.shape[1:]} != {tuple(input_shape[1:])}"

    # 分配显存缓冲区
    input_memory = cuda.mem_alloc(input_data.nbytes)
    output_memory = cuda.mem_alloc(int(np.prod(output_shape)) * input_data.dtype.itemsize)

    # 创建 CUDA 流
    stream = cuda.Stream()

    # 将输入数据拷贝到显存
    cuda.memcpy_htod_async(input_memory, input_data, stream)

    # 推理
    bindings = [int(input_memory), int(output_memory)]
    
    for i in range(10):

        start_time = time.time()
        context.execute_async_v2(bindings, stream_handle=stream.handle)
        stream.synchronize()  # 等待计算完成
        end_time = time.time()
        print(f"TensorRT 推理时间: {(end_time - start_time) * 1000:.2f} ms")
            
        # 从显存拷贝输出数据
        output_data = np.empty(output_shape, dtype=np.float32)
        cuda.memcpy_dtoh_async(output_data, output_memory, stream)
        stream.synchronize()

    return output_data


def main():
    engine_path = "resnet152_engine.trt" 
    image_path = "cat.jpg" 

    # 加载 TensorRT Engine
    engine = load_engine(engine_path)

    # 预处理输入图像
    input_data = preprocess_image(image_path)

    # 推理并解析结果
    output = infer_with_tensorrt(engine, input_data)
    
    def softmax(logits):
        exp_logits = np.exp(logits - np.max(logits))  # 防止溢出
        return exp_logits / np.sum(exp_logits)
        
    probabilities = softmax(output[0])

    # 输出最高概率的类别
    top_indices = probabilities.argsort()[-5:][::-1]  # 获取概率最高的5个索引
    print("预测结果:")
    for idx in top_indices:
        print(f"{labels[idx]}: {probabilities[idx]:.4f}")

if __name__ == "__main__":
    main()

FP32推理结果:

TensorRT 推理时间: 46.79 ms
TensorRT 推理时间: 4.53 ms
TensorRT 推理时间: 4.53 ms
TensorRT 推理时间: 4.54 ms
TensorRT 推理时间: 4.55 ms
TensorRT 推理时间: 4.52 ms
TensorRT 推理时间: 4.56 ms
TensorRT 推理时间: 4.53 ms
TensorRT 推理时间: 4.52 ms
TensorRT 推理时间: 4.56 ms
预测结果:
tabby cat: 0.4988
tiger cat: 0.2572
Egyptian Mau: 0.2000
lynx: 0.0281
remote control: 0.0016

FP16推理结果:

TensorRT 推理时间: 42.77 ms
TensorRT 推理时间: 1.68 ms
TensorRT 推理时间: 1.66 ms
TensorRT 推理时间: 1.66 ms
TensorRT 推理时间: 1.67 ms
TensorRT 推理时间: 1.71 ms
TensorRT 推理时间: 1.71 ms
TensorRT 推理时间: 1.72 ms
TensorRT 推理时间: 1.68 ms
TensorRT 推理时间: 1.66 ms
预测结果:
tabby cat: 0.5110
tiger cat: 0.2510
Egyptian Mau: 0.1955
lynx: 0.0271
remote control: 0.001

从上面可以看出,tensorrt的推理速度要比pytorch推理快上很多。其中FP32快2-3倍,FP16快6-8倍,FP16的结果和上述会稍有差异。

4. 采用tensorrt进行推理(c++环境)

/*------------------------
Author:    kk
Date:      2025.1.21
Version:   v0.1
function:  tensorrt inference with c++

 -------------------------*/

#include <NvInfer.h>
#include <cuda_runtime_api.h>
#include <iostream>
#include <fstream>
#include <vector>
#include <opencv2/opencv.hpp>

using namespace nvinfer1;

// Logger for TensorRT info/warning/errors
class Logger : public ILogger {
public:
    void log(Severity severity, const char* msg) noexcept override {
        // Suppress info-level messages
        if (severity != Severity::kINFO)
            std::cout << "[TensorRT] " << msg << std::endl;
    }
};

static Logger logger;

// image preprocess 
void preprocessImage(const std::string& imagePath, float* inputData, int inputH, int inputW) {
    cv::Mat img = cv::imread(imagePath, cv::IMREAD_COLOR);
    if (img.empty()) {
        std::cerr << "Failed to load image: " << imagePath << std::endl;
        exit(1);
    }

    cv::Mat resized_image;
    cv::resize(img, resized_image, cv::Size(inputH, inputW));

    // 2. 转换为RGB格式(OpenCV默认是BGR)
    cv::Mat rgb_image;
    cv::cvtColor(resized_image, rgb_image, cv::COLOR_BGR2RGB);

    // 3. 转为浮点类型并归一化到[0, 1]
    cv::Mat float_image;
    rgb_image.convertTo(float_image, CV_32F, 1.0 / 255.0);

    // 4. 归一化处理(按每个通道进行均值和标准差调整)
    float mean[] = {0.485, 0.456, 0.406};  // ImageNet的均值
    float std[] = {0.229, 0.224, 0.225};   // ImageNet的标准差

    // 分离颜色通道
    std::vector<cv::Mat> channels(3);
    cv::split(float_image, channels);

    // 对每个通道进行标准化
    for (int i = 0; i < 3; i++) {
        channels[i] = (channels[i] - mean[i]) / std[i];
    }

    int count = 0;
    for (int c = 0; c < channels.size(); ++c) {
        for (int h = 0; h < channels[c].rows; ++h) {
            for (int w = 0; w < channels[c].cols; ++w) {
                    inputData[count] = channels[c].at<float>(h, w); 
                    count += 1;
            }
        }
    }
}

// softmax
std::vector<float> softmax(const std::vector<float>& input) {
    std::vector<float> output(input.size());
    float sum_exp = 0.0f;

    // 计算每个元素的指数并求和
    for (size_t i = 0; i < input.size(); ++i) {
        sum_exp += std::exp(input[i]);
    }

    // 归一化每个元素
    for (size_t i = 0; i < input.size(); ++i) {
        output[i] = std::exp(input[i]) / sum_exp;
    }

    return output;
}


// result process
void postprocessResults(const float* output, int outputSize, const std::vector<std::string>& labels) {
    
    std::vector<float> out_vec(output, output + outputSize);
    std::vector<float> output_softmax = softmax(out_vec);
    
    std::vector<std::pair<float, int>> scores;
    for (int i = 0; i < outputSize; ++i) {
        scores.emplace_back(output_softmax[i], i);
    }
    std::sort(scores.begin(), scores.end(), std::greater<>());

    // Print top-5 predictions
    std::cout << "Top-5 Predictions:" << std::endl;
    for (int i = 0; i < 5; ++i) {
        std::cout << labels[scores[i].second] << ": " << scores[i].first << std::endl;
    }
}

// load iamgenet labels
std::vector<std::string> loadLabels(const std::string& labelFile) {
    std::ifstream file(labelFile);
    std::vector<std::string> labels;
    std::string line;
    while (std::getline(file, line)) {
        labels.push_back(line);
    }
    return labels;
}

// inference with tensorrt
void runInference(const std::string& enginePath, const std::string& imagePath, const std::string& labelPath) {

    // Load engine
    std::ifstream engineFile(enginePath, std::ios::binary);
    if (!engineFile) {
        std::cerr << "Failed to open engine file: " << enginePath << std::endl;
        exit(1);
    }
    engineFile.seekg(0, std::ios::end);
    size_t engineSize = engineFile.tellg();
    engineFile.seekg(0, std::ios::beg);
    std::vector<char> engineData(engineSize);
    engineFile.read(engineData.data(), engineSize);

    IRuntime* runtime = createInferRuntime(logger);
    ICudaEngine* engine = runtime->deserializeCudaEngine(engineData.data(), engineSize);//7.0以下增加nullptr
    IExecutionContext* context = engine->createExecutionContext();

    // Prepare input and output buffers
    const int inputIndex = engine->getBindingIndex("input0");
    const int outputIndex = engine->getBindingIndex("output0");

    int inputH = engine->getBindingDimensions(inputIndex).d[2];
    int inputW = engine->getBindingDimensions(inputIndex).d[3];
    int outputSize = engine->getBindingDimensions(outputIndex).d[1];

    void* buffers[2];
    cudaMalloc(&buffers[inputIndex], 3 * inputH * inputW * sizeof(float));
    cudaMalloc(&buffers[outputIndex], outputSize * sizeof(float));

        // Set input and output bindings
    context->setBindingDimensions(inputIndex, Dims4{1, 3, inputH, inputW});
    context->setTensorAddress("input0", buffers[inputIndex]);
    context->setTensorAddress("output0", buffers[outputIndex]);

    // Create CUDA stream
    cudaStream_t stream;
    cudaStreamCreate(&stream);


    for(int i = 0; i < 10; ++i)
    {
        
        float* inputHost = new float[3 * inputH * inputW];
        float* outputHost = new float[outputSize];

        // Preprocess image
        preprocessImage(imagePath, inputHost, inputH, inputW);
        // cudaMemcpy(buffers[inputIndex], inputHost, 3 * inputH * inputW * sizeof(float), cudaMemcpyHostToDevice);
         cudaMemcpyAsync(buffers[0], inputHost, 3 * inputH * inputW * sizeof(float), cudaMemcpyHostToDevice, stream);

        // Run inference
        float inferenceTime;
        cudaEvent_t start, stop;
                
        
        cudaEventCreate(&start);
        cudaEventCreate(&stop);

        cudaEventRecord(start, stream);
        // context->enqueueV2(buffers, 0, nullptr);
         context->enqueueV3(stream);

        cudaEventRecord(stop, stream);
        cudaEventSynchronize(stop);

        cudaEventElapsedTime(&inferenceTime, start, stop);
        std::cout << "Inference Time: " << inferenceTime << " ms" << std::endl;
        // cudaMemcpy(outputHost, buffers[outputIndex], outputSize * sizeof(float), cudaMemcpyDeviceToHost);
        cudaMemcpyAsync(outputHost, buffers[1], outputSize * sizeof(float), cudaMemcpyDeviceToHost, stream);
        cudaStreamSynchronize(stream);

        // Postprocess results
        if(i == 9)
        {
            auto labels = loadLabels(labelPath);
            postprocessResults(outputHost, outputSize, labels);
        }

        delete[] inputHost;
        delete[] outputHost;
    }

    // Cleanup
    cudaFree(buffers[inputIndex]);
    cudaFree(buffers[outputIndex]);
    delete context;
    delete engine;
    delete runtime;
}


// main function
int main() {
    const std::string enginePath = "/home/wangyl/program/pylearn/resnet_onnx_trt_test/resnet152_engine.trt";
    const std::string imagePath = "/home/wangyl/program/pylearn/resnet_onnx_trt_test/cat.jpg";
    const std::string labelPath = "/home/wangyl/program/pylearn/resnet_onnx_trt_test/imagenet_labels.json";

    runInference(enginePath, imagePath, labelPath);
    return 0;
}

CMakeLists文件:

cmake_minimum_required(VERSION 3.10)
project(TensorRT_Inference)

set(CMAKE_BUILD_TYPE Debug)  # 确保在Debug模式下编译
set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g -O0")

find_package(CUDA REQUIRED)
find_package(OpenCV REQUIRED)


include_directories(${CUDA_INCLUDE_DIRS} /usr/include/opencv4)

add_executable(resnet50_trt_inference resnet50_trt_inference.cpp)
target_link_libraries(resnet50_trt_inference ${CUDA_LIBRARIES} nvinfer cudart opencv_core opencv_imgproc opencv_highgui)

FP32输出结果:

Inference Time: 45.3356 ms
Inference Time: 4.49808 ms
Inference Time: 4.50038 ms
Inference Time: 4.5016 ms
Inference Time: 4.52179 ms
Inference Time: 4.49683 ms
Inference Time: 4.50618 ms
Inference Time: 4.49936 ms
Inference Time: 4.50445 ms
Inference Time: 4.49485 ms
Top-5 Predictions:
"tabby cat",: 0.517677
"Egyptian Mau",: 0.295606
"tiger cat",: 0.157792
"lynx",: 0.0208484
"Persian cat",: 0.00111315

FP16输出结果:

Inference Time: 40.4666 ms
Inference Time: 1.6351 ms
Inference Time: 1.6624 ms
Inference Time: 1.63142 ms
Inference Time: 1.63312 ms
Inference Time: 1.63485 ms
Inference Time: 1.63613 ms
Inference Time: 1.63539 ms
Inference Time: 1.63667 ms
Inference Time: 1.63373 ms
Top-5 Predictions:
"tabby cat",: 0.507069
"Egyptian Mau",: 0.29809
"tiger cat",: 0.165912
"lynx",: 0.0207663
"Persian cat",: 0.00112668

上述时间可以看出,在c++和python环境中推理时间比较接近,这是因为我们只统计了GPU计算的时间,如果包括前后处理的时间,C++整体推理速度会更快。上述结果会略有差异,主要是前处理和后处理实现上会略有差异。

5. 采用torch2trt进行推理(python环境)

torch2trt 是 NVIDIA 提供的一个轻量级工具库,用于将 PyTorch 模型转换为 TensorRT 引擎,从而优化推理性能。它特别适合使用 NVIDIA GPU 的深度学习应用,能够自动将 PyTorch 的操作映射到 TensorRT 等效操作,同时支持多种优化模式(如 FP16 和 INT8),以提高推理速度并减少延迟。

'''
Author:   kk
Date:     2025.1.20
version:  v0.2
Function: resnet pytorch model inference and torch2trt optimization
'''

import time
import json
from PIL import Image
import torch
from torchvision import models, transforms
from torch2trt import torch2trt, TRTModule  # 引入 torch2trt

# 下载 ImageNet 类别标签 (https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json)
with open("imagenet_labels.json", "r") as f:
    labels = json.load(f)

# choose device
device = torch.device("cuda:0")

# 图像预处理
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 使用resnet152
model = models.resnet152(weights=models.ResNet152_Weights.IMAGENET1K_V1).eval().to(device)

# 加载图片
img = Image.open("cat.jpg")
img_t = transform(img)
img_t = img_t.unsqueeze(0).to(device)

# 使用 torch2trt 转换模型
print("\n== Converting to TensorRT ==")
model_trt = torch2trt(model, [img_t], max_batch_size=1, fp16_mode=True)  # 启用 FP16 模式

# TensorRT 推理
print("\n== TensorRT Inference ==")
with torch.no_grad():
    for _ in range(10):
        start_time = time.time()
        outputs_trt = model_trt(img_t)
        end_time = time.time()
        probabilities = torch.nn.functional.softmax(outputs_trt[0], dim=0)
        top_prob, top_catid = torch.topk(probabilities, 5)
        print(f"TensorRT 推理时间: {(end_time - start_time) * 1000:.2f} ms")
    for i in range(top_prob.size(0)):
        print(f"{labels[top_catid[i]]}: {top_prob[i].item():.4f}")

# 保存 TensorRT 模型
print("\n== Saving TensorRT Model ==")
torch.save(model_trt.state_dict(), "resnet152_trt.pth")

FP32输出结果:

TensorRT 推理时间: 1.48 ms
TensorRT 推理时间: 1.30 ms
TensorRT 推理时间: 0.95 ms
TensorRT 推理时间: 0.75 ms
TensorRT 推理时间: 0.74 ms
TensorRT 推理时间: 0.76 ms
TensorRT 推理时间: 3.35 ms
TensorRT 推理时间: 4.00 ms
TensorRT 推理时间: 4.01 ms
TensorRT 推理时间: 4.01 ms
tabby cat: 0.4988
tiger cat: 0.2572
Egyptian Mau: 0.2000
lynx: 0.0281
remote control: 0.0016

FP16输出结果:

== TensorRT Inference ==
TensorRT 推理时间: 1.41 ms
TensorRT 推理时间: 1.08 ms
TensorRT 推理时间: 0.73 ms
TensorRT 推理时间: 0.56 ms
TensorRT 推理时间: 0.55 ms
TensorRT 推理时间: 0.55 ms
TensorRT 推理时间: 0.67 ms
TensorRT 推理时间: 0.66 ms
TensorRT 推理时间: 0.65 ms
TensorRT 推理时间: 0.57 ms
tabby cat: 0.5102
tiger cat: 0.2506
Egyptian Mau: 0.1967
lynx: 0.0270
remote control: 0.0015

采用torch2trt似乎推理时间更小,还没有具体对比下原因。

相关推荐
缺的不是资料,是学习的心14 分钟前
使用qwen作为基座训练分类大模型
python·机器学习·分类
AI趋势预见26 分钟前
使用AI生成金融时间序列数据:解决股市场的数据稀缺问题并提升信噪比
人工智能·深度学习·神经网络·语言模型·金融
Zda天天爱打卡1 小时前
【机器学习实战中阶】使用Python和OpenCV进行手语识别
人工智能·python·深度学习·opencv·机器学习
martian6651 小时前
第19篇:python高级编程进阶:使用Flask进行Web开发
开发语言·python
gis收藏家2 小时前
利用 SAM2 模型探测卫星图像中的农田边界
开发语言·python
YiSLWLL2 小时前
Tauri2+Leptos开发桌面应用--绘制图形、制作GIF动画和mp4视频
python·rust·ffmpeg·音视频·matplotlib
数据馅2 小时前
python自动生成pg数据库表对应的es索引
数据库·python·elasticsearch
编程、小哥哥2 小时前
python操作mysql
android·python
Serendipity_Carl2 小时前
爬虫基础之爬取某站视频
爬虫·python·pycharm
2401_890416712 小时前
Recaptcha2 图像怎么识别
人工智能·python·django