PyTorch模型训练全流程详解

完整且标准的 PyTorch 模型训练模板

python 复制代码
import torchvision
from torch.utils.tensorboard import SummaryWriter
from model import * # 从外部 model.py 导入自定义模型 Tudui
from torch import nn
from torch.utils.data import DataLoader

# --- 1. 准备数据集 ---
# 加载 CIFAR10 训练集,transform 确保图片转为 Tensor
train_data = torchvision.datasets.CIFAR10(root="../data", train=True, transform=torchvision.transforms.ToTensor(),
                                          download=True)
# 加载 CIFAR10 测试集
test_data = torchvision.datasets.CIFAR10(root="../data", train=False, transform=torchvision.transforms.ToTensor(),
                                         download=True)

# 获取数据集长度,用于后续打印和计算准确率
train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度为:{}".format(train_data_size))
print("测试数据集的长度为:{}".format(test_data_size))

# --- 2. 利用 DataLoader 分批次加载数据 ---
# batch_size=64 表示每次训练输入 64 张图片
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

# --- 3. 初始化模型、损失函数和优化器 ---
tudui = Tudui() # 实例化模型

# 交叉熵损失函数,常用于多分类任务
loss_fn = nn.CrossEntropyLoss()

# 学习率:$1e-2 = 0.01$
learning_rate = 1e-2
# 随机梯度下降优化器,负责更新模型权重
optimizer = torch.optim.SGD(tudui.parameters(), lr=learning_rate)

# --- 4. 设置训练计数器与参数 ---
total_train_step = 0 # 记录训练总步数(总图片数/batch_size)
total_test_step = 0  # 记录测试总次数(通常等于 epoch 数)
epoch = 10           # 训练 10 轮

# 初始化 TensorBoard 写入器
writer = SummaryWriter("../logs_train")

# --- 5. 核心训练与测试循环 ---
for i in range(epoch):
    print("-------第 {} 轮训练开始-------".format(i+1))

    # --- 训练步骤 ---
    tudui.train() # 将模型设置为训练模式(影响 Dropout 和 BatchNorm 层)
    for data in train_dataloader:
        imgs, targets = data
        outputs = tudui(imgs)     # 前向传播:计算输出
        loss = loss_fn(outputs, targets) # 计算损失

        # 优化器"标准三步走":
        optimizer.zero_grad() # 1. 梯度清零
        loss.backward()       # 2. 反向传播,计算梯度
        optimizer.step()      # 3. 根据梯度更新权重

        total_train_step += 1
        # 每训练 100 次打印一次状态并记录 TensorBoard
        if total_train_step % 100 == 0:
            print("训练次数:{}, Loss: {}".format(total_train_step, loss.item()))
            writer.add_scalar("train_loss", loss.item(), total_train_step)

    # --- 测试/验证步骤 ---
    tudui.eval() # 将模型设置为评估模式
    total_test_loss = 0
    total_accuracy = 0
    # 测试时不需要计算梯度,节省内存和算力
    with torch.no_grad():
        for data in test_dataloader:
            imgs, targets = data
            outputs = tudui(imgs)
            loss = loss_fn(outputs, targets)
            total_test_loss += loss.item() # 累加测试集损失
            
            # 准确率计算:argmax(1) 找出预测得分最高的索引,与标签比对
            accuracy = (outputs.argmax(1) == targets).sum()
            total_accuracy += accuracy

    # 打印一轮训练后的整体表现
    print("整体测试集上的Loss: {}".format(total_test_loss))
    print("整体测试集上的正确率: {}".format(total_accuracy/test_data_size))
    
    # 记录测试结果到 TensorBoard
    writer.add_scalar("test_loss", total_test_loss, total_test_step)
    writer.add_scalar("test_accuracy", total_accuracy/test_data_size, total_test_step)
    total_test_step += 1

    # 保存每一轮训练后的模型
    torch.save(tudui, "tudui_{}.pth".format(i))
    print("模型已保存")

writer.close()
1. 训练 vs 验证周期

代码在每轮训练(epoch)结束后都会在完整的测试集上跑一遍。这有助于我们观察模型是否过拟合(即训练集 Loss 下降,但测试集正确率不再上升甚至下降)。

2. 前向传播与反向传播
  • Forward : outputs = tudui(imgs)。数据流过层层神经元。

  • Backward : loss.backward()。这是根据误差方向计算各层权重的"责任"大小。

  • Step : optimizer.step()。就像在损失的山坡上向下迈一小步。

3. 核心计算技巧:argmax(1)

由于分类任务的输出是一个包含 10 个类别的概率向量(Logits),argmax(1) 的作用是锁定得分最高的那一个位置。比如 [0.1, 0.8, 0.1]argmax1

4. tudui.train()tudui.eval()

这两行经常被新手忽略。虽然在这个简单的模型中可能不影响结果,但如果你的模型包含 Dropout (随机丢弃神经元)或 Batch Normalization(批归一化),这两行代码能确保模型在训练时"随机学习",在测试时"全力以赴"。

相关推荐
码农的神经元2 小时前
从零搭建一个带 GUI 的机器学习建模系统:多模型切换、遗传算法优化与可视化实战复盘
人工智能·机器学习
一楼的猫2 小时前
茄子小说AI辅助智能写作助手:10倍速创作神器
人工智能·学习·机器学习·学习方法·ai写作·迁移学习·集成学习
独隅2 小时前
PyTorch模型转TensorFlow Lite的Android部署全流程指南
android·pytorch·tensorflow
懂AI的老郑2 小时前
人工智能手机的构建思路:从架构到实现
人工智能·智能手机·架构
思绪无限2 小时前
YOLOv5至YOLOv12升级:交通信号灯识别系统的设计与实现(完整代码+界面+数据集项目)
深度学习·yolo·目标检测·交通信号灯识别·yolov12·yolo全家桶
gjhave2 小时前
强化学习论文(Double-DQN)
人工智能·机器学习
rADu REME2 小时前
rust web框架actix和axum比较
前端·人工智能·rust
Mark-Han2 小时前
AI产品的定价是一门玄学
人工智能
BizViewStudio2 小时前
GEO vs SEO vs SEM:2026 年品牌流量获取的三元格局分析
大数据·运维·网络·人工智能·ai