前言
CIFAR10 作为计算机视觉领域经典的图像分类数据集,常被用于验证各类深度学习模型的性能。本文将基于 Vision Transformer(ViT)模型,详细讲解如何加载预训练权重、对 CIFAR10 测试集进行抽样测试,并完整解析测试流程中的核心代码逻辑,帮助大家掌握模型测试的关键步骤。
一、测试流程整体设计
本次测试的核心目标是:加载训练好的 ViT 权重文件,从 CIFAR10 测试集中随机抽取指定数量样本,验证模型预测效果并统计准确率。整体流程可分为 3 个核心步骤:
- 数据预处理:加载 CIFAR10 测试集并完成标准化等预处理;
- 模型加载:构建 ViT 模型结构并加载预训练权重;
- 抽样测试:随机抽取样本进行预测,输出详细结果并统计准确率。
二、完整代码解析
2.1 环境与依赖准备
首先确保安装以下核心依赖:
pip install torch torchvision numpy
核心依赖说明:
- torch/torchvision:PyTorch 框架及视觉工具库,用于模型构建和数据加载;
- numpy:数值计算基础库;
- cudnn:GPU 加速库(可选,有 GPU 时启用)。
2.2 完整测试代码
# -*- coding: utf-8 -*-
'''
ViT测试代码:加载已保存的权重文件,对CIFAR10测试集进行抽样测试并打印结果
'''
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
import numpy as np
import random
# ===================== 全局常量配置 =====================
batch_size = 128 # 批次大小(测试时实际用1,此处为配置保留)
img_size = 32 # 输入图像尺寸(CIFAR10原生尺寸)
patch = 4 # ViT图像分块大小
dimhead = 512 # ViT特征维度
num_classes = 10 # CIFAR10类别数
WEIGHT_PATH = './vit_net_30epochs.pth' # 预训练权重路径
SAMPLE_NUM = 10 # 抽样测试样本数
device = 'cuda' if torch.cuda.is_available() else 'cpu' # 设备选择
# ===================== 主函数 =====================
def main():
# 1. 加载数据集(仅加载测试集)
print('==> Preparing CIFAR10 test data..')
# CIFAR10官方均值和标准差(标准化关键参数)
mean = (0.4914, 0.4822, 0.4465)
std = (0.2023, 0.1994, 0.2010)
# 测试集变换(必须与训练时保持一致)
transform_test = transforms.Compose([
transforms.Resize(img_size), # 调整尺寸(此处与原生一致,可省略)
transforms.ToTensor(), # 转为Tensor并归一化到[0,1]
transforms.Normalize(mean, std), # 标准化:(x-mean)/std
])
# 加载CIFAR10测试集
testset = torchvision.datasets.CIFAR10(
root='./data', # 数据集保存路径
train=False, # 加载测试集
download=True, # 无数据时自动下载
transform=transform_test # 应用预处理
)
# 随机抽样:生成测试集索引并随机选择指定数量
test_indices = list(range(len(testset)))
sample_indices = random.sample(test_indices, SAMPLE_NUM)
# 构建抽样数据集和数据加载器
sample_dataset = torch.utils.data.Subset(testset, sample_indices)
sample_loader = torch.utils.data.DataLoader(
sample_dataset,
batch_size=1, # 逐个处理,方便打印单样本详细信息
shuffle=False
)
# CIFAR10类别名称(对应标签0-9)
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
# 2. 构建模型 + 加载训练好的权重
print('==> Building ViT model and loading trained weights..')
# 导入自定义ViT模型(需确保models.vit.py文件存在)
from models.vit import ViT
net = ViT(
image_size=img_size,
patch_size=patch,
num_classes=num_classes,
dim=dimhead,
depth=6,
heads=8,
mlp_dim=512,
dropout=0.1,
emb_dropout=0.1
)
# 将模型移至指定设备(GPU/CPU)
net = net.to(device)
# GPU加速优化(可选)
if device == 'cuda':
cudnn.benchmark = True
# 加载预训练权重
checkpoint = torch.load(WEIGHT_PATH, map_location=device)
net.load_state_dict(checkpoint)
print(f'✅ 成功加载训练权重: {WEIGHT_PATH}')
# 3. 抽样测试并打印结果
print(f'\n========== 开始抽样测试(共{len(sample_loader)}个样本) ==========')
net.eval() # 关键:设置模型为评估模式(禁用Dropout/BatchNorm训练行为)
correct_count = 0
with torch.no_grad(): # 关键:禁用梯度计算,节省内存+加速推理
for idx, (inputs, targets) in enumerate(sample_loader):
# 获取样本在原测试集中的索引
original_idx = sample_indices[idx]
# 将数据移至指定设备
inputs, targets = inputs.to(device), targets.to(device)
# 模型推理:前向传播得到预测输出
outputs = net(inputs)
# 获取预测类别(取概率最大的类别)
_, predicted = outputs.max(1)
# 转换为类别名称(方便阅读)
target_class = classes[targets.item()]
predicted_class = classes[predicted.item()]
# 判断预测是否正确
is_correct = target_class == predicted_class
if is_correct:
correct_count += 1
result_mark = "✅ 正确"
else:
result_mark = "❌ 错误"
# 打印单样本详细结果
print(f'\n样本 {idx+1} (测试集索引: {original_idx}):')
print(f' 真实类别: {target_class}')
print(f' 预测类别: {predicted_class}')
print(f' 测试结果: {result_mark}')
# 打印测试统计结果
print(f'\n========== 测试统计 ==========')
print(f'抽样总数: {SAMPLE_NUM}')
print(f'预测正确: {correct_count}')
print(f'预测错误: {SAMPLE_NUM - correct_count}')
print(f'抽样准确率: {correct_count/SAMPLE_NUM*100:.2f}%')
# ===================== 启动 =====================
if __name__ == '__main__':
from multiprocessing import freeze_support
freeze_support() # Windows系统多进程兼容(可选)
main()
2.3 核心代码详解
(1)数据预处理关键
transform_test = transforms.Compose([
transforms.Resize(img_size),
transforms.ToTensor(),
transforms.Normalize(mean, std),
])
ToTensor():将 PIL 图像转为torch.Tensor,并将像素值从[0,255]归一化到[0,1];Normalize(mean, std):使用 CIFAR10 官方统计的均值和标准差进行标准化,必须与训练时保持一致,否则会导致模型预测失效;Resize(img_size):此处 CIFAR10 原生尺寸就是 32x32,可省略,但若训练时调整过尺寸,测试时必须同步。
(2)随机抽样实现
test_indices = list(range(len(testset)))
sample_indices = random.sample(test_indices, SAMPLE_NUM)
sample_dataset = torch.utils.data.Subset(testset, sample_indices)
random.sample():从测试集索引中随机选择指定数量样本,避免重复;torch.utils.data.Subset:基于索引构建子集数据集,方便后续加载。
(3)模型加载与评估模式
net.eval() # 评估模式
with torch.no_grad(): # 禁用梯度
# 推理代码
net.eval():核心操作,将模型中的 Dropout、BatchNorm 等层切换为评估模式,避免影响预测结果;torch.no_grad():禁用梯度计算,大幅减少内存占用并提升推理速度,测试 / 推理阶段必须使用。
(4)预测结果解析
_, predicted = outputs.max(1)
outputs是模型输出的形状为[1, 10]的张量(对应 10 个类别的预测得分);max(1)表示在维度 1(类别维度)上取最大值,返回值中第一个元素是最大值,第二个元素是最大值对应的索引(即预测类别)。
三、运行结果示例
==> Preparing CIFAR10 test data..
Files already downloaded and verified
==> Building ViT model and loading trained weights..
✅ 成功加载训练权重: ./vit_net_30epochs.pth
========== 开始抽样测试(共10个样本) ==========
样本 1 (测试集索引: 1234):
真实类别: cat
预测类别: cat
测试结果: ✅ 正确
样本 2 (测试集索引: 5678):
真实类别: car
预测类别: truck
测试结果: ❌ 错误
...
========== 测试统计 ==========
抽样总数: 10
预测正确: 8
预测错误: 2
抽样准确率: 80.00%
四、扩展与优化建议
- 全量测试 :若需统计完整测试集准确率,可直接使用完整
testset构建DataLoader,批量推理并统计总准确率; - 可视化结果 :结合
matplotlib绘制预测错误的样本图像,直观分析模型薄弱点; - 批量推理 :将
batch_size调大(如 64/128),提升测试效率; - 混淆矩阵:生成 CIFAR10 10x10 混淆矩阵,分析各类别预测效果。
总结
本文详细讲解了基于 ViT 模型测试 CIFAR10 数据集的完整流程,核心要点包括:
- 测试集预处理必须与训练集严格一致,尤其是标准化参数;
- 模型测试时需切换到
eval()模式并禁用梯度计算(torch.no_grad()); - 随机抽样测试可快速验证模型效果,全量测试需统计完整测试集准确率。
掌握以上方法后,你可以轻松将这套测试流程迁移到其他图像分类模型(如 ResNet、CNN 等)和数据集(如 CIFAR100、MNIST)上,快速验证模型性能。