一、总体总结
这个代码的核心任务就是:训练一个模型,让它能自动将输入的图片(混在一起的 10 类图片)正确地归类到对应的类别中,然后用没见过的测试集来评估分类的准确率,看看模型分得准不准。
该教程使用 PyTorch 和 torchvision 完成了一个完整的图像分类任务,主要步骤如下:
-
加载并预处理 CIFAR-10 数据集
-
将图片从 PIL 格式转为张量(Tensor),并归一化到 -1, 1 范围。
-
创建 DataLoader,用于批量加载训练和测试数据。
-
-
定义一个简单的卷积神经网络(CNN)
-
网络包含两个卷积层、两个池化层和三个全连接层。
-
输入为 3 通道(RGB)的 32×32 图像,输出为 10 个类别。
-
-
定义损失函数和优化器
-
使用交叉熵损失(CrossEntropyLoss)作为损失函数。
-
使用带动量(momentum)的随机梯度下降(SGD)作为优化器。
-
-
训练网络
-
迭代 2 个 epoch,每个 epoch 遍历全部训练数据。
-
在每个 mini-batch 上进行前向传播、计算损失、反向传播并更新参数。
-
每 2000 个 mini-batch 打印一次平均损失。
-
-
在测试集上评估网络
-
计算网络在整个测试集(10000 张图)上的总体准确率。
-
分别计算每个类别(如猫、汽车、飞机等)的分类准确率。
-
-
(可选)在 GPU 上训练
- 检测 CUDA 是否可用,将网络和数据迁移到 GPU 以加速训练。
二、实际代码运行操作(分步实现)
1. 导入必要的库
python
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
2. 加载并归一化 CIFAR-10 数据集
python
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')

显示部分数据集看看
python
import matplotlib.pyplot as plt
import numpy as np
# 从 trainloader 中取一个 batch 的数据
dataiter = iter(trainloader)
images, labels = next(dataiter)
# 定义显示图片的函数(反归一化 + 转换格式)
def imshow(img):
img = img / 2 + 0.5 # 反归一化:原 Normalize(mean=0.5, std=0.5)
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0))) # 将 (C, H, W) 转为 (H, W, C)
plt.show()
# 显示图片(比如前 4 张)
# torchvision.utils.make_grid 可以将多张图片拼成一张网格图
grid_img = torchvision.utils.make_grid(images[:4])
imshow(grid_img)
# 打印对应的标签
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))

3. 定义卷积神经网络
python
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5) # 输入3通道,输出6通道,卷积核5x5
self.pool = nn.MaxPool2d(2, 2) # 2x2最大池化
self.conv2 = nn.Conv2d(6, 16, 5) # 输入6通道,输出16通道
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5) # 展平
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
4. 定义损失函数和优化器
python
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
5. 训练网络
python
for epoch in range(2): # 只训练2个epoch(可增加)
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 2000 == 1999:
print(f'[{epoch+1}, {i+1:5d}] loss: {running_loss/2000:.3f}')
running_loss = 0.0
print('Finished Training')

6. 在测试集上评估
python
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Accuracy of the network on the 10000 test images: {100 * correct / total:.2f} %')

7. 查看每个类别的准确率
python
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs, 1)
c = (predicted == labels).squeeze()
for i in range(4): # 注意缩进
label = labels[i]
class_correct[label] += c[i].item()
class_total[label] += 1
for i in range(10):
# 将 :2d 改为 :2.0f 或 :5.2f
print(f'Accuracy of {classes[i]:5s}: {100 * class_correct[i] / class_total[i]:2.0f} %')

8. (可选)在 GPU 上训练
python
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
net.to(device)
# 在训练和测试循环中,将 inputs, labels 也迁移到 device
# inputs, labels = inputs.to(device), labels.to(device)

三、重要部分解释
1. 数据预处理中的 Normalize 作用
python
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
-
ToTensor()将 PIL 图片(范围 0,1)转为张量。 -
Normalize用均值和标准差对每个通道做标准化:output = (input - mean) / std。这里 mean=0.5, std=0.5,结果范围变为 -1, 1。
好处:使数据分布更对称,利于梯度下降收敛。
2. 卷积网络结构的关键参数
-
nn.Conv2d(3, 6, 5):输入通道 3(RGB),输出通道 6,卷积核大小 5×5。原始 CIFAR-10 图片大小为 32×32,经过两次 2×2 池化后,尺寸变为 32→16→8,再减去卷积核边界效应,最终特征图大小为 5×5,因此全连接层的输入为
16 * 5 * 5。 -
池化层
MaxPool2d(2,2)步长为 2,将尺寸减半。
3. 损失函数和优化器
-
CrossEntropyLoss:适合多分类问题,内部自动对网络输出做 softmax,然后计算负对数似然损失。 -
SGD+momentum:带动量的 SGD 可以加速收敛并减少震荡。
4. 训练循环中的关键点
-
optimizer.zero_grad():清除之前的梯度,否则梯度会累加。 -
loss.backward():反向传播计算梯度。 -
optimizer.step():更新网络参数。 -
running_loss:每 2000 个 batch 打印一次平均损失,用于监控训练过程。
5. 测试时的 torch.no_grad() 上下文
- 在评估阶段不需要计算梯度,使用
torch.no_grad()可以禁用梯度计算,减少内存消耗并加速。
6. 准确率计算
-
torch.max(outputs, 1)返回每行(每个样本)最大值及其索引,索引就是预测的类别号。 -
通过比较
predicted == labels获得正确预测的布尔张量,再求和得到正确数量。
7. GPU 迁移
-
需要同时将网络和每个 batch 的数据都调用
.to(device)移至 GPU。 -
确保网络在创建后、数据传入前就放到 GPU 上。
四、运行结果示例(原文输出)
-
训练过程损失逐渐下降(从约 2.187 到 1.286)。
-
测试集总体准确率约为 54%(随机猜测为 10%,可见网络学到了有用的特征)。
-
各类别准确率差异较大,例如
car和horse约 73~74%,而deer仅 18% ------ 这可能因数据不平衡或类别特征不易区分。
五、扩展建议
-
增加训练轮数(epoch):尝试 10~20 个 epoch,准确率可进一步提升(通常可到 60~70%)。
-
调整网络结构:增加卷积层深度或使用更现代的结构(如 ResNet)。
-
使用数据增强 :在
transform中加入随机翻转、裁剪等,可提高泛化能力。 -
学习率调度 :使用
torch.optim.lr_scheduler动态调整学习率。
按照上述步骤,你可以完整运行一个基于 PyTorch 的图像分类器,并理解每个环节的作用
python
# 直接从 testloader 中取一批数据,拿第一张图片来预测
dataiter = iter(testloader)
images, labels = next(dataiter) # images 是 4x3x32x32 的张量,取第一张
# 取第一张图片并添加 batch 维度(其实 images[0] 本身就是 3x32x32,要 unsqueeze(0))
img = images[0].unsqueeze(0) # 现在尺寸为 1x3x32x32
true_label = labels[0].item()
with torch.no_grad():
outputs = net(img)
_, predicted = torch.max(outputs, 1)
predicted_label = predicted[0].item()
print(f'真实标签:{classes[true_label]}')
print(f'预测标签:{classes[predicted_label]}')
# 如果想显示这张图片(需要反归一化)
import matplotlib.pyplot as plt
import numpy as np
# 反归一化:因为训练时用了 Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
def imshow(img):
img = img / 2 + 0.5 # 反归一化到 [0,1]
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
imshow(images[0]) # 显示图片
