Day 40 深度学习训练与测试的规范写法

在深度学习项目的开发中,随着模型复杂度的提升,编写结构清晰、易于维护的训练和测试代码变得至关重要。本篇笔记基于 MNIST 手写数字识别任务,详细解析了 PyTorch 中训练和测试流程的规范化写法。

1. 核心设计理念

在早期的简单脚本中,我们可能直接将训练循环写在主程序中。但在规范的工程实践中,我们将**训练(Train)测试(Test/Validation)**过程封装为独立的函数。这种设计带来了以下优势:

  1. 逻辑解耦:将模型的前向传播、反向传播、参数更新与数据加载、指标统计分离,代码逻辑更清晰。
  2. 参数隔离:函数参数(如 epoch, device, dataloader)明确,修改超参数时无需深入修改逻辑代码。
  3. 易于复用:标准化的训练函数可以轻松应用到不同的模型或数据集上。
  4. 状态管理 :明确区分 train 模式和 eval 模式,避免因 Dropout 或 Batch Normalization 行为不一致导致的错误。

2. 完整流程解析

2.1 环境设置与数据准备

在开始训练前,首先进行必要的环境配置和数据加载。

  • 设备选择 :自动检测是否可用 GPU (cuda),否则使用 CPU。

  • 随机种子 :设置 torch.manual_seed 确保实验结果可复现。

  • 数据预处理 :使用 transforms.Compose 将图像转换为 Tensor 并进行归一化。

  • DataLoader :使用 DataLoader 进行批量数据加载,训练集通常开启 shuffle=True 打乱数据。

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.manual_seed(42)

    数据转换

    transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
    ])

    DataLoader

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

2.2 模型定义与展平操作

在定义 MLP(多层感知机)时,处理图像数据的一个关键步骤是展平(Flatten)

  • 输入维度 :图像数据通常是 (batch_size, channels, height, width),例如 (64, 1, 28, 28)
  • 全连接层要求 :全连接层 (Linear) 需要二维输入 (batch_size, input_features)
  • Flatten 的作用nn.Flatten() 将除 batch_size 以外的所有维度展平。例如 (64, 1, 28, 28) -> (64, 784)

注意:无论如何变换形状(Flatten, View, Reshape),第一个维度(Batch Size)通常保持不变。

复制代码
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.flatten = nn.Flatten()       # 展平层
        self.layer1 = nn.Linear(784, 128) # 隐藏层
        self.relu = nn.ReLU()             # 激活函数
        self.layer2 = nn.Linear(128, 10)  # 输出层

    def forward(self, x):
        x = self.flatten(x)
        x = self.layer1(x)
        x = self.relu(x)
        x = self.layer2(x)
        return x

2.3 规范化的训练函数 (train)

这是核心部分,负责模型的参数更新和过程监控。

关键步骤:

  1. model.train():将模型设置为训练模式。这对于包含 Dropout 或 Batch Normalization 的模型至关重要。

  2. 数据迁移data.to(device), target.to(device) 将数据移至 GPU。

  3. 梯度清零optimizer.zero_grad() 防止梯度累加。

  4. 反向传播loss.backward() 计算梯度。

  5. 参数更新optimizer.step() 更新模型权重。

  6. 指标记录

    • Iteration 级损失:记录每个 Batch 的损失,用于绘制精细的损失曲线,观察模型收敛的微观波动。
    • Epoch 级指标:计算整个 Epoch 的平均损失和准确率。

    def train(model, train_loader, test_loader, criterion, optimizer, device, epochs):
    model.train() # 开启训练模式

    复制代码
     all_iter_losses = [] # 记录所有 Batch 的损失
     iter_indices = []
     
     for epoch in range(epochs):
         running_loss = 0.0
         correct = 0
         total = 0
         
         for batch_idx, (data, target) in enumerate(train_loader):
             data, target = data.to(device), target.to(device)
             
             optimizer.zero_grad()         # 1. 梯度清零
             output = model(data)          # 2. 前向传播
             loss = criterion(output, target) # 3. 计算损失
             loss.backward()               # 4. 反向传播
             optimizer.step()              # 5. 更新参数
             
             # 记录细粒度损失
             iter_loss = loss.item()
             all_iter_losses.append(iter_loss)
             iter_indices.append(epoch * len(train_loader) + batch_idx + 1)
             
             # 统计累计指标
             running_loss += iter_loss
             _, predicted = output.max(1)
             total += target.size(0)
             correct += predicted.eq(target).sum().item()
             
             if (batch_idx + 1) % 100 == 0:
                 print(f'Epoch: {epoch+1} | Batch: {batch_idx+1} | Loss: {iter_loss:.4f}')
         
         # Epoch 结束后的验证
         epoch_acc = 100. * correct / total
         test_loss, test_acc = test(model, test_loader, criterion, device)
         print(f'Epoch {epoch+1} 训练准确率: {epoch_acc:.2f}% | 测试准确率: {test_acc:.2f}%')
         
     return test_acc

2.4 规范化的测试函数 (test)

测试函数用于评估模型性能,不涉及参数更新。

关键步骤:

  1. model.eval():将模型设置为评估模式。固定 Dropout 和 BN 层。

  2. with torch.no_grad():上下文管理器,关闭梯度计算。这可以显著减少显存占用并加速计算。

  3. 统计逻辑:累加损失值和正确预测数,最后计算平均值。

    def test(model, test_loader, criterion, device):
    model.eval() # 开启评估模式
    test_loss = 0
    correct = 0
    total = 0

    复制代码
     with torch.no_grad():  # 关闭梯度计算
         for data, target in test_loader:
             data, target = data.to(device), target.to(device)
             output = model(data)
             
             test_loss += criterion(output, target).item() # 累加 Loss
             _, predicted = output.max(1)
             total += target.size(0)
             correct += predicted.eq(target).sum().item() # 累加正确数
     
     avg_loss = test_loss / len(test_loader)
     accuracy = 100. * correct / total
     return avg_loss, accuracy

3. 常见问题与最佳实践 QA

Q1: 为什么要在训练循环中使用 loss.item()****?

  • A : loss 是一个包含计算图信息的 Tensor。如果直接累加 running_loss += loss,PyTorch 会保留整个计算图,导致显存迅速耗尽(Memory Leak)。使用 .item() 可以获取 Python 标量数值,切断计算图依赖。

Q2: **model.train()** **model.eval()**是必须的吗?

  • A : 对于简单的 MLP(没有 Dropout 和 BN),它们可能看起来没区别。但必须养成习惯。因为一旦模型加入了 Dropout(训练时随机丢弃,测试时全保留)或 Batch Normalization(训练时计算 Batch 均值,测试时使用全局均值),不切换模式会导致严重的性能下降。

Q3: 为什么测试时要用 torch.no_grad()****?

  • A: 测试阶段不需要反向传播更新参数,因此不需要构建计算图。关闭梯度计算可以节省大量内存(不需要保存中间激活值),并且略微提升推理速度。

Q4: 为什么要记录每个 Iteration 的损失?

  • A : Epoch 级别的平均损失可能会掩盖模型训练过程中的震荡或异常。通过绘制 Iteration 级别的 Loss 曲线,我们可以更直观地观察:
    • 学习率是否过大(Loss 剧烈震荡)。
    • 模型是否在某些 Batch 上难以收敛。
    • 训练初期的快速下降趋势。

4. 总结

规范化的 PyTorch 训练代码包含以下要素:

  1. 结构化:使用 Dataset/DataLoader 管理数据,使用 Class 管理模型。
  2. 模块化train()test() 函数分离,职责单一。
  3. 正确性 :正确使用 train/eval 模式切换,正确处理梯度清零和反向传播。
  4. 高效性 :使用 device 管理硬件加速,使用 no_grad 优化推理。
  5. 可观测性:详细记录 Loss 和 Accuracy,辅助调参。
相关推荐
音视频牛哥7 小时前
C#实战:如何开发设计毫秒级延迟、工业级稳定的Windows平台RTSP/RTMP播放器
人工智能·机器学习·机器人·c#·音视频·rtsp播放器·rtmp播放器
Blossom.1188 小时前
基于时序大模型+强化学习的虚拟电厂储能调度系统:从负荷预测到收益最大化的实战闭环
运维·人工智能·python·决策树·机器学习·自动化·音视频
深蓝海拓9 小时前
PySide6从0开始学习的笔记(四)QMainWindow
笔记·python·学习·pyqt
百胜软件@百胜软件9 小时前
重塑零售未来:百胜智能中台+胜券AI,赋能品牌零售撬动3100亿增量市场
大数据·人工智能·零售
深蓝海拓9 小时前
PySide6 的 QSettings简单应用学习笔记
python·学习·pyqt
Shawn_Shawn14 小时前
人工智能入门概念介绍
人工智能
极限实验室14 小时前
程序员爆哭!我们让 COCO AI 接管 GitLab 审查后,团队直接起飞:连 CTO 都说“这玩意儿比人靠谱多了
人工智能·gitlab
Maynor99616 小时前
Z-Image: 100% Free AI Image Generator
人工智能
码界奇点16 小时前
Python从0到100一站式学习路线图与实战指南
开发语言·python·学习·青少年编程·贴图