PyTorch模型训练全流程详解

完整且标准的 PyTorch 模型训练模板

python 复制代码
import torchvision
from torch.utils.tensorboard import SummaryWriter
from model import * # 从外部 model.py 导入自定义模型 Tudui
from torch import nn
from torch.utils.data import DataLoader

# --- 1. 准备数据集 ---
# 加载 CIFAR10 训练集,transform 确保图片转为 Tensor
train_data = torchvision.datasets.CIFAR10(root="../data", train=True, transform=torchvision.transforms.ToTensor(),
                                          download=True)
# 加载 CIFAR10 测试集
test_data = torchvision.datasets.CIFAR10(root="../data", train=False, transform=torchvision.transforms.ToTensor(),
                                         download=True)

# 获取数据集长度,用于后续打印和计算准确率
train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度为:{}".format(train_data_size))
print("测试数据集的长度为:{}".format(test_data_size))

# --- 2. 利用 DataLoader 分批次加载数据 ---
# batch_size=64 表示每次训练输入 64 张图片
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

# --- 3. 初始化模型、损失函数和优化器 ---
tudui = Tudui() # 实例化模型

# 交叉熵损失函数,常用于多分类任务
loss_fn = nn.CrossEntropyLoss()

# 学习率:$1e-2 = 0.01$
learning_rate = 1e-2
# 随机梯度下降优化器,负责更新模型权重
optimizer = torch.optim.SGD(tudui.parameters(), lr=learning_rate)

# --- 4. 设置训练计数器与参数 ---
total_train_step = 0 # 记录训练总步数(总图片数/batch_size)
total_test_step = 0  # 记录测试总次数(通常等于 epoch 数)
epoch = 10           # 训练 10 轮

# 初始化 TensorBoard 写入器
writer = SummaryWriter("../logs_train")

# --- 5. 核心训练与测试循环 ---
for i in range(epoch):
    print("-------第 {} 轮训练开始-------".format(i+1))

    # --- 训练步骤 ---
    tudui.train() # 将模型设置为训练模式(影响 Dropout 和 BatchNorm 层)
    for data in train_dataloader:
        imgs, targets = data
        outputs = tudui(imgs)     # 前向传播:计算输出
        loss = loss_fn(outputs, targets) # 计算损失

        # 优化器"标准三步走":
        optimizer.zero_grad() # 1. 梯度清零
        loss.backward()       # 2. 反向传播,计算梯度
        optimizer.step()      # 3. 根据梯度更新权重

        total_train_step += 1
        # 每训练 100 次打印一次状态并记录 TensorBoard
        if total_train_step % 100 == 0:
            print("训练次数:{}, Loss: {}".format(total_train_step, loss.item()))
            writer.add_scalar("train_loss", loss.item(), total_train_step)

    # --- 测试/验证步骤 ---
    tudui.eval() # 将模型设置为评估模式
    total_test_loss = 0
    total_accuracy = 0
    # 测试时不需要计算梯度,节省内存和算力
    with torch.no_grad():
        for data in test_dataloader:
            imgs, targets = data
            outputs = tudui(imgs)
            loss = loss_fn(outputs, targets)
            total_test_loss += loss.item() # 累加测试集损失
            
            # 准确率计算:argmax(1) 找出预测得分最高的索引,与标签比对
            accuracy = (outputs.argmax(1) == targets).sum()
            total_accuracy += accuracy

    # 打印一轮训练后的整体表现
    print("整体测试集上的Loss: {}".format(total_test_loss))
    print("整体测试集上的正确率: {}".format(total_accuracy/test_data_size))
    
    # 记录测试结果到 TensorBoard
    writer.add_scalar("test_loss", total_test_loss, total_test_step)
    writer.add_scalar("test_accuracy", total_accuracy/test_data_size, total_test_step)
    total_test_step += 1

    # 保存每一轮训练后的模型
    torch.save(tudui, "tudui_{}.pth".format(i))
    print("模型已保存")

writer.close()
1. 训练 vs 验证周期

代码在每轮训练(epoch)结束后都会在完整的测试集上跑一遍。这有助于我们观察模型是否过拟合(即训练集 Loss 下降,但测试集正确率不再上升甚至下降)。

2. 前向传播与反向传播
  • Forward : outputs = tudui(imgs)。数据流过层层神经元。

  • Backward : loss.backward()。这是根据误差方向计算各层权重的"责任"大小。

  • Step : optimizer.step()。就像在损失的山坡上向下迈一小步。

3. 核心计算技巧:argmax(1)

由于分类任务的输出是一个包含 10 个类别的概率向量(Logits),argmax(1) 的作用是锁定得分最高的那一个位置。比如 [0.1, 0.8, 0.1]argmax1

4. tudui.train()tudui.eval()

这两行经常被新手忽略。虽然在这个简单的模型中可能不影响结果,但如果你的模型包含 Dropout (随机丢弃神经元)或 Batch Normalization(批归一化),这两行代码能确保模型在训练时"随机学习",在测试时"全力以赴"。

相关推荐
weixin_468466852 分钟前
通义千问核心能力与实战表现深度评测
人工智能·深度学习·算法·ai·大模型
jerryinwuhan3 分钟前
marker BiBERTo解释
java·前端·人工智能
学习3人组4 分钟前
机器学习KNeighborsClassifier实现手写数字识别
人工智能·机器学习
掘金安东尼5 分钟前
如果你真能 7×24 小时运行最顶级的大模型,你会想用它来干嘛
人工智能
翼龙云_cloud5 分钟前
云服务器代理商:2026 年云计算趋势 AI 算力需求激增下的云服务器选择
服务器·人工智能·云计算·ai智能体
m沐沐5 分钟前
【机器学习】NLP---用 Python+TF-IDF 给《红楼梦》自动提取关键词
人工智能·python·机器学习·自然语言处理·nlp·中文分词·tf-idf
小脑斧1236 分钟前
自媒体内容工业化:基于AI Skills低代码实现穿搭账号矩阵自动化量产
人工智能·低代码·媒体·skills·openclaw·hermes·marvis
填满你的记忆6 分钟前
《为什么 MySQL 不适合做 AI 检索?》
数据库·人工智能·mysql·ai·向量数据库
威尔逊·柏斯科·希伯理8 分钟前
机器学习第二天(KNN)
人工智能·机器学习
书生的梦9 分钟前
《神经网络与深度学习》学习笔记(三):Transformer 模型
深度学习·神经网络·学习