在深度学习入门过程中,图像分类是最经典的任务之一,而 CIFAR10 数据集则是入门图像分类的 "练手神器"。
一、前置知识:CIFAR10 数据集是什么?
1.1 CIFAR10 核心参数
CIFAR10(Canadian Institute for Advanced Research 10)是由加拿大高级研究所发布的图像分类数据集,包含10 个类别的彩色图像,具体信息如下:
类别 | 包含内容 | 样本数量(训练集) | 样本数量(测试集) |
---|---|---|---|
0 | 飞机(airplane) | 5000 | 1000 |
1 | 汽车(automobile) | 5000 | 1000 |
2 | 鸟类(bird) | 5000 | 1000 |
3 | 猫(cat) | 5000 | 1000 |
4 | 鹿(deer) | 5000 | 1000 |
5 | 狗(dog) | 5000 | 1000 |
6 | 青蛙(frog) | 5000 | 1000 |
7 | 马(horse) | 5000 | 1000 |
8 | 船(ship) | 5000 | 1000 |
9 | 卡车(truck) | 5000 | 1000 |
1.2 图像尺寸与格式
CIFAR10 的每张图像都是3 通道彩色图像 (RGB),尺寸固定为32×32 像素(即高度 32、宽度 32)。
- 通道数(C):3(R 红、G 绿、B 蓝)
- 高度(H):32
- 宽度(W):32
- 单张图像展平后特征数:3×32×32 = 3072(这是后续线性层输入维度的关键依据)
二、代码拆解:从导入库到数据加载
代码的第一部分是 "数据准备",核心是将 CIFAR10 数据集加载到 PyTorch 中,并按批次处理。我们逐行解析:
2.1 导入必要库
python
import torch # PyTorch核心库(张量操作、自动求导等)
import torchvision # PyTorch视觉库(数据集、图像变换、预训练模型等)
from torch import nn # PyTorch神经网络模块(含线性层、卷积层等)
from torch.utils.data import DataLoader # 数据加载器(按批次加载数据)
这是 PyTorch 视觉任务的 "标准开头",每个库的作用必须明确:
torch
:所有操作的基础,比如张量(Tensor)的创建和计算。torchvision
:专门为计算机视觉设计,提供了 CIFAR10 等常用数据集,以及图像预处理工具。torch.nn
:搭建神经网络的核心,比如nn.Linear
(线性层)、nn.Conv2d
(卷积层)都在这里。DataLoader
:将数据集按批次分割,支持多线程加载,是训练时高效喂数据的关键。
2.2 加载 CIFAR10 测试集
python
dataset = torchvision.datasets.CIFAR10(
root='./data', # 数据集保存路径(当前目录下的data文件夹)
train=False, # 是否为训练集:False表示加载测试集,True表示加载训练集
download=True, # 如果root路径下没有数据集,是否自动下载
transform=torchvision.transforms.ToTensor() # 图像变换:将PIL图像转为Tensor
)
(1)train=False
的意义
- 当train=True时,加载的是50000 张图像的训练集 (用于模型训练);
- 当train=False时,加载的是10000 张图像的测试集 (用于验证模型性能);
- 我们这里用测试集做演示,后续实际训练时需要切换为
train=True
。
(2)transform=ToTensor()
的作用
图像在计算机中原始存储格式是PIL 图像 (或 numpy 数组),像素值范围是[0, 255]
(整数),但 PyTorch 模型要求输入是Tensor 格式 ,且像素值归一化到[0, 1]
(浮点数)。ToTensor()
做了两件事:
- 将 PIL 图像转为形状为
[C, H, W]
的 Tensor(注意:PIL 图像默认是[H, W, C]
,这里会自动转置通道顺序); - 将像素值从
[0, 255]
除以 255,归一化到[0, 1]
。
举个例子:一张 PIL 格式的 CIFAR10 图像(32×32×3),经过ToTensor()
后会变成[3, 32, 32]
的 Tensor,每个元素值在 0~1 之间。
2.3 用 DataLoader 按批次加载数据
python
dataloader = DataLoader(dataset, batch_size=64)
DataLoader的核心作用是将dataset(10000 张测试集图像)按batch_size=64分割成多个批次,方便模型批量处理(批量处理能提高计算效率,且符合梯度下降的原理)。
- 总批次数量:10000 ÷ 64 ≈ 157(最后一个批次不足 64 张,实际为 10000 - 156×64 = 16 张);
- 每个批次的数据格式:
(imgs, targets)
,其中imgs
是图像张量,targets
是类别标签张量。
三、核心:搭建线性分类器(Prayer 类)
这部分是神经网络的 "骨架",我们用线性层(全连接层) 搭建一个最简单的分类器,理解模型的输入、输出和前向传播过程。
3.1 类的定义与初始化(__init__方法)
python
class Prayer(nn.Module):
def __init__(self):
super(Prayer, self).__init__() # 继承nn.Module的初始化方法
# 定义线性层:输入维度3072,输出维度10
self.linear1 = nn.Linear(3072, 10)
这里有三个必须掌握的关键点:
(1)继承nn.Module
的意义
nn.Module
是 PyTorch 中所有神经网络模块的基类,自定义模型必须继承它。它的核心作用包括:
- 自动管理模型中的可训练参数(比如线性层的权重和偏置);
- 支持前向传播(
forward
方法)和反向传播(自动求导); - 提供模型保存、加载、移动到 GPU 等便捷功能。
(2)super(Prayer, self).__init__()
的作用
这行代码是 "子类调用父类初始化方法" 的标准写法,目的是让父类nn.Module完成自身的初始化(比如初始化参数列表、计算设备等)。如果不写这行,模型会缺少必要的属性,后续调用时会报错。
(3)线性层nn.Linear(3072, 10)
的参数含义
nn.Linear(in_features, out_features)
是线性层的定义,本质是实现一个线性变换:y = x × W + b
,其中:
in_features
(输入维度):3072 → 对应 CIFAR10 图像展平后的特征数(3×32×32);out_features
(输出维度):10 → 对应 CIFAR10 的 10 个类别(每个输出值代表模型对该类别的 "置信度");- 线性层的可训练参数:
- 权重
W
:形状为[out_features, in_features]
→ 这里是[10, 3072]
; - 偏置
b
:形状为[out_features]
→ 这里是[10]
。
- 权重
3.2 前向传播(forward 方法)
python
def forward(self, input):
output = self.linear1(input) # 将输入传入线性层,得到输出
return output
forward
方法是模型的 "计算流程",定义了数据如何从输入经过模型层得到输出。在 PyTorch 中,不需要手动调用forward方法 ,只需将模型实例当作函数调用(比如prayer(output)),PyTorch 会自动触发forward方法。
举个例子:如果输入是一个形状为[64, 3072]
的张量(64 个样本,每个样本 3072 个特征),经过self.linear1
后,输出会是[64, 10]
的张量(64 个样本,每个样本 10 个类别置信度)。
四、模型推理:数据流过模型的完整流程
代码的最后一部分是 "模型推理",即让加载好的批次数据通过模型,观察数据形状的变化(这是理解模型是否正确的关键)。
4.1 创建模型实例
python
prayer = Prayer() # 实例化Prayer类,得到模型对象prayer
这行代码会调用Prayer
类的__init__
方法,创建线性层并初始化权重和偏置(默认是随机初始化)。此时prayer
就是一个可使用的线性分类器模型。
4.2 遍历 DataLoader,执行推理
python
for data in dataloader:
imgs, targets = data # 拆分每个批次的数据:图像张量和标签张量
print("原始图像形状:", imgs.shape) # 打印原始图像形状
# 展平操作:从第1维开始展平,保留批次维度
output = torch.flatten(imgs, start_dim=1)
print("展平后形状:", output.shape) # 打印展平后形状
output = prayer(output) # 将展平后的特征传入模型,得到输出
print("模型输出形状:", output.shape) # 打印模型输出形状
我们逐句解析,并结合可视化图表理解数据形状的变化:
(1)原始图像形状:imgs.shape
每个批次的imgs
是一个 4 维张量,形状为[batch_size, C, H, W]
。
- 当batch_size=64时,形状为[64, 3, 32, 32];
- 含义:64 张图像,每张图像 3 个通道,每个通道 32×32 像素。
可视化如下(用简化的维度图表示):
python
原始图像张量:[64(批次), 3(通道), 32(高度), 32(宽度)]
├─ 第1张图:[3, 32, 32]
├─ 第2张图:[3, 32, 32]
├─ ...
└─ 第64张图:[3, 32, 32]
2)展平操作:torch.flatten(imgs, start_dim=1)
线性层nn.Linear
要求输入是2 维张量 ([batch_size, in_features]
),而原始imgs
是 4 维张量,因此需要用torch.flatten将其展平(只保留批次维度,将通道、高度、宽度合并为 "特征维度")。
start_dim=1
:表示从第 1 个维度(通道维度)开始展平,第 0 个维度(批次维度)保持不变;- 展平后形状:[64, 3×32×32] = [64, 3072]。
可视化展平过程:
python
原始形状:[64, 3, 32, 32]
↓ 展平维度1~3(3×32×32=3072)
展平后形状:[64, 3072]
├─ 第1个样本:[3072个特征值](R通道32×32 + G通道32×32 + B通道32×32)
├─ 第2个样本:[3072个特征值]
├─ ...
└─ 第64个样本:[3072个特征值]
(3)模型输出形状:prayer(output).shape
将展平后的[64, 3072]
张量传入模型,经过线性层nn.Linear(3072, 10)
变换后,输出形状为[64, 10]
。
- 含义:64 个样本,每个样本对应 10 个数值(分别代表模型对 10 个类别的置信度);
- 后续步骤(未在代码中体现):通过
torch.argmax(output, dim=1)
取每个样本置信度最大的索引,即为模型预测的类别。
可视化模型输入输出:
python
模型输入(展平后):[64, 3072]
↓ 经过线性变换 y = x×W + b(W: [10,3072], b: [10])
模型输出:[64, 10]
├─ 第1个样本:[置信度0, 置信度1, ..., 置信度9] → 预测类别=置信度最大的索引
├─ 第2个样本:[置信度0, 置信度1, ..., 置信度9]
├─ ...
└─ 第64个样本:[置信度0, 置信度1, ..., 置信度9]
4.3 实际运行输出结果
当你运行代码时,会看到如下输出(前两个批次为例):
python
原始图像形状: torch.Size([64, 3, 32, 32])
展平后形状: torch.Size([64, 3072])
模型输出形状: torch.Size([64, 10])
原始图像形状: torch.Size([64, 3, 32, 32])
展平后形状: torch.Size([64, 3072])
模型输出形状: torch.Size([64, 10])
...
# 最后一个批次(不足64张)
原始图像形状: torch.Size([16, 3, 32, 32])
展平后形状: torch.Size([16, 3072])
模型输出形状: torch.Size([16, 10])
这个结果验证了模型和数据处理的正确性:每个批次的输入都能顺利通过模型,输出形状符合预期。
五、常见问题与拓展:让代码更完整
虽然当前代码能正常运行,但它只是 "推理流程",实际深度学习项目还需要训练、损失计算、评估等步骤。
5.1 为什么线性层输入维度不能是 196608?
在之前的错误中,曾将线性层输入维度设为 196608,导致RuntimeError: mat1 and mat2 shapes cannot be multiplied
。原因是:
- 196608 = 64×3×32×32 → 这是整个批次所有像素的总数(包含了批次维度);
- 线性层需要的是单个样本的特征数(3072),而不是整个批次的总像素数;
- 记住:线性层输入维度 = 单样本特征数,与批次大小无关。
5.2 如何添加训练逻辑?
当前代码只有推理,要让模型能学习,需要添加损失函数、优化器和训练循环:
python
# 1. 定义损失函数(分类任务常用交叉熵损失)
loss_fn = nn.CrossEntropyLoss()
# 2. 定义优化器(常用Adam优化器,学习率0.001)
optimizer = torch.optim.Adam(prayer.parameters(), lr=0.001)
# 3. 训练循环(以10轮训练为例)
epochs = 10
for epoch in range(epochs):
running_loss = 0.0 # 记录每轮的总损失
prayer.train() # 切换模型为训练模式(启用 dropout、批量归一化等训练特有的操作)
for data in dataloader:
imgs, targets = data
# 步骤1:前向传播(数据过模型)
output = torch.flatten(imgs, start_dim=1)
pred = prayer(output)
# 步骤2:计算损失(预测值与真实标签的差距)
loss = loss_fn(pred, targets)
# 步骤3:反向传播(计算梯度)
optimizer.zero_grad() # 清空上一轮的梯度(避免梯度累积)
loss.backward() # 从损失值反向计算各参数的梯度
# 步骤4:参数更新(用梯度优化器更新模型权重)
optimizer.step()
# 累加损失值(用于打印日志)
running_loss += loss.item() * imgs.size(0) # loss.item()是单批次损失,乘以批次大小得到总损失
# 计算每轮的平均损失
epoch_loss = running_loss / len(dataset)
print(f"Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss:.4f}")
5.3 结果可视化:用 Matplotlib 展示预测效果
为了更直观地理解模型的预测结果,我们可以用 Matplotlib 绘制 "图像 - 真实标签 - 预测标签" 的对应图,

六、完整代码
python
import matplotlib.pyplot as plt
import numpy as np
# CIFAR10类别名称(与索引0-9对应)
classes = ('airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
def show_predictions(model, dataloader, num_images=5):
model.eval()
with torch.no_grad():
# 取第一个批次的数据
data_iter = iter(dataloader)
imgs, targets = next(data_iter)
# 前向传播得到预测结果
output = torch.flatten(imgs, start_dim=1)
pred = model(output)
_, predicted = torch.max(pred, dim=1)
# 转换图像格式(从[C, H, W]转为[H, W, C],方便Matplotlib显示)
imgs = imgs.permute(0, 2, 3, 1).numpy() # permute调整维度顺序
imgs = imgs * 255 # 从[0,1]反归一化到[0,255](Matplotlib需要整数像素值)
imgs = imgs.astype(np.uint8) # 转为整数类型
# 绘制图像
plt.figure(figsize=(12, 4))
for i in range(num_images):
plt.subplot(1, num_images, i+1)
plt.imshow(imgs[i])
# 标题格式:真实标签 -> 预测标签(正确标绿,错误标红)
true_label = classes[targets[i]]
pred_label = classes[predicted[i]]
color = 'green' if true_label == pred_label else 'red'
plt.title(f"True: {true_label}\nPred: {pred_label}", color=color)
plt.axis('off') # 隐藏坐标轴
plt.show()
# 调用可视化函数
show_predictions(prayer, dataloader, num_images=5)
可视化效果说明:
运行代码后,会显示 5 张 CIFAR10 测试集图像,每张图像下方标注 "真实类别" 和 "预测类别":
- 若预测正确,标题为绿色;
- 若预测错误,标题为红色。
例如:
- 真实类别是 "cat",预测类别也是 "cat" → 绿色标题;
- 真实类别是 "dog",预测类别是 "cat" → 红色标题。
通过可视化,你可以快速发现模型擅长预测哪些类别(如 "airplane""ship" 这类轮廓清晰的类别),以及容易混淆的类别(如 "cat" 和 "dog" 这类细节相似的类别)。
七、常见问题与解决方案(FAQ)
在实际运行代码时,你可能会遇到以下问题,这里提前给出解决方案:
|-----------------------------------|--------------------------|-------------------------------------------------------------------------------------------|
| 常见问题 | 错误原因 | 解决方案 |
| RuntimeError: CUDA out of memory | 显卡内存不足(模型或批次太大) | 1. 减小batch_size(如从 64 改为 32、16);2. 使用torch.cuda.empty_cache()清空缓存;3. 改用 CPU 训练(速度慢但不占显存) |
| 训练损失不下降,准确率始终 10% 左右 | 模型未学习(可能是梯度消失或学习率不合适) | 1. 调整学习率(如从 0.001 改为 0.01 或 0.0001);2. 检查数据预处理是否正确(如是否忘记归一化);3. 增加训练轮次(epochs) |
| 评估准确率远低于训练准确率 | 模型过拟合(在训练集上表现好,测试集上表现差) | 1. 增加训练数据(如数据增强,见下文拓展);2. 减少模型复杂度(如线性层改为更简单的结构);3. 添加正则化(如 L2 正则化) |
八、 进阶拓展:数据增强提升模型性能
当前代码使用的是原始 CIFAR10 图像,若想进一步提升模型准确率,可以添加数据增强(通过随机变换图像,增加训练数据的多样性,减少过拟合)。修改数据加载代码如下:
python
# 定义数据增强变换(训练集用增强,测试集不用)
train_transform = torchvision.transforms.Compose([
torchvision.transforms.RandomCrop(32, padding=4), # 随机裁剪( padding=4表示先填充4像素,再裁剪32×32)
torchvision.transforms.RandomHorizontalFlip(p=0.5), # 随机水平翻转(50%概率)
torchvision.transforms.ToTensor() # 转为Tensor并归一化
])
test_transform = torchvision.transforms.ToTensor() # 测试集只做归一化,不做增强
# 加载训练集(用增强变换)
train_dataset = torchvision.datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=train_transform
)
# 加载测试集(不用增强)
test_dataset = torchvision.datasets.CIFAR10(
root='./data',
train=False,
download=True,
transform=test_transform
)
# 创建DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True) # 训练集打乱顺序
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False) # 测试集不打乱
数据增强的核心作用:
- 让模型看到更多 "变种" 图像(如裁剪后的局部图像、翻转后的图像);
- 避免模型过度依赖图像的固定位置或方向(如只认识向左的猫,不认识向右的猫);
- 通常能将 CIFAR10 线性分类器的准确率提升 10%-15%。