115_PyTorch 实战:从零搭建 CIFAR-10 完整训练与测试流水线

在掌握了神经网络的各个组件后,如何将它们组织成一个可运行、可监控、可保存的完整项目?本篇将通过 CIFAR-10 识别任务,拆解 PyTorch 训练的标准"套路"。

1. 训练全流程概览

一个标准的深度学习训练脚本通常包含以下几个固定环节:

  1. 准备数据集与加载器:Dataset & DataLoader。
  2. 搭建网络结构:定义模型类并实例化。
  3. 设置损失函数与优化器:选择合适的评价与优化算法。
  4. 训练循环 (Train Loop):前向传播、算损失、反向传播、更新参数。
  5. 测试/验证循环 (Test Loop):评估模型在未见过的数据上的表现。
  6. 可视化与保存:记录日志并持久化模型。

2. 核心代码实战

文件演示了如何通过多轮迭代(Epoch)来提升模型准确率。

① 网络模型定义 (model.py 逻辑)

这里使用了经典的 CIFAR-10 结构,通过 Sequential 串联。

Python

复制代码
import torch
from torch import nn

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64*4*4, 64),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        return self.model(x)

② 训练与测试核心循环

注意代码中 tudui.train()tudui.eval() 的切换,这是为了正确处理 Dropout 或 BatchNorm 等层。

Python

复制代码
# 定义损失函数和优化器
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(tudui.parameters(), lr=0.01)

# 开始训练
for i in range(epoch):
    print(f"-------第 {i+1} 轮训练开始-------")

    # 训练步骤
    tudui.train() # 切换到训练模式
    for data in train_dataloader:
        imgs, targets = data
        outputs = tudui(imgs)
        loss = loss_fn(outputs, targets)

        # 优化器优化模型
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # 测试步骤
    tudui.eval() # 切换到评估模式
    total_test_loss = 0
    total_accuracy = 0
    with torch.no_grad(): # 测试时不需要梯度,节省计算资源
        for data in test_dataloader:
            imgs, targets = data
            outputs = tudui(imgs)
            loss = loss_fn(outputs, targets)
            total_test_loss += loss.item()
            # 计算正确率
            accuracy = (outputs.argmax(1) == targets).sum()
            total_accuracy += accuracy

    print(f"整体测试集上的Loss: {total_test_loss}")
    print(f"整体测试集上的正确率: {total_accuracy / test_data_size}")

    # 保存每一轮训练后的结果
    torch.save(tudui, f"tudui_{i}.pth")

3. 关键细节解析

如何计算正确率 (Accuracy)?

在测试代码中,这一行是精华:outputs.argmax(1) == targets

  • outputs.argmax(1):找到模型预测得分最高的那个类别索引。
  • == targets:与真实标签比对,返回布尔值。
  • .sum():统计正确的个数。

为什么需要 with torch.no_grad()

在测试阶段,我们不需要更新权重,也不需要存储计算图。加上这一行可以大幅减少显存占用并提升运行速度


4. 总结:训练套路的五个"一"

分析完这个文件,你可以记住这个简化的口诀:

  • 个模型类(定义结构)。
  • 个损失函数(定义目标)。
  • 个优化器(定义更新规则)。
  • 套训练代码(三步走:清零、反传、步进)。
  • 套验证代码(不计梯度、算准确率)。

💡 结语

到这里,你已经掌握了 PyTorch 开发的精髓。从最初的图片加载,到现在的模型评估,所有的模块都已经各就各位。

相关推荐
Veggie261 小时前
【Java深度学习】PyTorch On Java 系列课程 第八章 17 :模型评估【AI Infra 3.0】[PyTorch Java 硕士研一课程]
java·人工智能·深度学习
链上杯子2 小时前
《2026 LangChain零基础入门:用AI应用框架快速搭建智能助手》第8课(完结篇):小项目实战 + 部署 —— 构建网页版个人知识库 AI 助手
人工智能·langchain
东方不败之鸭梨的测试笔记2 小时前
AI生成测试用例方案
人工智能·测试用例
笨手笨脚の3 小时前
AI 基础概念
人工智能·大模型·prompt·agent·tool
飞睿科技3 小时前
解析 ESP-AirPuff 泡芙一号的 ESP32-P4 大模型 AI 智能体方案
人工智能
gregmankiw3 小时前
Nemotron架构(Mamba3+Transformer+Moe)
android·深度学习·transformer
云烟成雨TD3 小时前
Spring AI Alibaba 1.x 系列【4】ReAct 范式与 ReactAgent 核心设计
java·人工智能·spring
乐分启航4 小时前
SliMamba:十余K参数量刷新SOTA!高光谱分类的“降维打击“来了
java·人工智能·深度学习·算法·机器学习·分类·数据挖掘
_codemonster4 小时前
被子植物门 —— 纲、目、科详细梳理 + 分类依据
人工智能·分类·数据挖掘