117_PyTorch 实战:利用训练好的模型进行单张图片验证

训练模型(Train)和测试模型(Test)最终都是为了应用(Inference) 。本篇将教你如何加载已经保存的 .pth 模型文件,并用一张外部图片来检验它的分类能力。

1. 验证流程的三大核心步骤

  1. 准备测试环境:包括加载模型、处理单张图片。
  2. 图像预处理 :输入的图片必须经过与训练时完全相同的缩放(Resize)和归一化(ToTensor)。
  3. 模型推理:将图片送入模型,并解析输出结果。

2. 代码实战:验证模型的分类结果

文件演示了如何加载一个在 CIFAR-10 上训练好的模型,并识别一张"狗"或"飞机"的图片。

Python

复制代码
import torch
import torchvision
from PIL import Image
from torch import nn

# 1. 读取外部图片
image_path = "dog.png" # 或者是你本地的图片路径
img = Image.open(image_path)
# 如果是 RGBA 格式(带透明度),需转为 RGB
img = img.convert('RGB')

# 2. 图像预处理(必须与训练时保持一致)
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((32, 32)),
    torchvision.transforms.ToTensor()
])
img = transform(img)
# 增加 Batch 维度,从 [3, 32, 32] 变为 [1, 3, 32, 32]
img = torch.reshape(img, (1, 3, 32, 32))

# 3. 加载训练好的模型
# 注意:如果模型是用方式一保存的,加载时需要能访问到网络定义类
model = torch.load("tudui_29.pth", map_location=torch.device('cpu'))

# 4. 进入验证模式
model.eval()
with torch.no_grad():
    output = model(img)

# 5. 解析输出
# output 是一个长度为 10 的向量,值最大的位置即为预测类别
print(output)
predict_idx = output.argmax(1).item()

# CIFAR-10 的类别映射(固定顺序)
classes = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
print(f"预测结果是:{classes[predict_idx]}")

3. 关键细节分析

为什么需要 img.convert('RGB')

PNG 图片通常包含 4 个通道(RGBA),其中 A 是透明度。而我们的模型是基于 RGB 3 通道训练的。使用 convert('RGB') 可以保证无论输入什么格式的图片,都能适配模型。

map_location 的妙用

如果你的模型是在 GPU 上训练保存的,但现在你想在只有 CPU 的电脑上运行验证,加载时必须加上 map_location=torch.device('cpu'),否则会报错。

argmax(1) 的逻辑

模型输出的是图片属于这 10 个类别的"概率得分"。通过 argmax(1),我们能直接提取出得分最高的那个位置的索引(例如 5 代表狗)。


4. 总结:从模型到应用

分析完这个文件,我们就完成了从数据采集到实战部署的全流程:

  1. 训练集/测试集->训练出高准确率模型。
  2. 保存模型-> 持久化存储。
  3. 单张验证->将模型应用到真实场景中。

💡 学习小结

当你能成功地输入一张从网上下载的图片,并让模型正确报出"cat"或"dog"时,你就真正完成了深度学习的闭环。

相关推荐
Irene199110 小时前
Python 卸载与安装(以卸载3.13.3,装3.13.13为例)
python
予早10 小时前
使用 pyrasite-ng 和 guppy3 做内存分析
python·内存分析
hef28815 小时前
如何生成特定SQL的AWR报告_@awrsqrpt.sql深度剖析单条语句性能
jvm·数据库·python
小程故事多_8015 小时前
Agent+Milvus,告别静态知识库,打造具备动态记忆的智能AI助手
人工智能·深度学习·ai编程·milvus
code_pgf16 小时前
Llama 3详解
人工智能·llama
ComputerInBook16 小时前
数字图像处理(4版)——第 3 章——(图像的)强度变换和空间滤波(Rafael C.Gonzalez&Richard E. Woods)
图像处理·人工智能·计算机视觉·强度变换和空间滤波
爱写代码的小朋友16 小时前
生成式人工智能(AIGC)在开放式教育问答系统中的知识表征与推理机制研究
人工智能·aigc
Jinkxs16 小时前
从语法纠错到项目重构:Python+Copilot 的全流程开发效率提升指南
python·重构·copilot
技术专家16 小时前
Stable Diffusion系列的详细讨论 / Detailed Discussion of the Stable Diffusion Series
人工智能·python·算法·推荐算法·1024程序员节
m0_4889130116 小时前
万字长文带你梳理Llama开源家族:从Llama-1到Llama-3,看这一篇就够了!
人工智能·学习·机器学习·大模型·产品经理·llama·uml