(补)CNN 模型搭建与训练:PyTorch 实战 CIFAR10 任务的应用

一、代码核心定位:承接训练,实现单图预测

前文CNN 模型搭建与训练:PyTorch 实战 CIFAR10 任务-CSDN博客

已完成 CIFAR10 模型的三大核心步骤:

  1. 定义了Prayer卷积神经网络结构(model.py);
  2. 完成了 10 轮训练,得到了prayer_0.pthprayer_9.pth等训练好的模型文件;
  3. 验证了模型在测试集上的正确率最终达到约 55.6%。

而当前代码的核心目标是:

用训练好的模型(如prayer_29.pth),对一张自定义的图像(如dog.png)进行类别预测

把 "离线训练的模型" 转化为 "可实时预测的工具"。

二、代码逐段详解:从图像到预测结果的全流程

1. 前置准备:导入库与定义模型

这部分是模型推理的 "基础保障",确保代码能调用 PyTorch 工具和匹配训练时的模型结构。

python 复制代码
import torch                # PyTorch核心库,负责张量运算和模型推理
import torchvision          # 提供图像预处理工具
from PIL import Image       # 读取和处理图像的经典库
from torch import nn        # 神经网络模块,用于定义模型结构
  • 模型类Prayer的重复定义 :这里重新定义了与model.py完全一致的Prayer类,原因是torch.load加载完整模型时,需要当前环境中有对应的模型类定义(否则无法解析模型结构)。
  • 核心是保证推理时的模型结构与训练时完全一致,从输入通道(3)、卷积 / 池化层级,到全连接层维度(最终输出 10 类),均和训练阶段完全匹配。

2. 图像预处理:让输入符合模型要求

CIFAR10 训练时,图像是 "32×32 像素的 RGB 彩色图 + 张量格式",因此自定义图像必须经过相同预处理,否则模型无法识别。

  • 步骤 1:读取图像 image = Image.open(image_path).convert('RGB')

    • PIL.Image读取图像文件;
    • convert('RGB')强制转为 3 通道彩色图,避免灰度图(1 通道)或透明图(4 通道)导致通道数不匹配。
  • 步骤 2:标准化预处理

python 复制代码
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((32, 32)),  # 缩放到32×32,匹配训练输入尺寸
    torchvision.transforms.ToTensor()         # 转为Tensor:像素值从[0,255]→[0,1],维度从(HWC)→(CHW)
])
image = transform(image)  # 处理后形状:(3, 32, 32)
  • 这一步是关键匹配项 :如果图像尺寸、格式与训练数据不一致,模型会因输入维度错误直接报错。

3. 调整输入形状:适配模型的批量推理逻辑

  • 模型训练时处理的是 "批量数据"(如batch_size=64,输入形状为(64, 3, 32, 32)),即使推理单张图像,也需要调整为 "批量维度为 1" 的格式。
  • image = torch.reshape(image, (1, 3, 32, 32)):将(3, 32, 32)转为(1, 3, 32, 32),其中1代表 "当前批量只有 1 张图"。

模型在训练时已经 "习惯了" 接收4 个维度的输入 (批量大小 + 通道 + 高 + 宽)。就像你去自动售货机买水,机器的投币口只接受 "竖着插卡",如果你横着插,即使卡是对的,机器也不认 ------ 模型也有这样的 "输入格式洁癖"。

比如:

  • 模型的第一层是卷积层 nn.Conv2d(in_channels=3, ...),它要求输入必须是 4 维张量 (批量大小 ×3×32×32);
  • 如果你直接输入单张图的 3 维张量 (3, 32, 32),模型会 "困惑":"第一个维度应该是批量大小,怎么没有了?" 然后直接报错。

4. 模型加载与推理:核心预测环节

这部分是连接 "训练成果" 与 "预测结果" 的桥梁。

  • 加载训练好的模型

  • model = torch.load("prayer_29.pth", map_location='cpu', weights_only=False)

    • prayer_29.pth:训练保存的模型文件(前文训练 10 轮,此处文件名可能为示例,实际对应某一轮训练结果);
    • map_location='cpu':指定在 CPU 上推理(无需 GPU 也能运行,兼容更多环境);
    • weights_only=False:允许加载 "完整模型"(包含结构 + 权重),适配前文torch.save(prayer, ...)的保存方式。
  • 切换模型为评估模式 model.eval():将模型从 "训练模式" 切换为 "评估模式",关闭 Dropout(此处模型未用,但为通用规范)、固定 BatchNorm 等层的参数,确保推理结果稳定。

  • 无梯度推理

    python 复制代码
    with torch.no_grad():
        output = model(image)  # 模型输出:(1, 10)的张量
  • with torch.no_grad():关闭梯度计算,减少内存占用、加快推理速度(推理阶段无需更新参数,梯度无用);

  • output形状为(1, 10):对应 1 个样本、10 个类别的 "预测分数"(非概率,数值越大代表模型认为属于该类的可能性越高)。


5. 输出预测结果:解读模型输出

  • 打印预测分数print(output)输出 10 个类别的原始分数,
  • 例如某类分数为2.5,另一类为-1.2,分数越高概率越大。
  • 打印预测类别索引print(output.argmax(1))
    • argmax(1):在 "类别维度"(第 1 维,对应 10 个类别)上取最大值的索引,结果为0-9中的一个;
    • 该索引对应 CIFAR10 的类别(如0=飞机1=汽车3=猫5=狗等,需对照 CIFAR10 类别表解读)。

CIFAR10 类别索引 - 名称映射表

类别索引 对应类别名称 英文名称
0 飞机 airplane
1 汽车 automobile
2 bird
3 cat
4 鹿 deer
5 dog
6 青蛙 frog
7 horse
8 ship
9 卡车 truck
相关推荐
C7211BA4 小时前
世界模型和大语言模型的区别
人工智能·语言模型·自然语言处理
paid槮5 小时前
深度学习复习汇总
人工智能·深度学习
Light605 小时前
深度学习 × 计算机视觉 × Kaggle(上):从理论殿堂起步 ——像素、特征与模型的进化之路
人工智能·深度学习·计算机视觉·卷积神经网络·transformer·特征学习
天外飞雨道沧桑5 小时前
JS/CSS实现元素样式隔离
前端·javascript·css·人工智能·ai
深圳UMI5 小时前
UMI无忧秘书智脑:实现生活与工作全面智能化服务
大数据·人工智能
Antonio9155 小时前
【图像处理】图像形态学操作
图像处理·人工智能·opencv
Theodore_10226 小时前
机器学习(7)逻辑回归及其成本函数
人工智能·机器学习
AKAMAI6 小时前
Akamai与Bitmovin:革新直播与点播视频流服务
人工智能·云原生·云计算
文火冰糖的硅基工坊6 小时前
[人工智能-大模型-54]:模型层技术 - 数据结构+算法 = 程序
数据结构·人工智能·算法