117_PyTorch 实战:利用训练好的模型进行单张图片验证

训练模型(Train)和测试模型(Test)最终都是为了应用(Inference) 。本篇将教你如何加载已经保存的 .pth 模型文件,并用一张外部图片来检验它的分类能力。

1. 验证流程的三大核心步骤

  1. 准备测试环境:包括加载模型、处理单张图片。
  2. 图像预处理 :输入的图片必须经过与训练时完全相同的缩放(Resize)和归一化(ToTensor)。
  3. 模型推理:将图片送入模型,并解析输出结果。

2. 代码实战:验证模型的分类结果

文件演示了如何加载一个在 CIFAR-10 上训练好的模型,并识别一张"狗"或"飞机"的图片。

Python

复制代码
import torch
import torchvision
from PIL import Image
from torch import nn

# 1. 读取外部图片
image_path = "dog.png" # 或者是你本地的图片路径
img = Image.open(image_path)
# 如果是 RGBA 格式(带透明度),需转为 RGB
img = img.convert('RGB')

# 2. 图像预处理(必须与训练时保持一致)
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((32, 32)),
    torchvision.transforms.ToTensor()
])
img = transform(img)
# 增加 Batch 维度,从 [3, 32, 32] 变为 [1, 3, 32, 32]
img = torch.reshape(img, (1, 3, 32, 32))

# 3. 加载训练好的模型
# 注意:如果模型是用方式一保存的,加载时需要能访问到网络定义类
model = torch.load("tudui_29.pth", map_location=torch.device('cpu'))

# 4. 进入验证模式
model.eval()
with torch.no_grad():
    output = model(img)

# 5. 解析输出
# output 是一个长度为 10 的向量,值最大的位置即为预测类别
print(output)
predict_idx = output.argmax(1).item()

# CIFAR-10 的类别映射(固定顺序)
classes = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
print(f"预测结果是:{classes[predict_idx]}")

3. 关键细节分析

为什么需要 img.convert('RGB')

PNG 图片通常包含 4 个通道(RGBA),其中 A 是透明度。而我们的模型是基于 RGB 3 通道训练的。使用 convert('RGB') 可以保证无论输入什么格式的图片,都能适配模型。

map_location 的妙用

如果你的模型是在 GPU 上训练保存的,但现在你想在只有 CPU 的电脑上运行验证,加载时必须加上 map_location=torch.device('cpu'),否则会报错。

argmax(1) 的逻辑

模型输出的是图片属于这 10 个类别的"概率得分"。通过 argmax(1),我们能直接提取出得分最高的那个位置的索引(例如 5 代表狗)。


4. 总结:从模型到应用

分析完这个文件,我们就完成了从数据采集到实战部署的全流程:

  1. 训练集/测试集->训练出高准确率模型。
  2. 保存模型-> 持久化存储。
  3. 单张验证->将模型应用到真实场景中。

💡 学习小结

当你能成功地输入一张从网上下载的图片,并让模型正确报出"cat"或"dog"时,你就真正完成了深度学习的闭环。

相关推荐
程序员cxuan2 小时前
人麻了,谁把我 ssh 干没了
人工智能·后端·程序员
数据皮皮侠2 小时前
中国城市间地理距离矩阵(2024)
大数据·数据库·人工智能·算法·制造
枫叶林FYL2 小时前
【乳腺癌早期筛查(钼靶X光图像AI识别)】第一章:钼靶AI核心算法架构演进——从2D全视野到3D断层合成与视觉Transformer
人工智能·深度学习
Lethehong2 小时前
Python Selenium全栈指南:从自动化入门到企业级实战
python·selenium·测试工具·自动化
TK云大师-KK2 小时前
TikTok自动化直播遇到内容重复问题?这套技术方案了解一下
大数据·运维·人工智能·矩阵·自动化·新媒体运营·流量运营
姚青&2 小时前
大语言模型与私有部署
人工智能·语言模型·chatgpt
WeeJot嵌入式2 小时前
Meta LSP无数据训练深度解析:语言自我对弈的数学原理与实现
人工智能·机器学习·里氏替换原则
foundbug9992 小时前
基于卡尔曼滤波的背景建模与车辆检测(OpenCV实现)
人工智能·opencv·计算机视觉