Resnet C ++ 部署 pytorch功能测试(一)

说明

最近在研究分类模型如何部署C++,先拿Resnet50 来练一练手,文章将 分为多篇,这一篇主要验证一下pytorch 模型输出是正确的,为后续tensort RT 模型输出提供验证。

1 官方权重下载

https://download.pytorch.org/models/resnet50-19c8e357.pth

2 测试代码

python 复制代码
import torchvision.models as models
import os
import json
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt


def main():
    # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    device = torch.device("cuda:0")
    # device = torch.device("cpu")
    model = models.resnet50()
    model = model.cuda()
    # num_classes = 10  # 修改为你自己的类别数量
    # model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
    data_transform = transforms.Compose(
        [transforms.Resize(256),
         transforms.CenterCrop(224),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
    # model.load_state_dict(torch.load('params.pth', map_location=device)) #
    model.load_state_dict(torch.load('resnet50-19c8e357.pth'))
    model.eval()
    img_path = 'data/fox.png'
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img = Image.open(img_path)
    plt.imshow(img)
    # [N, C, H, W]
    img = data_transform(img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)
    with torch.no_grad():
        # predict class
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        print(predict.shape)
        print(predict[270:280]) # 打印几个softmax之后的输出
        predict_cla = torch.argmax(predict).numpy() # 找出最大的序号
        print(predict_cla) # 打印出类别


if __name__ == '__main__':
    main()

3 测试图片

这是我找的几张图片


4 测试结果

跑一张狐狸的图片,输出序号277

查看预训练模型对应的类别编号

相关推荐
羊八井14 分钟前
使用 Earth2Studio 和 AI 模型进行全球天气预测:太阳辐照
pytorch·python·nvidia
向左转, 向右走ˉ24 分钟前
PyTorch随机擦除:提升模型抗遮挡能力
人工智能·pytorch·python·深度学习
大白的编程日记.1 小时前
【计算机基础理论知识】C++篇(二)
开发语言·c++·学习
网小鱼的学习笔记1 小时前
python中MongoDB操作实践:查询文档、批量插入文档、更新文档、删除文档
开发语言·python·mongodb
C语言小火车1 小时前
野指针:C/C++内存管理的“幽灵陷阱”与系统化规避策略
c语言·c++·学习·指针
Q_Q5110082851 小时前
python的保险业务管理与数据分析系统
开发语言·spring boot·python·django·flask·node.js·php
亮1111 小时前
Maven 编译过程中发生了 Java Heap Space 内存溢出(OutOfMemoryError)
java·开发语言·maven
凤年徐1 小时前
【数据结构】时间复杂度和空间复杂度
c语言·数据结构·c++·笔记·算法
Chef_Chen1 小时前
从0开始学习R语言--Day40--Kruskal-Wallis检验
开发语言·学习·r语言
鑫宇吖1 小时前
Polyspace作为MISRA-C合规性检查工具,其检查规则会根据目标C语言标准(C90或C99)动态调整限值要求。
c语言·嵌入式·c99·c90·polyspace·misra-c合规性检查