PyTorch模型训练全流程详解

完整且标准的 PyTorch 模型训练模板

python 复制代码
import torchvision
from torch.utils.tensorboard import SummaryWriter
from model import * # 从外部 model.py 导入自定义模型 Tudui
from torch import nn
from torch.utils.data import DataLoader

# --- 1. 准备数据集 ---
# 加载 CIFAR10 训练集,transform 确保图片转为 Tensor
train_data = torchvision.datasets.CIFAR10(root="../data", train=True, transform=torchvision.transforms.ToTensor(),
                                          download=True)
# 加载 CIFAR10 测试集
test_data = torchvision.datasets.CIFAR10(root="../data", train=False, transform=torchvision.transforms.ToTensor(),
                                         download=True)

# 获取数据集长度,用于后续打印和计算准确率
train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度为:{}".format(train_data_size))
print("测试数据集的长度为:{}".format(test_data_size))

# --- 2. 利用 DataLoader 分批次加载数据 ---
# batch_size=64 表示每次训练输入 64 张图片
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

# --- 3. 初始化模型、损失函数和优化器 ---
tudui = Tudui() # 实例化模型

# 交叉熵损失函数,常用于多分类任务
loss_fn = nn.CrossEntropyLoss()

# 学习率:$1e-2 = 0.01$
learning_rate = 1e-2
# 随机梯度下降优化器,负责更新模型权重
optimizer = torch.optim.SGD(tudui.parameters(), lr=learning_rate)

# --- 4. 设置训练计数器与参数 ---
total_train_step = 0 # 记录训练总步数(总图片数/batch_size)
total_test_step = 0  # 记录测试总次数(通常等于 epoch 数)
epoch = 10           # 训练 10 轮

# 初始化 TensorBoard 写入器
writer = SummaryWriter("../logs_train")

# --- 5. 核心训练与测试循环 ---
for i in range(epoch):
    print("-------第 {} 轮训练开始-------".format(i+1))

    # --- 训练步骤 ---
    tudui.train() # 将模型设置为训练模式(影响 Dropout 和 BatchNorm 层)
    for data in train_dataloader:
        imgs, targets = data
        outputs = tudui(imgs)     # 前向传播:计算输出
        loss = loss_fn(outputs, targets) # 计算损失

        # 优化器"标准三步走":
        optimizer.zero_grad() # 1. 梯度清零
        loss.backward()       # 2. 反向传播,计算梯度
        optimizer.step()      # 3. 根据梯度更新权重

        total_train_step += 1
        # 每训练 100 次打印一次状态并记录 TensorBoard
        if total_train_step % 100 == 0:
            print("训练次数:{}, Loss: {}".format(total_train_step, loss.item()))
            writer.add_scalar("train_loss", loss.item(), total_train_step)

    # --- 测试/验证步骤 ---
    tudui.eval() # 将模型设置为评估模式
    total_test_loss = 0
    total_accuracy = 0
    # 测试时不需要计算梯度,节省内存和算力
    with torch.no_grad():
        for data in test_dataloader:
            imgs, targets = data
            outputs = tudui(imgs)
            loss = loss_fn(outputs, targets)
            total_test_loss += loss.item() # 累加测试集损失
            
            # 准确率计算:argmax(1) 找出预测得分最高的索引,与标签比对
            accuracy = (outputs.argmax(1) == targets).sum()
            total_accuracy += accuracy

    # 打印一轮训练后的整体表现
    print("整体测试集上的Loss: {}".format(total_test_loss))
    print("整体测试集上的正确率: {}".format(total_accuracy/test_data_size))
    
    # 记录测试结果到 TensorBoard
    writer.add_scalar("test_loss", total_test_loss, total_test_step)
    writer.add_scalar("test_accuracy", total_accuracy/test_data_size, total_test_step)
    total_test_step += 1

    # 保存每一轮训练后的模型
    torch.save(tudui, "tudui_{}.pth".format(i))
    print("模型已保存")

writer.close()
1. 训练 vs 验证周期

代码在每轮训练(epoch)结束后都会在完整的测试集上跑一遍。这有助于我们观察模型是否过拟合(即训练集 Loss 下降,但测试集正确率不再上升甚至下降)。

2. 前向传播与反向传播
  • Forward : outputs = tudui(imgs)。数据流过层层神经元。

  • Backward : loss.backward()。这是根据误差方向计算各层权重的"责任"大小。

  • Step : optimizer.step()。就像在损失的山坡上向下迈一小步。

3. 核心计算技巧:argmax(1)

由于分类任务的输出是一个包含 10 个类别的概率向量(Logits),argmax(1) 的作用是锁定得分最高的那一个位置。比如 [0.1, 0.8, 0.1]argmax1

4. tudui.train()tudui.eval()

这两行经常被新手忽略。虽然在这个简单的模型中可能不影响结果,但如果你的模型包含 Dropout (随机丢弃神经元)或 Batch Normalization(批归一化),这两行代码能确保模型在训练时"随机学习",在测试时"全力以赴"。

相关推荐
To_OC2 小时前
搞懂 Token 和 Embedding 后,我终于明白大模型是怎么 "读" 文字的
人工智能·llm·agent
冬奇Lab5 小时前
每日一个开源项目(第139篇):Voicebox - 本地运行的开源 ElevenLabs 替代品
人工智能·开源·资讯
冬奇Lab5 小时前
Skill 系列(03):Skill 设计范式——5 个模式让输出从混沌到可预测
人工智能·开源·agent
IT_陈寒7 小时前
Python搞不定字符串编码?这破玩意坑我两小时!
前端·人工智能·后端
大模型真好玩8 小时前
什么是Loop Engineering?最通俗易懂的Loop Engineering核心概念
人工智能·agent·deepseek
叁两9 小时前
前端转型AI Agent该如何学习?(前置篇)
前端·人工智能·node.js
LaiYoung_9 小时前
🎁 送你一套超好用超实用的 FE AI-Coding Skills
前端·人工智能·开源
ZzT11 小时前
怎么做才不会被 AI 替代?
人工智能·程序员
道友可好11 小时前
从今天开始:你的第一个 Harness Engineering 实践
前端·人工智能·后端