117_PyTorch 实战:利用训练好的模型进行单张图片验证

训练模型(Train)和测试模型(Test)最终都是为了应用(Inference) 。本篇将教你如何加载已经保存的 .pth 模型文件,并用一张外部图片来检验它的分类能力。

1. 验证流程的三大核心步骤

  1. 准备测试环境:包括加载模型、处理单张图片。
  2. 图像预处理 :输入的图片必须经过与训练时完全相同的缩放(Resize)和归一化(ToTensor)。
  3. 模型推理:将图片送入模型,并解析输出结果。

2. 代码实战:验证模型的分类结果

文件演示了如何加载一个在 CIFAR-10 上训练好的模型,并识别一张"狗"或"飞机"的图片。

Python

复制代码
import torch
import torchvision
from PIL import Image
from torch import nn

# 1. 读取外部图片
image_path = "dog.png" # 或者是你本地的图片路径
img = Image.open(image_path)
# 如果是 RGBA 格式(带透明度),需转为 RGB
img = img.convert('RGB')

# 2. 图像预处理(必须与训练时保持一致)
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((32, 32)),
    torchvision.transforms.ToTensor()
])
img = transform(img)
# 增加 Batch 维度,从 [3, 32, 32] 变为 [1, 3, 32, 32]
img = torch.reshape(img, (1, 3, 32, 32))

# 3. 加载训练好的模型
# 注意:如果模型是用方式一保存的,加载时需要能访问到网络定义类
model = torch.load("tudui_29.pth", map_location=torch.device('cpu'))

# 4. 进入验证模式
model.eval()
with torch.no_grad():
    output = model(img)

# 5. 解析输出
# output 是一个长度为 10 的向量,值最大的位置即为预测类别
print(output)
predict_idx = output.argmax(1).item()

# CIFAR-10 的类别映射(固定顺序)
classes = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
print(f"预测结果是:{classes[predict_idx]}")

3. 关键细节分析

为什么需要 img.convert('RGB')

PNG 图片通常包含 4 个通道(RGBA),其中 A 是透明度。而我们的模型是基于 RGB 3 通道训练的。使用 convert('RGB') 可以保证无论输入什么格式的图片,都能适配模型。

map_location 的妙用

如果你的模型是在 GPU 上训练保存的,但现在你想在只有 CPU 的电脑上运行验证,加载时必须加上 map_location=torch.device('cpu'),否则会报错。

argmax(1) 的逻辑

模型输出的是图片属于这 10 个类别的"概率得分"。通过 argmax(1),我们能直接提取出得分最高的那个位置的索引(例如 5 代表狗)。


4. 总结:从模型到应用

分析完这个文件,我们就完成了从数据采集到实战部署的全流程:

  1. 训练集/测试集->训练出高准确率模型。
  2. 保存模型-> 持久化存储。
  3. 单张验证->将模型应用到真实场景中。

💡 学习小结

当你能成功地输入一张从网上下载的图片,并让模型正确报出"cat"或"dog"时,你就真正完成了深度学习的闭环。

相关推荐
ZhengEnCi9 小时前
09bad-斯坦福CS336作业一-构建优化器
人工智能
ZhengEnCi10 小时前
09bac-斯坦福CS336作业一-实现训练损失计算
人工智能
冬奇Lab10 小时前
Skill 系列(01):Skill 评测体系——如何量化一个 AI Skill 的质量
人工智能
兵慌码乱12 小时前
基于 MediaPipe 与 PySide2 的手势交互音乐控制系统实现:轻量化视觉交互全流程解析
python·opencv·计算机视觉·人机交互·手势识别·mediapipe·pyside2
IT_陈寒13 小时前
Redis内存爆了,原来我漏掉了这个致命配置
前端·人工智能·后端
luckdewei15 小时前
FastAPI 资产管理系统实战:复杂 ORM 关联、Alembic 迁移与 N+1 查询优化
python
用户35218024547515 小时前
🎆从 Prompt 到 Skill:让 Spring AI Agent 学会"装新技能"
人工智能·spring boot·ai编程
米小虾15 小时前
手把手教你搭建第一个生产级AI Agent:从选型到实战的完整指南
人工智能·agent
任沫15 小时前
Agent之Function Call
javascript·人工智能·go
米小虾15 小时前
2026年AI Agent全面爆发:从开源生态到企业级应用的进化之路
人工智能·agent