PyTorch深度学习笔记(二十)(模型验证测试)

前言

到这一章节为止,依据小土堆课程的PyTorch深度学习笔记基础部分已经完结了,接下来将依据李沐动手学深度学习课程进行PyTorch深度学习笔记的进阶部分

预测图片

完整的模型验证(测试,demo)套路,利用已经训练好的模型,然后给它提供输入。

输入狗的图片,并打开

python 复制代码
image_path = "imgs/dog.png"
image = Image.open(image_path)

4通道的RGBA转为3通道的RGB图片

python 复制代码
image = image.convert("RGB")

转换图像格式并设定网络

python 复制代码
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),   
                                            torchvision.transforms.ToTensor()])

image = transform(image)
print(image.shape)

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()        
        self.model1 = nn.Sequential(
            nn.Conv2d(3,32,5,1,2),
            nn.MaxPool2d(2),
            nn.Conv2d(32,32,5,1,2),
            nn.MaxPool2d(2),
            nn.Conv2d(32,64,5,1,2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64*4*4,64),
            nn.Linear(64,10)
        )
        
    def forward(self, x):
        x = self.model1(x)
        return x

GPU上训练的东西映射到CPU上

python 复制代码
model = torch.load("model/tudui_29.pth",map_location=torch.device('cpu'))

转为四维,符合网络输入需求

python 复制代码
image = torch.reshape(image,(1,3,32,32))

将模型转为测试类型

python 复制代码
model.eval()

不进行梯度计算,减少内存计算

python 复制代码
with torch.no_grad():
    output = model(image)

概率最大类别的输出

python 复制代码
print(output.argmax(1))

完整代码

python 复制代码
import torchvision
from PIL import Image
from torch import nn
import torch

image_path = "imgs/dog.png"
image = Image.open(image_path)
image = image.convert("RGB") 
print(image)

transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),   
                                            torchvision.transforms.ToTensor()])

image = transform(image)
print(image.shape)

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()        
        self.model1 = nn.Sequential(
            nn.Conv2d(3,32,5,1,2),
            nn.MaxPool2d(2),
            nn.Conv2d(32,32,5,1,2),
            nn.MaxPool2d(2),
            nn.Conv2d(32,64,5,1,2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64*4*4,64),
            nn.Linear(64,10)
        )
        
    def forward(self, x):
        x = self.model1(x)
        return x

model = torch.load("model/tudui_29.pth",map_location=torch.device('cpu'))
print(model)
image = torch.reshape(image,(1,3,32,32)) 
model.eval()
with torch.no_grad():
    output = model(image)
output = model(image)
print(output)
print(output.argmax(1)) 
相关推荐
Hcoco_me10 分钟前
RNN(循环神经网络)
人工智能·rnn·深度学习
踏浪无痕18 分钟前
AI 时代架构师如何有效成长?
人工智能·后端·架构
AI 智能服务19 分钟前
第6课__本地工具调用(文件操作)
服务器·人工智能·windows·php
clorisqqq37 分钟前
人工智能现代方法笔记 第1章 绪论(1/2)
人工智能·笔记
kisshuan1239638 分钟前
YOLO11-RepHGNetV2实现甘蔗田杂草与作物区域识别详解
人工智能·计算机视觉·目标跟踪
charlie11451419138 分钟前
嵌入式现代C++教程: 构造函数优化:初始化列表 vs 成员赋值
开发语言·c++·笔记·学习·嵌入式·现代c++
焦耳热科技前沿44 分钟前
北京科技大学/理化所ACS Nano:混合价态Cu₂Sb金属间化合物实现高效尿素电合成
大数据·人工智能·自动化·能源·材料工程
C+-C资深大佬1 小时前
Creo 11.0 全功能解析:多体设计 + 仿真制造,机械设计效率翻倍下载安装
人工智能
浔川python社1 小时前
【维护期间重要提醒】请勿使用浔川 AI 翻译 v6.0 翻译违规内容
人工智能
CS创新实验室1 小时前
AI 与编程
人工智能·编程·编程语言