115_PyTorch 实战:从零搭建 CIFAR-10 完整训练与测试流水线

在掌握了神经网络的各个组件后,如何将它们组织成一个可运行、可监控、可保存的完整项目?本篇将通过 CIFAR-10 识别任务,拆解 PyTorch 训练的标准"套路"。

1. 训练全流程概览

一个标准的深度学习训练脚本通常包含以下几个固定环节:

  1. 准备数据集与加载器:Dataset & DataLoader。
  2. 搭建网络结构:定义模型类并实例化。
  3. 设置损失函数与优化器:选择合适的评价与优化算法。
  4. 训练循环 (Train Loop):前向传播、算损失、反向传播、更新参数。
  5. 测试/验证循环 (Test Loop):评估模型在未见过的数据上的表现。
  6. 可视化与保存:记录日志并持久化模型。

2. 核心代码实战

文件演示了如何通过多轮迭代(Epoch)来提升模型准确率。

① 网络模型定义 (model.py 逻辑)

这里使用了经典的 CIFAR-10 结构,通过 Sequential 串联。

Python

复制代码
import torch
from torch import nn

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64*4*4, 64),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        return self.model(x)

② 训练与测试核心循环

注意代码中 tudui.train()tudui.eval() 的切换,这是为了正确处理 Dropout 或 BatchNorm 等层。

Python

复制代码
# 定义损失函数和优化器
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(tudui.parameters(), lr=0.01)

# 开始训练
for i in range(epoch):
    print(f"-------第 {i+1} 轮训练开始-------")

    # 训练步骤
    tudui.train() # 切换到训练模式
    for data in train_dataloader:
        imgs, targets = data
        outputs = tudui(imgs)
        loss = loss_fn(outputs, targets)

        # 优化器优化模型
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # 测试步骤
    tudui.eval() # 切换到评估模式
    total_test_loss = 0
    total_accuracy = 0
    with torch.no_grad(): # 测试时不需要梯度,节省计算资源
        for data in test_dataloader:
            imgs, targets = data
            outputs = tudui(imgs)
            loss = loss_fn(outputs, targets)
            total_test_loss += loss.item()
            # 计算正确率
            accuracy = (outputs.argmax(1) == targets).sum()
            total_accuracy += accuracy

    print(f"整体测试集上的Loss: {total_test_loss}")
    print(f"整体测试集上的正确率: {total_accuracy / test_data_size}")

    # 保存每一轮训练后的结果
    torch.save(tudui, f"tudui_{i}.pth")

3. 关键细节解析

如何计算正确率 (Accuracy)?

在测试代码中,这一行是精华:outputs.argmax(1) == targets

  • outputs.argmax(1):找到模型预测得分最高的那个类别索引。
  • == targets:与真实标签比对,返回布尔值。
  • .sum():统计正确的个数。

为什么需要 with torch.no_grad()

在测试阶段,我们不需要更新权重,也不需要存储计算图。加上这一行可以大幅减少显存占用并提升运行速度


4. 总结:训练套路的五个"一"

分析完这个文件,你可以记住这个简化的口诀:

  • 个模型类(定义结构)。
  • 个损失函数(定义目标)。
  • 个优化器(定义更新规则)。
  • 套训练代码(三步走:清零、反传、步进)。
  • 套验证代码(不计梯度、算准确率)。

💡 结语

到这里,你已经掌握了 PyTorch 开发的精髓。从最初的图片加载,到现在的模型评估,所有的模块都已经各就各位。

相关推荐
Ricky055313 小时前
YOLO-FCE:一种基于特征与聚类增强的物种分类目标检测模型(澳大利亚2026年研究)
图像处理·人工智能·yolo·目标检测·分类
学习中.........13 小时前
大语言模型的推理机制与工程应用
人工智能·语言模型·自然语言处理
一切皆是因缘际会13 小时前
从模型竞赛到全域智能的时代跃迁
人工智能·深度学习·ai·分布式系统
极光代码工作室13 小时前
基于NLP的招聘信息关键词分析系统
python·深度学习·自然语言处理·nlp
2601_9578885613 小时前
流量终局与信源争夺:GEO(生成式引擎优化)时代的爬虫分析与数据管道构建
人工智能·爬虫
名不经传的养虾人13 小时前
从0到1:企业级AI项目迭代日记 Vol.35|追问比演示重要——技术团队问出的五个工程缺口
人工智能·算法·机器学习·ai编程·ai工作流·企业ai
光芒Shine13 小时前
【机器学习-mediapipe】
人工智能·机器人
多米哇卡13 小时前
Figure 03 实测 200 小时稳定作业,人形机器人商业化落地提速
大数据·人工智能·机器人
锦鲤521413 小时前
深度学习与神经网络学习
深度学习·神经网络·学习
1892280486113 小时前
NQ486固态MT29F16T08GSLDHL8-QM:D
大数据·人工智能·科技·microsoft·缓存