学习pytorch20 pytorch完整的模型验证套路

pytorch完整的模型验证套路

B站小土堆pytorch学习视频 https://www.bilibili.com/video/BV1hE411t7RN/?p=32\&spm_id_from=pageDriver\&vd_source=9607a6d9d829b667f8f0ccaaaa142fcb

使用非数据集的测试数据,测试训练好模型的效果

复制代码
 测试:训练好的模型,提供对外真实数据的一个实际应用

从网上下载两张图片,整理图片的输入格式,输入模型测试模型效果

代码

py 复制代码
import torch
from torch import nn
from torchvision import transforms
from PIL import Image
import cv2

dog_path = './images/dog.jpg'
airplane_path = './images/airplane.jpg'
model_path = './images/net_epoch9_gpu.pth'

dog_pil = Image.open(dog_path)
airp_pil = Image.open(airplane_path)
print(dog_pil)  # <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=258x174 at 0x237CC2CBE50>
# RGB 3通道 匹配模型输入的通道数
dog_pil = dog_pil.convert('RGB')  # def convert(self, mode=None, matrix=None, dither=None, palette=Palette.WEB, colors=256):
airp_pil = airp_pil.convert('RGB')
# dog_cv = cv2.imread(dog_path)  # numpy.array
# # print(dog_cv)
# img_trans = torchvision.transforms.ToTensor()  # 实例化转tensor的类
# dog_tensor = img_trans(dog_pil)
# dog_cv_tensor = img_trans(dog_cv)
# print(dog_tensor)
# print(dog_tensor.shape)
# print(dog_cv_tensor)
# 输入模型shape 需要是32*32大小的
transform = transforms.Compose([transforms.Resize((32,32)),
                                transforms.ToTensor()])
dog_tensor = transform(dog_pil)
airp_tensor = transform(airp_pil)
# print(dog_tensor)
print(dog_tensor.shape, airp_tensor.shape)


class Cifar10Net(nn.Module):
    def __init__(self):
        super(Cifar10Net, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(kernel_size=2),
            nn.Flatten(),
            nn.Linear(64*4*4, 64),
            nn.ReLU(),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        x = self.net(x)
        return x

# 加载模型要考虑是以哪种形式保存的模型  模型保存方式1:保存模型结构和参数 方式二:只保存模型参数
model = torch.load(model_path, map_location=torch.device('cpu'))
print(model)
dog_tensor = dog_tensor.reshape((1, 3, 32, 32))
airp_tensor = airp_tensor.reshape((1, 3, 32, 32))
model.eval() # 设置模型为测试状态 网络层有dropout batchNormal层不加eval函数会有问题
with torch.no_grad(): # 测试不做梯度计算 节省算力
    dog_output = model(dog_tensor)
    airp_output = model(airp_tensor)
print(dog_output)
print(dog_output.argmax())
print(dog_output.argmax(1))
print(airp_output)
print(airp_output.argmax(1))  # 概率值不便于解读 使用argmax 可以很方便的读出模型预测的是哪个类别

预测结果

sh 复制代码
tensor([[ 1.1317, -4.3441,  3.2116,  2.8930,  2.6749,  4.6079, -3.2860,  3.1357,
         -3.0432, -4.1703]])
tensor(5)
tensor([5])
tensor([[ 5.5993, -0.6140,  4.4758,  0.8463,  1.6311, -1.0217, -3.9990, -2.8343,
          1.1050, -1.6423]])
tensor([0])

预测结果和训练数据的标注一直,预测正确

解决报错

RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x16 and 1024x64)

解决: dog_tensor = dog_tensor.reshape((1, 3, 32, 32)) 转换输入是4维的, 模型输入有一个batch-size维度

RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.

解决:model = torch.load(model_path, map_location=torch.device('cpu'))

在gpu上训练的模型,要在cpu上测试,模型加载的时候指定cpu设备

相关推荐
belldeep5 小时前
python:用 Flask 3 , mistune 2 和 mermaid.min.js 10.9 来实现 Markdown 中 mermaid 图表的渲染
javascript·python·flask
喵手5 小时前
Python爬虫实战:电商价格监控系统 - 从定时任务到历史趋势分析的完整实战(附CSV导出 + SQLite持久化存储)!
爬虫·python·爬虫实战·零基础python爬虫教学·电商价格监控系统·从定时任务到历史趋势分析·采集结果sqlite存储
喵手5 小时前
Python爬虫实战:京东/淘宝搜索多页爬虫实战 - 从反爬对抗到数据入库的完整工程化方案(附CSV导出 + SQLite持久化存储)!
爬虫·python·爬虫实战·零基础python爬虫教学·京东淘宝页面数据采集·反爬对抗到数据入库·采集结果csv导出
B站_计算机毕业设计之家6 小时前
猫眼电影数据可视化与智能分析平台 | Python Flask框架 Echarts 推荐算法 爬虫 大数据 毕业设计源码
python·机器学习·信息可视化·flask·毕业设计·echarts·推荐算法
PPPPPaPeR.6 小时前
光学算法实战:深度解析镜片厚度对前后表面折射/反射的影响(纯Python实现)
开发语言·python·数码相机·算法
JaydenAI6 小时前
[拆解LangChain执行引擎] ManagedValue——一种特殊的只读虚拟通道
python·langchain
骇城迷影6 小时前
Makemore 核心面试题大汇总
人工智能·pytorch·python·深度学习·线性回归
长安牧笛6 小时前
反传统学习APP,摒弃固定课程顺序,根据用户做题正确性,学习速度,动态调整课程难度,比如某知识点学不会,自动推荐基础讲解和练习题,学习后再进阶,不搞一刀切。
python·编程语言
码界筑梦坊6 小时前
330-基于Python的社交媒体舆情监控系统
python·mysql·信息可视化·数据分析·django·毕业设计·echarts