Pytorch完整的模型训练流程

一个完整的 PyTorch 模型训练流程通常遵循一个标准的"套路"。为了让你一目了然,我将通过一个经典的 MNIST 手写数字识别 任务,为你展示从数据准备到模型保存的全流程。

这个流程可以概括为 6 个核心步骤

1. 导入库与设置设备

首先,我们需要导入必要的库,并检查是否有 GPU 可用。

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# 设置设备:优先使用 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

2. 数据准备 (Data Preparation)

这是"喂食"给模型的数据。我们需要定义数据预处理方式,并加载数据集。

python 复制代码
# 定义数据预处理:将图片转为张量并进行归一化
transform = transforms.Compose([
    transforms.ToTensor(), # 转为 Tensor 并归一化到 [0,1]
    transforms.Normalize((0.1307,), (0.3081,)) # MNIST 的均值和标准差
])

# 加载训练集和测试集
train_set = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_set = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# 创建数据加载器 (DataLoader)
train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = DataLoader(test_set, batch_size=1000, shuffle=False)

3. 定义模型 (Model Definition)

我们定义一个简单的卷积神经网络(CNN)。

python 复制代码
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1) # 输入通道1(灰度图),输出32
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout = nn.Dropout2d(0.25)
        self.fc1 = nn.Linear(9216, 128) # 全连接层
        self.fc2 = nn.Linear(128, 10)  # 输出10类 (0-9)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2) # 降维
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = self.dropout(x)
        x = torch.flatten(x, 1) # 展平
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# 实例化模型并移动到设备 (GPU/CPU)
model = SimpleCNN().to(device)

4. 配置损失函数与优化器

这是模型的"指南针"和"导航员"。

python 复制代码
# 损失函数:交叉熵损失,适用于分类任务
criterion = nn.CrossEntropyLoss()

# 优化器:Adam,自动调整学习率
optimizer = optim.Adam(model.parameters(), lr=0.001)

5. 训练与验证循环 (Training Loop)

这是最核心的部分,包含前向传播、反向传播和参数更新。

python 复制代码
def train(epoch):
    model.train() # 切换到训练模式 (开启 Dropout 等)
    for batch_idx, (data, target) in enumerate(train_loader):
        # 1. 数据移到设备
        data, target = data.to(device), target.to(device)
        
        # 2. 梯度清零
        optimizer.zero_grad()
        
        # 3. 前向传播
        output = model(data)
        loss = criterion(output, target)
        
        # 4. 反向传播与参数更新
        loss.backward()
        optimizer.step()
        
        # 打印日志
        if batch_idx % 100 == 0:
            print(f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} "
                  f"({100. * batch_idx / len(train_loader):.0f}%)] Loss: {loss.item():.6f}")

def test():
    model.eval() # 切换到评估模式 (关闭 Dropout)
    correct = 0
    with torch.no_grad(): # 禁用梯度计算,节省内存
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            # 统计正确率
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    
    accuracy = 100. * correct / len(test_loader.dataset)
    print(f"\nTest Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.2f}%)\n")
    return accuracy

6. 执行训练与保存模型

启动训练,并保存最终结果。

python 复制代码
# 开始训练
epochs = 5
best_acc = 0
for epoch in range(1, epochs + 1):
    train(epoch)
    acc = test()
    # 保存最佳模型
    if acc > best_acc:
        best_acc = acc
        torch.save(model.state_dict(), "mnist_best_model.pth")

print(f"训练完成!最佳准确率: {best_acc:.2f}%")

📌 流程总结图

为了方便记忆,你可以将这个流程看作一个循环:

  1. 准备数据Dataset + DataLoader
  2. 构建模型 :继承 nn.Module
  3. 设置损失与优化器nn.CrossEntropyLoss() + optim.Adam
  4. 循环体
    • optimizer.zero_grad()
    • loss = model(data) (前向)
    • loss.backward() (反向)
    • optimizer.step() (更新)
  5. 验证与保存torch.save(model.state_dict())

这就是 PyTorch 最标准的训练模板,你可以将这个结构套用到大多数图像分类任务中。

相关推荐
GinoInterpreter1 小时前
什么是翻译的去中心化?
人工智能·自然语言处理·去中心化·区块链·机器翻译·机器翻译模型·机器翻译引擎
码农小白AI1 小时前
IACheck AI报告文档审核:高端制造合规新助力,保障标准引用报告质量
大数据·人工智能·制造
_YiFei2 小时前
哪个降论文AI率工具最好用?
人工智能·深度学习·神经网络
放下华子我只抽RuiKe52 小时前
机器学习全景指南-直觉篇——基于距离的 K-近邻 (KNN) 算法
人工智能·gpt·算法·机器学习·语言模型·chatgpt·ai编程
kisshuan123962 小时前
[特殊字符]【深度学习】DA3METRIC-LARGE单目深度估计算法详解
人工智能·深度学习·算法
sali-tec2 小时前
C# 基于OpenCv的视觉工作流-章33-Blod分析
图像处理·人工智能·opencv·算法·计算机视觉
老星*2 小时前
Trae-cn一句话安装OpenClaw:AI智能体框架快速部署指南
人工智能·编辑器
昨夜见军贴06162 小时前
IACheck结合AI报告审核:轨道扣件横向阻力检测报告确保无误差
人工智能
Qt学视觉2 小时前
AI2-Paddle环境搭建
c++·人工智能·python·opencv·paddle
泰迪智能科技2 小时前
分享|高校必备三大实训管理平台,助力高校人工智能、大数据、商务数据分析人才培养
大数据·人工智能·数据分析