PyTorch深度学习笔记(二十)(模型验证测试)

前言

到这一章节为止,依据小土堆课程的PyTorch深度学习笔记基础部分已经完结了,接下来将依据李沐动手学深度学习课程进行PyTorch深度学习笔记的进阶部分

预测图片

完整的模型验证(测试,demo)套路,利用已经训练好的模型,然后给它提供输入。

输入狗的图片,并打开

python 复制代码
image_path = "imgs/dog.png"
image = Image.open(image_path)

4通道的RGBA转为3通道的RGB图片

python 复制代码
image = image.convert("RGB")

转换图像格式并设定网络

python 复制代码
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),   
                                            torchvision.transforms.ToTensor()])

image = transform(image)
print(image.shape)

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()        
        self.model1 = nn.Sequential(
            nn.Conv2d(3,32,5,1,2),
            nn.MaxPool2d(2),
            nn.Conv2d(32,32,5,1,2),
            nn.MaxPool2d(2),
            nn.Conv2d(32,64,5,1,2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64*4*4,64),
            nn.Linear(64,10)
        )
        
    def forward(self, x):
        x = self.model1(x)
        return x

GPU上训练的东西映射到CPU上

python 复制代码
model = torch.load("model/tudui_29.pth",map_location=torch.device('cpu'))

转为四维,符合网络输入需求

python 复制代码
image = torch.reshape(image,(1,3,32,32))

将模型转为测试类型

python 复制代码
model.eval()

不进行梯度计算,减少内存计算

python 复制代码
with torch.no_grad():
    output = model(image)

概率最大类别的输出

python 复制代码
print(output.argmax(1))

完整代码

python 复制代码
import torchvision
from PIL import Image
from torch import nn
import torch

image_path = "imgs/dog.png"
image = Image.open(image_path)
image = image.convert("RGB") 
print(image)

transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),   
                                            torchvision.transforms.ToTensor()])

image = transform(image)
print(image.shape)

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()        
        self.model1 = nn.Sequential(
            nn.Conv2d(3,32,5,1,2),
            nn.MaxPool2d(2),
            nn.Conv2d(32,32,5,1,2),
            nn.MaxPool2d(2),
            nn.Conv2d(32,64,5,1,2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64*4*4,64),
            nn.Linear(64,10)
        )
        
    def forward(self, x):
        x = self.model1(x)
        return x

model = torch.load("model/tudui_29.pth",map_location=torch.device('cpu'))
print(model)
image = torch.reshape(image,(1,3,32,32)) 
model.eval()
with torch.no_grad():
    output = model(image)
output = model(image)
print(output)
print(output.argmax(1)) 
相关推荐
NewsMash2 小时前
PyTorch之父发离职长文,告别Meta
人工智能·pytorch·python
IT_陈寒2 小时前
Python 3.12新特性实测:10个让你的代码提速30%的隐藏技巧 🚀
前端·人工智能·后端
Ztop2 小时前
GPT-5.1 已确认!OpenAI下一步推理升级?对决 Gemini 3 在即
人工智能·gpt·chatgpt
qq_436962182 小时前
奥威BI:打破数据分析的桎梏,让决策更自由
人工智能·数据挖掘·数据分析
金融Tech趋势派2 小时前
金融机构如何用企业微信实现客户服务优化?
大数据·人工智能·金融·企业微信·企业微信scrm
大模型真好玩3 小时前
LangChain1.0速通指南(三)——LangChain1.0 create_agent api 高阶功能
人工智能·langchain·mcp
汉克老师3 小时前
CCF--LMCC大语言模型能力认证官方样题(第一赛(青少年组)第二部分 程序题 (21--25))
人工智能·语言模型·自然语言处理·lmcc
视界先声3 小时前
如何挑选出色的展厅机器人
人工智能·机器人
沫儿笙3 小时前
YASKAWA机器人焊机气体省气
人工智能·机器人