在掌握了神经网络的各个组件后,如何将它们组织成一个可运行、可监控、可保存的完整项目?本篇将通过 CIFAR-10 识别任务,拆解 PyTorch 训练的标准"套路"。
1. 训练全流程概览
一个标准的深度学习训练脚本通常包含以下几个固定环节:
- 准备数据集与加载器:Dataset & DataLoader。
- 搭建网络结构:定义模型类并实例化。
- 设置损失函数与优化器:选择合适的评价与优化算法。
- 训练循环 (Train Loop):前向传播、算损失、反向传播、更新参数。
- 测试/验证循环 (Test Loop):评估模型在未见过的数据上的表现。
- 可视化与保存:记录日志并持久化模型。
2. 核心代码实战
文件演示了如何通过多轮迭代(Epoch)来提升模型准确率。
① 网络模型定义 (model.py 逻辑)
这里使用了经典的 CIFAR-10 结构,通过 Sequential 串联。
Python
import torch
from torch import nn
class Tudui(nn.Module):
def __init__(self):
super(Tudui, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(3, 32, 5, 1, 2),
nn.MaxPool2d(2),
nn.Conv2d(32, 32, 5, 1, 2),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 5, 1, 2),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(64*4*4, 64),
nn.Linear(64, 10)
)
def forward(self, x):
return self.model(x)
② 训练与测试核心循环
注意代码中 tudui.train() 和 tudui.eval() 的切换,这是为了正确处理 Dropout 或 BatchNorm 等层。
Python
# 定义损失函数和优化器
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(tudui.parameters(), lr=0.01)
# 开始训练
for i in range(epoch):
print(f"-------第 {i+1} 轮训练开始-------")
# 训练步骤
tudui.train() # 切换到训练模式
for data in train_dataloader:
imgs, targets = data
outputs = tudui(imgs)
loss = loss_fn(outputs, targets)
# 优化器优化模型
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 测试步骤
tudui.eval() # 切换到评估模式
total_test_loss = 0
total_accuracy = 0
with torch.no_grad(): # 测试时不需要梯度,节省计算资源
for data in test_dataloader:
imgs, targets = data
outputs = tudui(imgs)
loss = loss_fn(outputs, targets)
total_test_loss += loss.item()
# 计算正确率
accuracy = (outputs.argmax(1) == targets).sum()
total_accuracy += accuracy
print(f"整体测试集上的Loss: {total_test_loss}")
print(f"整体测试集上的正确率: {total_accuracy / test_data_size}")
# 保存每一轮训练后的结果
torch.save(tudui, f"tudui_{i}.pth")
3. 关键细节解析
如何计算正确率 (Accuracy)?
在测试代码中,这一行是精华:outputs.argmax(1) == targets。
outputs.argmax(1):找到模型预测得分最高的那个类别索引。== targets:与真实标签比对,返回布尔值。.sum():统计正确的个数。
为什么需要 with torch.no_grad()?
在测试阶段,我们不需要更新权重,也不需要存储计算图。加上这一行可以大幅减少显存占用并提升运行速度。
4. 总结:训练套路的五个"一"
分析完这个文件,你可以记住这个简化的口诀:
- 一个模型类(定义结构)。
- 一个损失函数(定义目标)。
- 一个优化器(定义更新规则)。
- 一套训练代码(三步走:清零、反传、步进)。
- 一套验证代码(不计梯度、算准确率)。
💡 结语
到这里,你已经掌握了 PyTorch 开发的精髓。从最初的图片加载,到现在的模型评估,所有的模块都已经各就各位。