【深度学习保姆级教程】ViT 模型测试 CIFAR10 数据集:从权重加载到抽样验证全流程(附上资源)

前言

CIFAR10 作为计算机视觉领域经典的图像分类数据集,常被用于验证各类深度学习模型的性能。本文将基于 Vision Transformer(ViT)模型,详细讲解如何加载预训练权重、对 CIFAR10 测试集进行抽样测试,并完整解析测试流程中的核心代码逻辑,帮助大家掌握模型测试的关键步骤。

一、测试流程整体设计

本次测试的核心目标是:加载训练好的 ViT 权重文件,从 CIFAR10 测试集中随机抽取指定数量样本,验证模型预测效果并统计准确率。整体流程可分为 3 个核心步骤:

  1. 数据预处理:加载 CIFAR10 测试集并完成标准化等预处理;
  2. 模型加载:构建 ViT 模型结构并加载预训练权重;
  3. 抽样测试:随机抽取样本进行预测,输出详细结果并统计准确率。

二、完整代码解析

2.1 环境与依赖准备

首先确保安装以下核心依赖:

复制代码
pip install torch torchvision numpy

核心依赖说明:

  • torch/torchvision:PyTorch 框架及视觉工具库,用于模型构建和数据加载;
  • numpy:数值计算基础库;
  • cudnn:GPU 加速库(可选,有 GPU 时启用)。

2.2 完整测试代码

复制代码
# -*- coding: utf-8 -*-
'''
ViT测试代码:加载已保存的权重文件,对CIFAR10测试集进行抽样测试并打印结果
'''
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
import numpy as np
import random

# ===================== 全局常量配置 =====================
batch_size = 128   # 批次大小(测试时实际用1,此处为配置保留)
img_size = 32      # 输入图像尺寸(CIFAR10原生尺寸)
patch = 4          # ViT图像分块大小
dimhead = 512      # ViT特征维度
num_classes = 10   # CIFAR10类别数
WEIGHT_PATH = './vit_net_30epochs.pth'  # 预训练权重路径
SAMPLE_NUM = 10    # 抽样测试样本数
device = 'cuda' if torch.cuda.is_available() else 'cpu'  # 设备选择

# ===================== 主函数 =====================
def main():
    # 1. 加载数据集(仅加载测试集)
    print('==> Preparing CIFAR10 test data..')
    # CIFAR10官方均值和标准差(标准化关键参数)
    mean = (0.4914, 0.4822, 0.4465)
    std = (0.2023, 0.1994, 0.2010)

    # 测试集变换(必须与训练时保持一致)
    transform_test = transforms.Compose([
        transforms.Resize(img_size),  # 调整尺寸(此处与原生一致,可省略)
        transforms.ToTensor(),        # 转为Tensor并归一化到[0,1]
        transforms.Normalize(mean, std),  # 标准化:(x-mean)/std
    ])

    # 加载CIFAR10测试集
    testset = torchvision.datasets.CIFAR10(
        root='./data',        # 数据集保存路径
        train=False,          # 加载测试集
        download=True,        # 无数据时自动下载
        transform=transform_test  # 应用预处理
    )
    # 随机抽样:生成测试集索引并随机选择指定数量
    test_indices = list(range(len(testset)))
    sample_indices = random.sample(test_indices, SAMPLE_NUM)
    
    # 构建抽样数据集和数据加载器
    sample_dataset = torch.utils.data.Subset(testset, sample_indices)
    sample_loader = torch.utils.data.DataLoader(
        sample_dataset, 
        batch_size=1,  # 逐个处理,方便打印单样本详细信息
        shuffle=False
    )

    # CIFAR10类别名称(对应标签0-9)
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 
               'dog', 'frog', 'horse', 'ship', 'truck')

    # 2. 构建模型 + 加载训练好的权重
    print('==> Building ViT model and loading trained weights..')
    # 导入自定义ViT模型(需确保models.vit.py文件存在)
    from models.vit import ViT
    net = ViT(
        image_size=img_size,
        patch_size=patch,
        num_classes=num_classes,
        dim=dimhead,
        depth=6,
        heads=8,
        mlp_dim=512,
        dropout=0.1,
        emb_dropout=0.1
    )
    # 将模型移至指定设备(GPU/CPU)
    net = net.to(device)
    # GPU加速优化(可选)
    if device == 'cuda':
        cudnn.benchmark = True

    # 加载预训练权重
    checkpoint = torch.load(WEIGHT_PATH, map_location=device)
    net.load_state_dict(checkpoint)
    print(f'✅ 成功加载训练权重: {WEIGHT_PATH}')

    # 3. 抽样测试并打印结果
    print(f'\n========== 开始抽样测试(共{len(sample_loader)}个样本) ==========')
    net.eval()  # 关键:设置模型为评估模式(禁用Dropout/BatchNorm训练行为)
    correct_count = 0
    
    with torch.no_grad():  # 关键:禁用梯度计算,节省内存+加速推理
        for idx, (inputs, targets) in enumerate(sample_loader):
            # 获取样本在原测试集中的索引
            original_idx = sample_indices[idx]
            
            # 将数据移至指定设备
            inputs, targets = inputs.to(device), targets.to(device)
            # 模型推理:前向传播得到预测输出
            outputs = net(inputs)
            
            # 获取预测类别(取概率最大的类别)
            _, predicted = outputs.max(1)
            
            # 转换为类别名称(方便阅读)
            target_class = classes[targets.item()]
            predicted_class = classes[predicted.item()]
            
            # 判断预测是否正确
            is_correct = target_class == predicted_class
            if is_correct:
                correct_count += 1
                result_mark = "✅ 正确"
            else:
                result_mark = "❌ 错误"
            
            # 打印单样本详细结果
            print(f'\n样本 {idx+1} (测试集索引: {original_idx}):')
            print(f'  真实类别: {target_class}')
            print(f'  预测类别: {predicted_class}')
            print(f'  测试结果: {result_mark}')
    
    # 打印测试统计结果
    print(f'\n========== 测试统计 ==========')
    print(f'抽样总数: {SAMPLE_NUM}')
    print(f'预测正确: {correct_count}')
    print(f'预测错误: {SAMPLE_NUM - correct_count}')
    print(f'抽样准确率: {correct_count/SAMPLE_NUM*100:.2f}%')

# ===================== 启动 =====================
if __name__ == '__main__':
    from multiprocessing import freeze_support
    freeze_support()  # Windows系统多进程兼容(可选)
    main()

2.3 核心代码详解

(1)数据预处理关键
复制代码
transform_test = transforms.Compose([
    transforms.Resize(img_size),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])
  • ToTensor():将 PIL 图像转为torch.Tensor,并将像素值从[0,255]归一化到[0,1]
  • Normalize(mean, std):使用 CIFAR10 官方统计的均值和标准差进行标准化,必须与训练时保持一致,否则会导致模型预测失效;
  • Resize(img_size):此处 CIFAR10 原生尺寸就是 32x32,可省略,但若训练时调整过尺寸,测试时必须同步。
(2)随机抽样实现
复制代码
test_indices = list(range(len(testset)))
sample_indices = random.sample(test_indices, SAMPLE_NUM)
sample_dataset = torch.utils.data.Subset(testset, sample_indices)
  • random.sample():从测试集索引中随机选择指定数量样本,避免重复;
  • torch.utils.data.Subset:基于索引构建子集数据集,方便后续加载。
(3)模型加载与评估模式
复制代码
net.eval()  # 评估模式
with torch.no_grad():  # 禁用梯度
    # 推理代码
  • net.eval()核心操作,将模型中的 Dropout、BatchNorm 等层切换为评估模式,避免影响预测结果;
  • torch.no_grad():禁用梯度计算,大幅减少内存占用并提升推理速度,测试 / 推理阶段必须使用。
(4)预测结果解析
复制代码
_, predicted = outputs.max(1)
  • outputs是模型输出的形状为[1, 10]的张量(对应 10 个类别的预测得分);
  • max(1)表示在维度 1(类别维度)上取最大值,返回值中第一个元素是最大值,第二个元素是最大值对应的索引(即预测类别)。

三、运行结果示例

复制代码
==> Preparing CIFAR10 test data..
Files already downloaded and verified
==> Building ViT model and loading trained weights..
✅ 成功加载训练权重: ./vit_net_30epochs.pth

========== 开始抽样测试(共10个样本) ==========

样本 1 (测试集索引: 1234):
  真实类别: cat
  预测类别: cat
  测试结果: ✅ 正确

样本 2 (测试集索引: 5678):
  真实类别: car
  预测类别: truck
  测试结果: ❌ 错误

...

========== 测试统计 ==========
抽样总数: 10
预测正确: 8
预测错误: 2
抽样准确率: 80.00%

四、扩展与优化建议

  1. 全量测试 :若需统计完整测试集准确率,可直接使用完整testset构建DataLoader,批量推理并统计总准确率;
  2. 可视化结果 :结合matplotlib绘制预测错误的样本图像,直观分析模型薄弱点;
  3. 批量推理 :将batch_size调大(如 64/128),提升测试效率;
  4. 混淆矩阵:生成 CIFAR10 10x10 混淆矩阵,分析各类别预测效果。

总结

本文详细讲解了基于 ViT 模型测试 CIFAR10 数据集的完整流程,核心要点包括:

  1. 测试集预处理必须与训练集严格一致,尤其是标准化参数;
  2. 模型测试时需切换到eval()模式并禁用梯度计算(torch.no_grad());
  3. 随机抽样测试可快速验证模型效果,全量测试需统计完整测试集准确率。

掌握以上方法后,你可以轻松将这套测试流程迁移到其他图像分类模型(如 ResNet、CNN 等)和数据集(如 CIFAR100、MNIST)上,快速验证模型性能。

相关推荐
思考的小屋2 小时前
Transformer001 介绍激活函数
人工智能
福客AI智能客服2 小时前
跨渠协同赋能:AI智能客服重构电商客服系统服务生态
大数据·人工智能
白日做梦Q2 小时前
EfficientNet解析:用复合缩放统一CNN架构
人工智能·架构·cnn
红尘炼丹客2 小时前
DeepSeek 新作 mHC 解读:用流形约束(Manifold Constraints)重构大模型残差连接
人工智能·深度学习·大模型·mhc
70asunflower2 小时前
RL(强化学习,Reinforcement Learning)
人工智能·机器学习
盼小辉丶2 小时前
Transformer实战(34)——多语言和跨语言Transformer模型
深度学习·语言模型·transformer
Σίσυφος19002 小时前
张正友标定法原理总结2
人工智能·数码相机·计算机视觉
乾元2 小时前
兵器谱——深度学习、强化学习与 NLP 在安全中的典型应用场景
运维·网络·人工智能·深度学习·安全·自然语言处理·自动化
张祥6422889042 小时前
GNSS单点定位方程推导笔记
人工智能·算法·机器学习