用 PyTorch 搭建 CIFAR10 线性分类器:从数据加载到模型推理全流程解析

在深度学习入门过程中,图像分类是最经典的任务之一,而 CIFAR10 数据集则是入门图像分类的 "练手神器"。

一、前置知识:CIFAR10 数据集是什么?

1.1 CIFAR10 核心参数

CIFAR10(Canadian Institute for Advanced Research 10)是由加拿大高级研究所发布的图像分类数据集,包含10 个类别的彩色图像,具体信息如下:

类别 包含内容 样本数量(训练集) 样本数量(测试集)
0 飞机(airplane) 5000 1000
1 汽车(automobile) 5000 1000
2 鸟类(bird) 5000 1000
3 猫(cat) 5000 1000
4 鹿(deer) 5000 1000
5 狗(dog) 5000 1000
6 青蛙(frog) 5000 1000
7 马(horse) 5000 1000
8 船(ship) 5000 1000
9 卡车(truck) 5000 1000

1.2 图像尺寸与格式

CIFAR10 的每张图像都是3 通道彩色图像 (RGB),尺寸固定为32×32 像素(即高度 32、宽度 32)。

  • 通道数(C):3(R 红、G 绿、B 蓝)
  • 高度(H):32
  • 宽度(W):32
  • 单张图像展平后特征数:3×32×32 = 3072(这是后续线性层输入维度的关键依据)

二、代码拆解:从导入库到数据加载

代码的第一部分是 "数据准备",核心是将 CIFAR10 数据集加载到 PyTorch 中,并按批次处理。我们逐行解析:

2.1 导入必要库

python 复制代码
import torch          # PyTorch核心库(张量操作、自动求导等)
import torchvision    # PyTorch视觉库(数据集、图像变换、预训练模型等)
from torch import nn  # PyTorch神经网络模块(含线性层、卷积层等)
from torch.utils.data import DataLoader  # 数据加载器(按批次加载数据)

这是 PyTorch 视觉任务的 "标准开头",每个库的作用必须明确:

  • torch:所有操作的基础,比如张量(Tensor)的创建和计算。
  • torchvision:专门为计算机视觉设计,提供了 CIFAR10 等常用数据集,以及图像预处理工具。
  • torch.nn:搭建神经网络的核心,比如nn.Linear(线性层)、nn.Conv2d(卷积层)都在这里。
  • DataLoader:将数据集按批次分割,支持多线程加载,是训练时高效喂数据的关键。

2.2 加载 CIFAR10 测试集

python 复制代码
dataset = torchvision.datasets.CIFAR10(
    root='./data',          # 数据集保存路径(当前目录下的data文件夹)
    train=False,            # 是否为训练集:False表示加载测试集,True表示加载训练集
    download=True,          # 如果root路径下没有数据集,是否自动下载
    transform=torchvision.transforms.ToTensor()  # 图像变换:将PIL图像转为Tensor
)

(1)train=False的意义

  • 当train=True时,加载的是50000 张图像的训练集 (用于模型训练);
  • 当train=False时,加载的是10000 张图像的测试集 (用于验证模型性能);
  • 我们这里用测试集做演示,后续实际训练时需要切换为train=True

(2)transform=ToTensor()的作用

图像在计算机中原始存储格式是PIL 图像 (或 numpy 数组),像素值范围是[0, 255](整数),但 PyTorch 模型要求输入是Tensor 格式 ,且像素值归一化到[0, 1](浮点数)。ToTensor()做了两件事:

  1. 将 PIL 图像转为形状为[C, H, W]的 Tensor(注意:PIL 图像默认是[H, W, C],这里会自动转置通道顺序);
  2. 将像素值从[0, 255]除以 255,归一化到[0, 1]

举个例子:一张 PIL 格式的 CIFAR10 图像(32×32×3),经过ToTensor()后会变成[3, 32, 32]的 Tensor,每个元素值在 0~1 之间。

2.3 用 DataLoader 按批次加载数据

python 复制代码
dataloader = DataLoader(dataset, batch_size=64)

DataLoader的核心作用是将dataset(10000 张测试集图像)按batch_size=64分割成多个批次,方便模型批量处理(批量处理能提高计算效率,且符合梯度下降的原理)。

  • 总批次数量:10000 ÷ 64 ≈ 157(最后一个批次不足 64 张,实际为 10000 - 156×64 = 16 张);
  • 每个批次的数据格式:(imgs, targets),其中imgs是图像张量,targets是类别标签张量。

三、核心:搭建线性分类器(Prayer 类)

这部分是神经网络的 "骨架",我们用线性层(全连接层) 搭建一个最简单的分类器,理解模型的输入、输出和前向传播过程。

3.1 类的定义与初始化(__init__方法)

python 复制代码
class Prayer(nn.Module):
    def __init__(self):
        super(Prayer, self).__init__()  # 继承nn.Module的初始化方法
        # 定义线性层:输入维度3072,输出维度10
        self.linear1 = nn.Linear(3072, 10)

这里有三个必须掌握的关键点:

(1)继承nn.Module的意义

nn.Module是 PyTorch 中所有神经网络模块的基类,自定义模型必须继承它。它的核心作用包括:

  • 自动管理模型中的可训练参数(比如线性层的权重和偏置);
  • 支持前向传播(forward方法)和反向传播(自动求导);
  • 提供模型保存、加载、移动到 GPU 等便捷功能。

(2)super(Prayer, self).__init__()的作用

这行代码是 "子类调用父类初始化方法" 的标准写法,目的是让父类nn.Module完成自身的初始化(比如初始化参数列表、计算设备等)。如果不写这行,模型会缺少必要的属性,后续调用时会报错。

(3)线性层nn.Linear(3072, 10)的参数含义

nn.Linear(in_features, out_features)是线性层的定义,本质是实现一个线性变换:y = x × W + b,其中:

  • in_features(输入维度):3072 → 对应 CIFAR10 图像展平后的特征数(3×32×32);
  • out_features(输出维度):10 → 对应 CIFAR10 的 10 个类别(每个输出值代表模型对该类别的 "置信度");
  • 线性层的可训练参数:
    • 权重W:形状为[out_features, in_features] → 这里是[10, 3072]
    • 偏置b:形状为[out_features] → 这里是[10]

3.2 前向传播(forward 方法)

python 复制代码
def forward(self, input):
    output = self.linear1(input)  # 将输入传入线性层,得到输出
    return output

forward方法是模型的 "计算流程",定义了数据如何从输入经过模型层得到输出。在 PyTorch 中,不需要手动调用forward方法 ,只需将模型实例当作函数调用(比如prayer(output)),PyTorch 会自动触发forward方法。

举个例子:如果输入是一个形状为[64, 3072]的张量(64 个样本,每个样本 3072 个特征),经过self.linear1后,输出会是[64, 10]的张量(64 个样本,每个样本 10 个类别置信度)。

四、模型推理:数据流过模型的完整流程

代码的最后一部分是 "模型推理",即让加载好的批次数据通过模型,观察数据形状的变化(这是理解模型是否正确的关键)。

4.1 创建模型实例

python 复制代码
prayer = Prayer()  # 实例化Prayer类,得到模型对象prayer

这行代码会调用Prayer类的__init__方法,创建线性层并初始化权重和偏置(默认是随机初始化)。此时prayer就是一个可使用的线性分类器模型。

4.2 遍历 DataLoader,执行推理

python 复制代码
for data in dataloader:
    imgs, targets = data  # 拆分每个批次的数据:图像张量和标签张量
    print("原始图像形状:", imgs.shape)  # 打印原始图像形状
    # 展平操作:从第1维开始展平,保留批次维度
    output = torch.flatten(imgs, start_dim=1)
    print("展平后形状:", output.shape)  # 打印展平后形状
    output = prayer(output)  # 将展平后的特征传入模型,得到输出
    print("模型输出形状:", output.shape)  # 打印模型输出形状

我们逐句解析,并结合可视化图表理解数据形状的变化:

(1)原始图像形状:imgs.shape

每个批次的imgs是一个 4 维张量,形状为[batch_size, C, H, W]

  • 当batch_size=64时,形状为[64, 3, 32, 32];
  • 含义:64 张图像,每张图像 3 个通道,每个通道 32×32 像素。

可视化如下(用简化的维度图表示):

python 复制代码
原始图像张量:[64(批次), 3(通道), 32(高度), 32(宽度)]
├─ 第1张图:[3, 32, 32]
├─ 第2张图:[3, 32, 32]
├─ ...
└─ 第64张图:[3, 32, 32]

2)展平操作:torch.flatten(imgs, start_dim=1)

线性层nn.Linear要求输入是2 维张量[batch_size, in_features]),而原始imgs是 4 维张量,因此需要用torch.flatten将其展平(只保留批次维度,将通道、高度、宽度合并为 "特征维度")。

  • start_dim=1:表示从第 1 个维度(通道维度)开始展平,第 0 个维度(批次维度)保持不变;
  • 展平后形状:[64, 3×32×32] = [64, 3072]。

可视化展平过程:

python 复制代码
原始形状:[64, 3, 32, 32]
          ↓ 展平维度1~3(3×32×32=3072)
展平后形状:[64, 3072]
├─ 第1个样本:[3072个特征值](R通道32×32 + G通道32×32 + B通道32×32)
├─ 第2个样本:[3072个特征值]
├─ ...
└─ 第64个样本:[3072个特征值]

(3)模型输出形状:prayer(output).shape

将展平后的[64, 3072]张量传入模型,经过线性层nn.Linear(3072, 10)变换后,输出形状为[64, 10]

  • 含义:64 个样本,每个样本对应 10 个数值(分别代表模型对 10 个类别的置信度);
  • 后续步骤(未在代码中体现):通过torch.argmax(output, dim=1)取每个样本置信度最大的索引,即为模型预测的类别。

可视化模型输入输出:

python 复制代码
模型输入(展平后):[64, 3072]
          ↓ 经过线性变换 y = x×W + b(W: [10,3072], b: [10])
模型输出:[64, 10]
├─ 第1个样本:[置信度0, 置信度1, ..., 置信度9] → 预测类别=置信度最大的索引
├─ 第2个样本:[置信度0, 置信度1, ..., 置信度9]
├─ ...
└─ 第64个样本:[置信度0, 置信度1, ..., 置信度9]

4.3 实际运行输出结果

当你运行代码时,会看到如下输出(前两个批次为例):

python 复制代码
原始图像形状: torch.Size([64, 3, 32, 32])
展平后形状: torch.Size([64, 3072])
模型输出形状: torch.Size([64, 10])
原始图像形状: torch.Size([64, 3, 32, 32])
展平后形状: torch.Size([64, 3072])
模型输出形状: torch.Size([64, 10])
...
# 最后一个批次(不足64张)
原始图像形状: torch.Size([16, 3, 32, 32])
展平后形状: torch.Size([16, 3072])
模型输出形状: torch.Size([16, 10])

这个结果验证了模型和数据处理的正确性:每个批次的输入都能顺利通过模型,输出形状符合预期。

五、常见问题与拓展:让代码更完整

虽然当前代码能正常运行,但它只是 "推理流程",实际深度学习项目还需要训练、损失计算、评估等步骤。

5.1 为什么线性层输入维度不能是 196608?

在之前的错误中,曾将线性层输入维度设为 196608,导致RuntimeError: mat1 and mat2 shapes cannot be multiplied。原因是:

  • 196608 = 64×3×32×32 → 这是整个批次所有像素的总数(包含了批次维度);
  • 线性层需要的是单个样本的特征数(3072),而不是整个批次的总像素数;
  • 记住:线性层输入维度 = 单样本特征数,与批次大小无关。

5.2 如何添加训练逻辑?

当前代码只有推理,要让模型能学习,需要添加损失函数、优化器和训练循环:

python 复制代码
# 1. 定义损失函数(分类任务常用交叉熵损失)
loss_fn = nn.CrossEntropyLoss()
# 2. 定义优化器(常用Adam优化器,学习率0.001)
optimizer = torch.optim.Adam(prayer.parameters(), lr=0.001)
# 3. 训练循环(以10轮训练为例)
epochs = 10
for epoch in range(epochs):
    running_loss = 0.0  # 记录每轮的总损失
    prayer.train()  # 切换模型为训练模式(启用 dropout、批量归一化等训练特有的操作)
    for data in dataloader:
        imgs, targets = data
        
        # 步骤1:前向传播(数据过模型)
        output = torch.flatten(imgs, start_dim=1)
        pred = prayer(output)
        
        # 步骤2:计算损失(预测值与真实标签的差距)
        loss = loss_fn(pred, targets)
        
        # 步骤3:反向传播(计算梯度)
        optimizer.zero_grad()  # 清空上一轮的梯度(避免梯度累积)
        loss.backward()  # 从损失值反向计算各参数的梯度
        
        # 步骤4:参数更新(用梯度优化器更新模型权重)
        optimizer.step()
        
        # 累加损失值(用于打印日志)
        running_loss += loss.item() * imgs.size(0)  # loss.item()是单批次损失,乘以批次大小得到总损失
    
    # 计算每轮的平均损失
    epoch_loss = running_loss / len(dataset)
    print(f"Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss:.4f}")

5.3 结果可视化:用 Matplotlib 展示预测效果​

为了更直观地理解模型的预测结果,我们可以用 Matplotlib 绘制 "图像 - 真实标签 - 预测标签" 的对应图,

六、完整代码

python 复制代码
import matplotlib.pyplot as plt
import numpy as np

# CIFAR10类别名称(与索引0-9对应)
classes = ('airplane', 'automobile', 'bird', 'cat', 'deer', 
           'dog', 'frog', 'horse', 'ship', 'truck')

def show_predictions(model, dataloader, num_images=5):
    model.eval()
    with torch.no_grad():
        # 取第一个批次的数据
        data_iter = iter(dataloader)
        imgs, targets = next(data_iter)
        
        # 前向传播得到预测结果
        output = torch.flatten(imgs, start_dim=1)
        pred = model(output)
        _, predicted = torch.max(pred, dim=1)
        
        # 转换图像格式(从[C, H, W]转为[H, W, C],方便Matplotlib显示)
        imgs = imgs.permute(0, 2, 3, 1).numpy()  # permute调整维度顺序
        imgs = imgs * 255  # 从[0,1]反归一化到[0,255](Matplotlib需要整数像素值)
        imgs = imgs.astype(np.uint8)  # 转为整数类型
        
        # 绘制图像
        plt.figure(figsize=(12, 4))
        for i in range(num_images):
            plt.subplot(1, num_images, i+1)
            plt.imshow(imgs[i])
            # 标题格式:真实标签 -> 预测标签(正确标绿,错误标红)
            true_label = classes[targets[i]]
            pred_label = classes[predicted[i]]
            color = 'green' if true_label == pred_label else 'red'
            plt.title(f"True: {true_label}\nPred: {pred_label}", color=color)
            plt.axis('off')  # 隐藏坐标轴
        plt.show()

# 调用可视化函数
show_predictions(prayer, dataloader, num_images=5)

可视化效果说明:​

运行代码后,会显示 5 张 CIFAR10 测试集图像,每张图像下方标注 "真实类别" 和 "预测类别":

  • 若预测正确,标题为绿色;
  • 若预测错误,标题为红色。

例如:​

  • 真实类别是 "cat",预测类别也是 "cat" → 绿色标题;
  • 真实类别是 "dog",预测类别是 "cat" → 红色标题。

通过可视化,你可以快速发现模型擅长预测哪些类别(如 "airplane""ship" 这类轮廓清晰的类别),以及容易混淆的类别(如 "cat" 和 "dog" 这类细节相似的类别)。​

七、常见问题与解决方案(FAQ)​

在实际运行代码时,你可能会遇到以下问题,这里提前给出解决方案:​

|-----------------------------------|--------------------------|-------------------------------------------------------------------------------------------|
| 常见问题​ | 错误原因​ | 解决方案​ |
| RuntimeError: CUDA out of memory​ | 显卡内存不足(模型或批次太大)​ | 1. 减小batch_size(如从 64 改为 32、16);2. 使用torch.cuda.empty_cache()清空缓存;3. 改用 CPU 训练(速度慢但不占显存)​ |
| 训练损失不下降,准确率始终 10% 左右​ | 模型未学习(可能是梯度消失或学习率不合适)​ | 1. 调整学习率(如从 0.001 改为 0.01 或 0.0001);2. 检查数据预处理是否正确(如是否忘记归一化);3. 增加训练轮次(epochs)​ |
| 评估准确率远低于训练准确率​ | 模型过拟合(在训练集上表现好,测试集上表现差)​ | 1. 增加训练数据(如数据增强,见下文拓展);2. 减少模型复杂度(如线性层改为更简单的结构);3. 添加正则化(如 L2 正则化)​ |

八、 进阶拓展:数据增强提升模型性能​

当前代码使用的是原始 CIFAR10 图像,若想进一步提升模型准确率,可以添加数据增强(通过随机变换图像,增加训练数据的多样性,减少过拟合)。修改数据加载代码如下:

python 复制代码
# 定义数据增强变换(训练集用增强,测试集不用)
train_transform = torchvision.transforms.Compose([
    torchvision.transforms.RandomCrop(32, padding=4),  # 随机裁剪( padding=4表示先填充4像素,再裁剪32×32)
    torchvision.transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转(50%概率)
    torchvision.transforms.ToTensor()  # 转为Tensor并归一化
])

test_transform = torchvision.transforms.ToTensor()  # 测试集只做归一化,不做增强

# 加载训练集(用增强变换)
train_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=train_transform
)

# 加载测试集(不用增强)
test_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=test_transform
)

# 创建DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)  # 训练集打乱顺序
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False)   # 测试集不打乱

数据增强的核心作用:​

  • 让模型看到更多 "变种" 图像(如裁剪后的局部图像、翻转后的图像);
  • 避免模型过度依赖图像的固定位置或方向(如只认识向左的猫,不认识向右的猫);
  • 通常能将 CIFAR10 线性分类器的准确率提升 10%-15%。
相关推荐
程序员杰哥3 小时前
UI自动化测试实战:从入门到精通
自动化测试·软件测试·python·selenium·测试工具·ui·职场和发展
SunnyRivers3 小时前
通俗易懂理解python yield
python
mortimer3 小时前
Python 进阶:彻底理解类属性、类方法与静态方法
后端·python
碱化钾3 小时前
Lipschitz连续及其常量
人工智能·机器学习
两万五千个小时4 小时前
LangChain 入门教程:06LangGraph工作流编排
人工智能·后端
渡我白衣4 小时前
深度学习进阶(六)——世界模型与具身智能:AI的下一次跃迁
人工智能·深度学习
人工智能技术咨询.4 小时前
【无标题】
人工智能·深度学习·transformer
云卓SKYDROID4 小时前
无人机激光避障技术概述
人工智能·无人机·航电系统·高科技·云卓科技
蜉蝣之翼❉4 小时前
图像处理之浓度(AI 调研)
图像处理·人工智能·机器学习