PyTorch深度学习笔记(二十)(模型验证测试)

前言

到这一章节为止,依据小土堆课程的PyTorch深度学习笔记基础部分已经完结了,接下来将依据李沐动手学深度学习课程进行PyTorch深度学习笔记的进阶部分

预测图片

完整的模型验证(测试,demo)套路,利用已经训练好的模型,然后给它提供输入。

输入狗的图片,并打开

python 复制代码
image_path = "imgs/dog.png"
image = Image.open(image_path)

4通道的RGBA转为3通道的RGB图片

python 复制代码
image = image.convert("RGB")

转换图像格式并设定网络

python 复制代码
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),   
                                            torchvision.transforms.ToTensor()])

image = transform(image)
print(image.shape)

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()        
        self.model1 = nn.Sequential(
            nn.Conv2d(3,32,5,1,2),
            nn.MaxPool2d(2),
            nn.Conv2d(32,32,5,1,2),
            nn.MaxPool2d(2),
            nn.Conv2d(32,64,5,1,2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64*4*4,64),
            nn.Linear(64,10)
        )
        
    def forward(self, x):
        x = self.model1(x)
        return x

GPU上训练的东西映射到CPU上

python 复制代码
model = torch.load("model/tudui_29.pth",map_location=torch.device('cpu'))

转为四维,符合网络输入需求

python 复制代码
image = torch.reshape(image,(1,3,32,32))

将模型转为测试类型

python 复制代码
model.eval()

不进行梯度计算,减少内存计算

python 复制代码
with torch.no_grad():
    output = model(image)

概率最大类别的输出

python 复制代码
print(output.argmax(1))

完整代码

python 复制代码
import torchvision
from PIL import Image
from torch import nn
import torch

image_path = "imgs/dog.png"
image = Image.open(image_path)
image = image.convert("RGB") 
print(image)

transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),   
                                            torchvision.transforms.ToTensor()])

image = transform(image)
print(image.shape)

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()        
        self.model1 = nn.Sequential(
            nn.Conv2d(3,32,5,1,2),
            nn.MaxPool2d(2),
            nn.Conv2d(32,32,5,1,2),
            nn.MaxPool2d(2),
            nn.Conv2d(32,64,5,1,2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64*4*4,64),
            nn.Linear(64,10)
        )
        
    def forward(self, x):
        x = self.model1(x)
        return x

model = torch.load("model/tudui_29.pth",map_location=torch.device('cpu'))
print(model)
image = torch.reshape(image,(1,3,32,32)) 
model.eval()
with torch.no_grad():
    output = model(image)
output = model(image)
print(output)
print(output.argmax(1)) 
相关推荐
一只乔哇噻2 分钟前
java后端工程师+AI大模型进修ing(研一版‖day55)
人工智能
小毅&Nora30 分钟前
【AI微服务】【Spring AI Alibaba】② Agent 深度实战:构建可记忆、可拦截、可流式的智能体系统
人工智能·微服务·spring-ai
陈天伟教授1 小时前
基于学习的人工智能(7)机器学习基本框架
人工智能·学习
Ccjf酷儿1 小时前
操作系统 蒋炎岩 3.硬件视角的操作系统
笔记
千里念行客2401 小时前
昂瑞微正式启动科创板IPO发行
人工智能·科技·信息与通信·射频工程
习习.y1 小时前
python笔记梳理以及一些题目整理
开发语言·笔记·python
撸码猿2 小时前
《Python AI入门》第10章 拥抱AIGC——OpenAI API调用与Prompt工程实战
人工智能·python·aigc
在逃热干面2 小时前
(笔记)自定义 systemd 服务
笔记
双翌视觉2 小时前
双翌全自动影像测量仪:以微米精度打造智能化制造
人工智能·机器学习·制造
编程小白_正在努力中3 小时前
神经网络深度解析:从神经元到深度学习的进化之路
人工智能·深度学习·神经网络·机器学习