PyTorch模型训练全流程详解

完整且标准的 PyTorch 模型训练模板

python 复制代码
import torchvision
from torch.utils.tensorboard import SummaryWriter
from model import * # 从外部 model.py 导入自定义模型 Tudui
from torch import nn
from torch.utils.data import DataLoader

# --- 1. 准备数据集 ---
# 加载 CIFAR10 训练集,transform 确保图片转为 Tensor
train_data = torchvision.datasets.CIFAR10(root="../data", train=True, transform=torchvision.transforms.ToTensor(),
                                          download=True)
# 加载 CIFAR10 测试集
test_data = torchvision.datasets.CIFAR10(root="../data", train=False, transform=torchvision.transforms.ToTensor(),
                                         download=True)

# 获取数据集长度,用于后续打印和计算准确率
train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度为:{}".format(train_data_size))
print("测试数据集的长度为:{}".format(test_data_size))

# --- 2. 利用 DataLoader 分批次加载数据 ---
# batch_size=64 表示每次训练输入 64 张图片
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

# --- 3. 初始化模型、损失函数和优化器 ---
tudui = Tudui() # 实例化模型

# 交叉熵损失函数,常用于多分类任务
loss_fn = nn.CrossEntropyLoss()

# 学习率:$1e-2 = 0.01$
learning_rate = 1e-2
# 随机梯度下降优化器,负责更新模型权重
optimizer = torch.optim.SGD(tudui.parameters(), lr=learning_rate)

# --- 4. 设置训练计数器与参数 ---
total_train_step = 0 # 记录训练总步数(总图片数/batch_size)
total_test_step = 0  # 记录测试总次数(通常等于 epoch 数)
epoch = 10           # 训练 10 轮

# 初始化 TensorBoard 写入器
writer = SummaryWriter("../logs_train")

# --- 5. 核心训练与测试循环 ---
for i in range(epoch):
    print("-------第 {} 轮训练开始-------".format(i+1))

    # --- 训练步骤 ---
    tudui.train() # 将模型设置为训练模式(影响 Dropout 和 BatchNorm 层)
    for data in train_dataloader:
        imgs, targets = data
        outputs = tudui(imgs)     # 前向传播:计算输出
        loss = loss_fn(outputs, targets) # 计算损失

        # 优化器"标准三步走":
        optimizer.zero_grad() # 1. 梯度清零
        loss.backward()       # 2. 反向传播,计算梯度
        optimizer.step()      # 3. 根据梯度更新权重

        total_train_step += 1
        # 每训练 100 次打印一次状态并记录 TensorBoard
        if total_train_step % 100 == 0:
            print("训练次数:{}, Loss: {}".format(total_train_step, loss.item()))
            writer.add_scalar("train_loss", loss.item(), total_train_step)

    # --- 测试/验证步骤 ---
    tudui.eval() # 将模型设置为评估模式
    total_test_loss = 0
    total_accuracy = 0
    # 测试时不需要计算梯度,节省内存和算力
    with torch.no_grad():
        for data in test_dataloader:
            imgs, targets = data
            outputs = tudui(imgs)
            loss = loss_fn(outputs, targets)
            total_test_loss += loss.item() # 累加测试集损失
            
            # 准确率计算:argmax(1) 找出预测得分最高的索引,与标签比对
            accuracy = (outputs.argmax(1) == targets).sum()
            total_accuracy += accuracy

    # 打印一轮训练后的整体表现
    print("整体测试集上的Loss: {}".format(total_test_loss))
    print("整体测试集上的正确率: {}".format(total_accuracy/test_data_size))
    
    # 记录测试结果到 TensorBoard
    writer.add_scalar("test_loss", total_test_loss, total_test_step)
    writer.add_scalar("test_accuracy", total_accuracy/test_data_size, total_test_step)
    total_test_step += 1

    # 保存每一轮训练后的模型
    torch.save(tudui, "tudui_{}.pth".format(i))
    print("模型已保存")

writer.close()
1. 训练 vs 验证周期

代码在每轮训练(epoch)结束后都会在完整的测试集上跑一遍。这有助于我们观察模型是否过拟合(即训练集 Loss 下降,但测试集正确率不再上升甚至下降)。

2. 前向传播与反向传播
  • Forward : outputs = tudui(imgs)。数据流过层层神经元。

  • Backward : loss.backward()。这是根据误差方向计算各层权重的"责任"大小。

  • Step : optimizer.step()。就像在损失的山坡上向下迈一小步。

3. 核心计算技巧:argmax(1)

由于分类任务的输出是一个包含 10 个类别的概率向量(Logits),argmax(1) 的作用是锁定得分最高的那一个位置。比如 [0.1, 0.8, 0.1]argmax1

4. tudui.train()tudui.eval()

这两行经常被新手忽略。虽然在这个简单的模型中可能不影响结果,但如果你的模型包含 Dropout (随机丢弃神经元)或 Batch Normalization(批归一化),这两行代码能确保模型在训练时"随机学习",在测试时"全力以赴"。

相关推荐
数据法师8 小时前
Sora退场,GPT Image 2.0封神!免费不限次还支持中文!
人工智能·gpt·计算机视觉
2601_957780848 小时前
GPT-5.5时代:从“指令集“到“任务契约“的Prompt工程范式迁移
大数据·人工智能·gpt·架构·prompt
扬帆破浪8 小时前
免费开源AI软件.桌面单机版,可移动的AI知识库,察元 AI桌面版:本地离线知识库的第一份 PDF 引用气泡是怎么连回原文的
人工智能·pdf
少许极端8 小时前
AI修炼记3-RAG
人工智能·ai·原型模式·rag
乔江seven8 小时前
【跟李沐学AI 】23 实战Kaggle:图象分类(CIFAR-10)
人工智能·深度学习·kaggle·cifar-10
2601_958352908 小时前
手撕环境噪音:双麦降噪模块AN-93上板实测,降噪36dB是真是假?
人工智能·音视频·嵌入式·降噪
乔江seven8 小时前
【跟李沐学AI】24 狗的品种识别(ImageNet Dogs)
人工智能·深度学习·计算机视觉·微调·imagenetdogs
AC赳赳老秦8 小时前
全链路自动化巡检:用 OpenClaw 实现服务器 - 应用 - 数据库全链路巡检,自动生成报告与整改建议
服务器·数据库·人工智能·深度学习·自动化·deepseek·openclaw
求学中--8 小时前
DeepSeek V4 API实战:从零搭建AI编程助手全流程
人工智能·ai编程