CIFAR-10 是一个经典的小型彩色图像分类数据集,广泛用于深度学习入门、模型验证和算法研究。下面我将用通俗易懂的方式为你全面讲解:
📦 一、什么是 CIFAR-10 数据集?
✅ 基本信息
表格
项目 内容
图像数量 60,000 张(50,000 训练 + 10,000 测试)
图像尺寸 32×32 像素(非常小!)
颜色通道 RGB 三通道(彩色图像)
类别数 10 类
每类图像数 6,000 张(训练 5,000 + 测试 1,000)
🖼️ 10 个类别(都是日常物体):
text
编辑
. airplane(飞机)
. automobile(汽车)
. bird(鸟)
. cat(猫)
. deer(鹿)
. dog(狗)
. frog(青蛙)
. horse(马)
. ship(船)
. truck(卡车)
💡 特点:图像小、类别均衡、标注准确,非常适合快速实验 CNN 模型。
🔧 二、如何在 PyTorch 中加载 CIFAR-10?
PyTorch 的 torchvision.datasets 提供了开箱即用的接口:
python
编辑
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
定义图像预处理(标准化是关键!)
transform = transforms.Compose([
transforms.ToTensor(),
CIFAR-10 的均值和标准差(官方推荐)
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])
加载训练集和测试集
train_dataset = datasets.CIFAR10(
root='./data', # 下载路径
train=True, # 是否为训练集
download=True, # 自动下载
transform=transform # 预处理
)
test_dataset = datasets.CIFAR10(
root='./data',
train=False,
download=True,
transform=transform
)
创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)
print("✅ CIFAR-10 已加载!")
print(f"训练集大小: {len(train_dataset)}")
print(f"测试集大小: {len(test_dataset)}")
⚠️ 注意:第一次运行会自动下载约 170MB 数据。
🧠 三、用 CNN 训练 CIFAR-10(完整代码)
下面是一个简单但有效的 CNN 模型,包含卷积层、池化层、批归一化和全连接层。
步骤 1:定义 CNN 模型
python
编辑
import torch.nn as nn
import torch.nn.functional as F
class SimpleCNN(nn.Module):
def init(self, num_classes=10):
super(SimpleCNN, self).init()
第一个卷积块
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.pool1 = nn.MaxPool2d(2, 2)
第二个卷积块
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.bn3 = nn.BatchNorm2d(128)
self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
self.bn4 = nn.BatchNorm2d(128)
self.pool2 = nn.MaxPool2d(2, 2)
全连接层
self.fc1 = nn.Linear(128 * 8 * 8, 512)
self.dropout = nn.Dropout(0.5)
self.fc2 = nn.Linear(512, num_classes)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = self.pool1(x) # → 16x16
x = F.relu(self.bn3(self.conv3(x)))
x = F.relu(self.bn4(self.conv4(x)))
x = self.pool2(x) # → 8x8
x = x.view(-1, 128 * 8 * 8) # 展平
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x
步骤 2:训练循环
python
编辑
import torch.optim as optim
初始化
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
训练
num_epochs = 20
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
测试
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
acc = 100 * correct / total
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.3f}, Acc: {acc:.2f}%")
scheduler.step()
📈 四、预期效果 & 优化建议
🎯 基础 CNN 性能
准确率:约 75%~80%
训练时间:CPU 上几分钟,GPU 上几十秒
🔧 如何提升到 90%+?
表格
方法 说明
更深网络 使用 ResNet-18(预训练或从头训练)
数据增强 添加 RandomHorizontalFlip, RandomCrop
学习率调度 用 ReduceLROnPlateau
正则化 增加 Dropout、权重衰减
优化器 Adam → SGD + momentum
示例:添加数据增强
python
编辑
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])
🌟 五、为什么 CIFAR-10 如此重要?
入门友好:数据小,训练快,适合教学
基准测试:几乎所有 CNN 论文都会在 CIFAR-10 上验证
算法验证:可快速测试新想法(如新激活函数、正则化方法)
迁移学习起点:在 CIFAR-10 上预训练,再迁移到其他小数据集
✅ 总结
CIFAR-10 = 深度学习的"Hello World"
通过它,你可以:
理解 CNN 如何工作
掌握 PyTorch 数据加载流程
学习训练/验证/调参完整 pipeline
只需运行上面的代码,你就能在 10 分钟内 训练出一个能识别飞机、猫、船等物体的 AI 模型!
如果你想要:
- ResNet 版本代码
- 可视化训练过程(loss/acc 曲线)
- Grad-CAM 热力图(看模型关注哪里)
- 可以尝试基于上面思路自己实践下