用 PyTorch 实现 MNIST 手写数字识别:从入门到实践

手写数字识别是机器学习领域的经典入门案例,而 MNIST 数据集则是这个领域的 "Hello World"。本文将带你从零开始,使用 PyTorch 构建一个两层神经网络,完成 MNIST 手写数字的识别任务。无论你是机器学习新手还是想复习基础,这篇教程都能帮助你理解神经网络的基本原理和实现过程。

什么是 MNIST 数据集?

MNIST(Modified National Institute of Standards and Technology)数据集包含 60,000 个训练样本和 10,000 个测试样本,均为 28×28 像素的灰度手写数字图像(0-9)。它之所以成为入门经典,是因为:

  • 数据规模适中,不需要超级计算机也能训练
  • 任务明确(10 分类问题),评价指标简单(准确率)
  • 预处理简单,无需复杂的图像增强

环境准备与依赖库

本次实现基于 PyTorch 框架,需要安装以下依赖:

  • numpy:数值计算
  • torch:PyTorch 核心库
  • torchvision:包含 MNIST 数据集和图像处理工具
  • matplotlib:可视化工具
  • tensorboard:训练过程可视化

安装命令:

bash

复制代码
pip install numpy torch torchvision matplotlib tensorboard

实现步骤详解

1. 超参数配置

在开始之前,我们先集中定义所有可配置的超参数,方便后续调试和优化:

python

运行

复制代码
config = {
    "train_batch_size": 64,  # 训练批次大小
    "test_batch_size": 128,   # 测试批次大小
    "learning_rate": 0.01,    # 初始学习率
    "num_epochs": 20,         # 训练轮次
    "in_dim": 28 * 28,        # 输入维度(28x28像素)
    "n_hidden_1": 300,        # 第一个隐藏层神经元数
    "n_hidden_2": 100,        # 第二个隐藏层神经元数
    "out_dim": 10,            # 输出维度(10个数字类别)
    "log_dir": "logs",        # TensorBoard日志目录
    "data_root": "../data"    # 数据保存路径
}

超参数的选择对模型性能影响很大,后续可以通过调整这些参数来优化模型。

2. 数据加载与预处理

数据预处理是机器学习 pipeline 中至关重要的一步,直接影响模型性能:

python

运行

复制代码
def load_data(data_root, train_batch_size, test_batch_size):
    # 定义预处理流程
    transform = transforms.Compose([
        transforms.ToTensor(),  # 转换为Tensor并归一化到[0,1]
        transforms.Normalize([0.5], [0.5])  # 标准化到[-1,1]
    ])

    # 加载训练集和测试集
    train_dataset = MNIST(
        root=data_root,
        train=True,
        transform=transform,
        download=True
    )
    
    test_dataset = MNIST(
        root=data_root,
        train=False,
        transform=transform
    )

    # 数据加载器
    train_loader = DataLoader(
        train_dataset,
        batch_size=train_batch_size,
        shuffle=True  # 训练时打乱数据
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=test_batch_size,
        shuffle=False  # 测试时无需打乱
    )
    
    return train_loader, test_loader

预处理说明

  • ToTensor():将图像从 PIL 格式转换为 PyTorch 张量,并将像素值从 [0,255] 缩放到 [0,1]
  • Normalize():标准化处理,公式为(x - mean) / std,这里将数据调整为均值 0、标准差 0.5,最终范围为 [-1,1]
  • DataLoader:提供批量加载、打乱数据、多线程加载等功能

3. 数据可视化

加载数据后,我们可以随机可视化几个样本,验证数据加载是否正确:

python

运行

复制代码
def visualize_samples(test_loader, num_samples=6):
    examples = enumerate(test_loader)
    batch_idx, (example_data, example_targets) = next(examples)
    
    fig = plt.figure(figsize=(8, 4))
    for i in range(num_samples):
        plt.subplot(2, 3, i + 1)
        plt.tight_layout()
        plt.imshow(example_data[i][0], cmap='gray')
        plt.title(f"标签: {example_targets[i].item()}")
        plt.xticks([])
        plt.yticks([])
    plt.show()

4. 神经网络模型设计

我们将构建一个包含两个隐藏层的全连接神经网络:

python

运行

复制代码
class MNISTNet(nn.Module):
    def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
        super(MNISTNet, self).__init__()
        self.flatten = nn.Flatten()  # 展平层
        
        # 第一个隐藏层(带批归一化)
        self.layer1 = nn.Sequential(
            nn.Linear(in_dim, n_hidden_1),
            nn.BatchNorm1d(n_hidden_1)
        )
        
        # 第二个隐藏层(带批归一化)
        self.layer2 = nn.Sequential(
            nn.Linear(n_hidden_1, n_hidden_2),
            nn.BatchNorm1d(n_hidden_2)
        )
        
        # 输出层
        self.out = nn.Linear(n_hidden_2, out_dim)

    def forward(self, x):
        x = self.flatten(x)  # 展平为1D向量
        x = F.relu(self.layer1(x))  # 第一层+ReLU激活
        x = F.relu(self.layer2(x))  # 第二层+ReLU激活
        x = F.softmax(self.out(x), dim=1)  # 输出层+softmax
        return x

模型设计要点

  • nn.Flatten():将 28×28 的二维图像转换为 784 维的一维向量
  • 批归一化(BatchNorm1d):加速训练收敛,提高稳定性
  • ReLU 激活函数:解决梯度消失问题,引入非线性
  • Softmax 输出:将最后一层输出转换为概率分布(总和为 1)

5. 训练与评估流程

训练过程是模型学习的核心,我们需要定义损失函数、优化器,并实现完整的训练循环:

python

运行

复制代码
def train_and_evaluate(config):
    # 加载数据
    train_loader, test_loader = load_data(**config)
    visualize_samples(test_loader)
    
    # 设备配置(自动选择GPU或CPU)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"使用设备: {device}")
    
    # 初始化模型、损失函数和优化器
    model = MNISTNet(** config).to(device)
    criterion = nn.CrossEntropyLoss()  # 交叉熵损失
    optimizer = optim.SGD(
        model.parameters(),
        lr=config["learning_rate"],
        momentum=0.9  # 动量加速收敛
    )
    
    # 训练循环
    writer = SummaryWriter(log_dir=config["log_dir"])  # TensorBoard日志
    losses, eval_acces = [], []
    
    for epoch in range(config["num_epochs"]):
        # 训练阶段
        model.train()  # 训练模式
        train_loss, train_acc = 0.0, 0.0
        
        # 学习率调整
        if epoch % 5 == 0 and epoch != 0:
            optimizer.param_groups[0]['lr'] *= 0.9
            print(f"学习率调整为: {optimizer.param_groups[0]['lr']:.6f}")
        
        for img, label in train_loader:
            img, label = img.to(device), label.to(device)
            
            # 前向传播
            output = model(img)
            loss = criterion(output, label)
            
            # 反向传播与优化
            optimizer.zero_grad()  # 清空梯度
            loss.backward()        # 计算梯度
            optimizer.step()       # 更新参数
            
            # 计算损失和准确率
            train_loss += loss.item()
            _, pred = torch.max(output, 1)
            train_acc += (pred == label).sum().item() / img.size(0)
        
        # 评估阶段
        model.eval()  # 评估模式
        eval_acc = 0.0
        with torch.no_grad():  # 禁用梯度计算
            for img, label in test_loader:
                img, label = img.to(device), label.to(device)
                output = model(img)
                _, pred = torch.max(output, 1)
                eval_acc += (pred == label).sum().item() / img.size(0)
        
        # 记录指标
        avg_train_loss = train_loss / len(train_loader)
        avg_train_acc = train_acc / len(train_loader)
        avg_eval_acc = eval_acc / len(test_loader)
        
        losses.append(avg_train_loss)
        eval_acces.append(avg_eval_acc)
        
        print(f"Epoch [{epoch+1}/{config['num_epochs']}] | "
              f"训练损失: {avg_train_loss:.4f}, 训练准确率: {avg_train_acc:.4f} | "
              f"测试准确率: {avg_eval_acc:.4f}")
    
    # 可视化训练结果
    visualize_training(losses, eval_acces)
    writer.close()

训练关键步骤解析

  1. 设备选择:自动检测并使用 GPU(如有),大幅加速训练
  2. 损失函数:使用交叉熵损失(CrossEntropyLoss),适合多分类问题
  3. 优化器:带动量的 SGD,动量有助于加速收敛并跳出局部最优
  4. 学习率调度:每 5 个 epoch 将学习率乘以 0.9,后期精细化优化
  5. 训练模式与评估模式model.train()model.eval()控制批归一化等层的行为
  6. 梯度管理optimizer.zero_grad()清空梯度,loss.backward()计算梯度,optimizer.step()更新参数

6. 结果可视化

训练结束后,我们可以可视化损失和准确率曲线,直观了解模型性能变化:

python

运行

复制代码
def visualize_training(losses, eval_acces):
    fig = plt.figure(figsize=(10, 4))
    
    # 损失曲线
    plt.subplot(1, 2, 1)
    plt.title('训练损失')
    plt.plot(np.arange(len(losses)), losses, 'b-')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    
    # 准确率曲线
    plt.subplot(1, 2, 2)
    plt.title('测试准确率')
    plt.plot(np.arange(len(eval_acces)), eval_acces, 'g-')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    
    plt.tight_layout()
    plt.show()

模型优化方向

如果想进一步提高准确率,可以尝试以下方法:

  1. 增加网络深度或宽度(但要注意防止过拟合)
  2. 使用更先进的优化器(如 Adam 替代 SGD)
  3. 调整学习率调度策略
  4. 添加 dropout 层防止过拟合
  5. 尝试数据增强(旋转、平移等)

总结

本文详细介绍了使用 PyTorch 实现 MNIST 手写数字识别的完整流程,包括数据加载与预处理、模型设计、训练循环和结果可视化。通过这个案例,我们可以掌握神经网络的基本原理和实现方法:

  • 数据预处理对模型性能的重要性
  • 神经网络的基本组成(线性层、激活函数、批归一化)
  • 训练过程的核心步骤(前向传播、损失计算、反向传播、参数更新)
  • 如何评估模型性能并进行可视化分析
相关推荐
链上日记44 分钟前
WEEX出席迪拜区块链生活2025,担任白金赞助商
人工智能·区块链·生活
灵途科技3 小时前
灵途科技亮相NEPCON ASIA 2025 以光电感知点亮具身智能未来
人工智能·科技·机器人
文火冰糖的硅基工坊4 小时前
[人工智能-大模型-125]:模型层 - RNN的隐藏层是什么网络,全连接?还是卷积?RNN如何实现状态记忆?
人工智能·rnn·lstm
IT90904 小时前
c#+ visionpro汽车行业,机器视觉通用检测程序源码 产品尺寸检测,机械手引导定位等
人工智能·计算机视觉·视觉检测
Small___ming5 小时前
【人工智能数学基础】多元高斯分布
人工智能·机器学习·概率论
渔舟渡简5 小时前
机器学习-回归分析概述
人工智能·机器学习
王哈哈^_^5 小时前
【数据集】【YOLO】目标检测游泳数据集 4481 张,溺水数据集,YOLO河道、海滩游泳识别算法实战训练教程。
人工智能·算法·yolo·目标检测·计算机视觉·分类·视觉检测
桂花饼5 小时前
Sora 2:从视频生成到世界模拟,OpenAI的“终极游戏”
人工智能·aigc·openai·sora 2
007php0075 小时前
某游戏大厂 Java 面试题深度解析(四)
java·开发语言·python·面试·职场和发展·golang·php
wwlsm_zql6 小时前
荣耀YOYO智能体:自动执行与任务规划,开启智能生活新篇章
人工智能·生活