115_PyTorch 实战:从零搭建 CIFAR-10 完整训练与测试流水线

在掌握了神经网络的各个组件后,如何将它们组织成一个可运行、可监控、可保存的完整项目?本篇将通过 CIFAR-10 识别任务,拆解 PyTorch 训练的标准"套路"。

1. 训练全流程概览

一个标准的深度学习训练脚本通常包含以下几个固定环节:

  1. 准备数据集与加载器:Dataset & DataLoader。
  2. 搭建网络结构:定义模型类并实例化。
  3. 设置损失函数与优化器:选择合适的评价与优化算法。
  4. 训练循环 (Train Loop):前向传播、算损失、反向传播、更新参数。
  5. 测试/验证循环 (Test Loop):评估模型在未见过的数据上的表现。
  6. 可视化与保存:记录日志并持久化模型。

2. 核心代码实战

文件演示了如何通过多轮迭代(Epoch)来提升模型准确率。

① 网络模型定义 (model.py 逻辑)

这里使用了经典的 CIFAR-10 结构,通过 Sequential 串联。

Python

复制代码
import torch
from torch import nn

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64*4*4, 64),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        return self.model(x)

② 训练与测试核心循环

注意代码中 tudui.train()tudui.eval() 的切换,这是为了正确处理 Dropout 或 BatchNorm 等层。

Python

复制代码
# 定义损失函数和优化器
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(tudui.parameters(), lr=0.01)

# 开始训练
for i in range(epoch):
    print(f"-------第 {i+1} 轮训练开始-------")

    # 训练步骤
    tudui.train() # 切换到训练模式
    for data in train_dataloader:
        imgs, targets = data
        outputs = tudui(imgs)
        loss = loss_fn(outputs, targets)

        # 优化器优化模型
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # 测试步骤
    tudui.eval() # 切换到评估模式
    total_test_loss = 0
    total_accuracy = 0
    with torch.no_grad(): # 测试时不需要梯度,节省计算资源
        for data in test_dataloader:
            imgs, targets = data
            outputs = tudui(imgs)
            loss = loss_fn(outputs, targets)
            total_test_loss += loss.item()
            # 计算正确率
            accuracy = (outputs.argmax(1) == targets).sum()
            total_accuracy += accuracy

    print(f"整体测试集上的Loss: {total_test_loss}")
    print(f"整体测试集上的正确率: {total_accuracy / test_data_size}")

    # 保存每一轮训练后的结果
    torch.save(tudui, f"tudui_{i}.pth")

3. 关键细节解析

如何计算正确率 (Accuracy)?

在测试代码中,这一行是精华:outputs.argmax(1) == targets

  • outputs.argmax(1):找到模型预测得分最高的那个类别索引。
  • == targets:与真实标签比对,返回布尔值。
  • .sum():统计正确的个数。

为什么需要 with torch.no_grad()

在测试阶段,我们不需要更新权重,也不需要存储计算图。加上这一行可以大幅减少显存占用并提升运行速度


4. 总结:训练套路的五个"一"

分析完这个文件,你可以记住这个简化的口诀:

  • 个模型类(定义结构)。
  • 个损失函数(定义目标)。
  • 个优化器(定义更新规则)。
  • 套训练代码(三步走:清零、反传、步进)。
  • 套验证代码(不计梯度、算准确率)。

💡 结语

到这里,你已经掌握了 PyTorch 开发的精髓。从最初的图片加载,到现在的模型评估,所有的模块都已经各就各位。

相关推荐
财迅通Ai8 小时前
商业航天概念领涨A股,航天ETF华安(159267.SZ)收盘上涨1.2%
大数据·人工智能·区块链·中国卫星·航天电子
齐齐大魔王9 小时前
智能语音技术(八)
人工智能·语音识别
许彰午9 小时前
零成本搭建RAG智能客服:Ollama + Milvus + DeepSeek全程实战
人工智能·语音识别·llama·milvus
ZPC82109 小时前
自定义action server 接收arm_controller 指令
人工智能·机器人
迷茫的启明星9 小时前
各职业在当前发展阶段,使用AI的舒适区与盲区
大数据·人工智能·职场和发展
Liqiuyue10 小时前
Transformer:现代AI革命背后的核心模型
人工智能·算法·机器学习
桂花饼10 小时前
AI 视频生成:sora-2 模型快速对接指南
人工智能·音视频·sora2·nano banana 2·claude-opus-4-6·gemini 3.1
GreenTea11 小时前
AI Agent 评测的下半场:从方法论到落地实践
前端·人工智能·后端
冬奇Lab12 小时前
一天一个开源项目(第73篇):Multica - 把 AI 编程智能体变成真正的团队成员
人工智能·开源·资讯