11-pytorch-使用自己的数据集测试

b站小土堆pytorch教程学习笔记

python 复制代码
import torch
import torchvision
from PIL import Image
from torch import nn

img_path= '../imgs/dog.png'
image=Image.open(img_path)
print(image)
# image=image.convert('RGB')

transform=torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),
                                          torchvision.transforms.ToTensor()])
image=transform(image)
print(image.shape)

#加载模型
class Han(nn.Module):
    def __init__(self):
        super(Han, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 4 * 4, 64),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        x = self.model(x)
        return x

model=torch.load('../han_9.pth',map_location=torch.device('cpu'))#将GPU上运行的模型转移到CPU
print(model)

#对图片进行reshap
image=torch.reshape(image,(-1,3,32,32))

#将模型转化为测试类型
model.eval()
with torch.no_grad():#节约内存
    output=model(image)
print(output)


print(output.argmax(1))

<PIL.PngImagePlugin.PngImageFile image mode=RGB size=306x283 at 0x250B0006EE0>
torch.Size(3, 32, 32)
Han(
(model): Sequential(
(0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(4): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(6): Flatten(start_dim=1, end_dim=-1)
(7): Linear(in_features=1024, out_features=64, bias=True)
(8): Linear(in_features=64, out_features=10, bias=True)
)
)
tensor(\[-2.0302, -0.6256, 0.7483, 1.5765, 0.2651, 2.2243, -0.7037, -0.5262, -1.4401, -0.6563])
tensor(5)
Process finished with exit code 0

预测正确!

相关推荐
魔术师Grace几秒前
真正值钱的 AI 小工具,可能只是帮人少打一遍字
前端·人工智能
韦东东3 分钟前
究极方案:油猴脚本实现RAG问答前端图片流式体验
人工智能·大模型·油猴脚本·rag·tampermonkey·userscript
云布道师4 分钟前
阿里云 OSS 向量 Bucket 正式商业化,提升 AI 应用效能
人工智能·阿里云·云计算
珠***格5 分钟前
边缘计算——“云-边-端”协同架构解析
大数据·人工智能·分布式·架构·能源·边缘计算
YJlio7 分钟前
OpenClaw v2026.5.26-beta.1 / beta.2 预发布解读:Gateway 加速、transcript 路径统一、多通道修复、语音增强与安装更新链路加固
人工智能·windows·python·ui·缓存·gateway·outlook
Cosolar1 小时前
AutoGen:微软开源的多Agent对话框架详解
人工智能·系统架构·大模型·agent·rag
Urbano1 小时前
一条休闲束脚裤的工业化诞生科普 八道自动化缝纫工序拆解
人工智能
陕西企来客5 小时前
企来客科技来客 GEO 优化系统深度解析:核心技术与原因分析
大数据·人工智能·科技·搜索引擎
来让爷抱一个8 小时前
MonkeyCode 多模型切换技巧:什么时候用 Claude/GPT/DeepSeek
人工智能·ai编程
李白你好8 小时前
AI Agent 架构的自动化渗透测试工具
运维·人工智能·自动化