libtorch的c++,加载*.pth

一、转换模型为TorchScript

前提:python只保存了参数,没存结构

要在C++中使用libtorch(PyTorch的C++接口),读取和加载通过torch.save保存的模型( torch.save(pdn.state_dict()这种方式,只保存了参数,没存结构),需要转换模型为TorchScript。在python下实现。

复制代码
def get_pdn_small(out_channels=384, padding=False):
    pad_mult = 1 if padding else 0
    return nn.Sequential(
        nn.Conv2d(in_channels=3, out_channels=128, kernel_size=4,
                  padding=3 * pad_mult),
        nn.ReLU(inplace=True),
        nn.AvgPool2d(kernel_size=2, stride=2, padding=1 * pad_mult),
        nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4,
                  padding=3 * pad_mult),
        nn.ReLU(inplace=True),
        nn.AvgPool2d(kernel_size=2, stride=2, padding=1 * pad_mult),
        nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3,
                  padding=1 * pad_mult),
        nn.ReLU(inplace=True),
        nn.Conv2d(in_channels=256, out_channels=out_channels, kernel_size=4)
    )

def get_pdn_medium(out_channels=384, padding=False):
    pad_mult = 1 if padding else 0
    return nn.Sequential(
        nn.Conv2d(in_channels=3, out_channels=256, kernel_size=4,
                  padding=3 * pad_mult),
        nn.ReLU(inplace=True),
        nn.AvgPool2d(kernel_size=2, stride=2, padding=1 * pad_mult),
        nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4,
                  padding=3 * pad_mult),
        nn.ReLU(inplace=True),
        nn.AvgPool2d(kernel_size=2, stride=2, padding=1 * pad_mult),
        nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3,
                  padding=1 * pad_mult),
        nn.ReLU(inplace=True),
        nn.Conv2d(in_channels=512, out_channels=out_channels, kernel_size=4),
        nn.ReLU(inplace=True),
        nn.Conv2d(in_channels=out_channels, out_channels=out_channels,
                  kernel_size=1)
    )

import torch

# 假设你有一个已训练的模型
model = get_pdn_small()

# 加载模型的state_dict
model.load_state_dict(torch.load('teacher_small.pth'))
model.eval()  # 设置模型为评估模式

# 将模型转化为TorchScript
scripted_model = torch.jit.script(model)
scripted_model.save('teacher_small.pt')

二、在C++中加载TorchScript模型

在C++中,你可以使用torch::jit::load来加载.pt文件,如下所示:

复制代码
#include <torch/script.h>  // One-stop header for loading TorchScript models
#include <iostream>
#include <memory>

int main() {
    // 加载TorchScript模型
    try {
        // 加载模型
                std::shared_ptr<torch::jit::Module> model = std::make_shared<torch::jit::Module>(torch::jit::load("teacher_small.pt"));


        std::cout << "Model loaded successfully!" << std::endl;

        // 你可以在这里使用模型进行推理,比如输入一个张量
        // 例如,如果输入是一个3x224x224的图像,你需要创建一个相应的Tensor
        torch::Tensor input = torch::randn({1, 3, 224, 224});  // 示例输入
        std::vector<torch::jit::IValue> inputs;
        inputs.push_back(input);

        // 执行模型推理
        at::Tensor output = model->forward(inputs).toTensor();
        std::cout << "Output tensor: " << output << std::endl;
    }
    catch (const c10::Error& e) {
        std::cerr << "Error loading the model: " << e.what() << std::endl;
        return -1;
    }
}
相关推荐
loui robot1 分钟前
规划与控制之局部路径规划算法local_planner
人工智能·算法·自动驾驶
玄同7655 分钟前
Llama.cpp 全实战指南:跨平台部署本地大模型的零门槛方案
人工智能·语言模型·自然语言处理·langchain·交互·llama·ollama
格林威7 分钟前
Baumer相机金属焊缝缺陷识别:提升焊接质量检测可靠性的 7 个关键技术,附 OpenCV+Halcon 实战代码!
人工智能·数码相机·opencv·算法·计算机视觉·视觉检测·堡盟相机
独处东汉15 分钟前
freertos开发空气检测仪之按键输入事件管理系统设计与实现
人工智能·stm32·单片机·嵌入式硬件·unity
你大爷的,这都没注册了15 分钟前
AI提示词,zero-shot,few-shot 概念
人工智能
AC赳赳老秦16 分钟前
DeepSeek 辅助科研项目申报:可行性报告与经费预算框架的智能化撰写指南
数据库·人工智能·科技·mongodb·ui·rabbitmq·deepseek
瑞华丽PLM24 分钟前
国产PLM软件源头厂家的AI技术应用与智能化升级
人工智能·plm·国产plm·瑞华丽plm·瑞华丽
koo36430 分钟前
pytorch深度学习笔记19
pytorch·笔记·深度学习
xixixi7777733 分钟前
基于零信任架构的通信
大数据·人工智能·架构·零信任·通信·个人隐私
玄同76536 分钟前
LangChain v1.0+ Prompt 模板完全指南:构建精准可控的大模型交互
人工智能·语言模型·自然语言处理·langchain·nlp·交互·知识图谱