训练模型(Train)和测试模型(Test)最终都是为了应用(Inference) 。本篇将教你如何加载已经保存的 .pth 模型文件,并用一张外部图片来检验它的分类能力。
1. 验证流程的三大核心步骤
- 准备测试环境:包括加载模型、处理单张图片。
- 图像预处理 :输入的图片必须经过与训练时完全相同的缩放(Resize)和归一化(ToTensor)。
- 模型推理:将图片送入模型,并解析输出结果。
2. 代码实战:验证模型的分类结果
文件演示了如何加载一个在 CIFAR-10 上训练好的模型,并识别一张"狗"或"飞机"的图片。
Python
import torch
import torchvision
from PIL import Image
from torch import nn
# 1. 读取外部图片
image_path = "dog.png" # 或者是你本地的图片路径
img = Image.open(image_path)
# 如果是 RGBA 格式(带透明度),需转为 RGB
img = img.convert('RGB')
# 2. 图像预处理(必须与训练时保持一致)
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize((32, 32)),
torchvision.transforms.ToTensor()
])
img = transform(img)
# 增加 Batch 维度,从 [3, 32, 32] 变为 [1, 3, 32, 32]
img = torch.reshape(img, (1, 3, 32, 32))
# 3. 加载训练好的模型
# 注意:如果模型是用方式一保存的,加载时需要能访问到网络定义类
model = torch.load("tudui_29.pth", map_location=torch.device('cpu'))
# 4. 进入验证模式
model.eval()
with torch.no_grad():
output = model(img)
# 5. 解析输出
# output 是一个长度为 10 的向量,值最大的位置即为预测类别
print(output)
predict_idx = output.argmax(1).item()
# CIFAR-10 的类别映射(固定顺序)
classes = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
print(f"预测结果是:{classes[predict_idx]}")
3. 关键细节分析
为什么需要 img.convert('RGB')?
PNG 图片通常包含 4 个通道(RGBA),其中 A 是透明度。而我们的模型是基于 RGB 3 通道训练的。使用 convert('RGB') 可以保证无论输入什么格式的图片,都能适配模型。
map_location 的妙用
如果你的模型是在 GPU 上训练保存的,但现在你想在只有 CPU 的电脑上运行验证,加载时必须加上 map_location=torch.device('cpu'),否则会报错。
argmax(1) 的逻辑
模型输出的是图片属于这 10 个类别的"概率得分"。通过 argmax(1),我们能直接提取出得分最高的那个位置的索引(例如 5 代表狗)。
4. 总结:从模型到应用
分析完这个文件,我们就完成了从数据采集到实战部署的全流程:
- 训练集/测试集->训练出高准确率模型。
- 保存模型-> 持久化存储。
- 单张验证->将模型应用到真实场景中。
💡 学习小结
当你能成功地输入一张从网上下载的图片,并让模型正确报出"cat"或"dog"时,你就真正完成了深度学习的闭环。