使用 PyTorch 实现 MNIST 手写数字分类的 CNN 模型

手写数字识别是计算机视觉领域的经典任务,而 MNIST 数据集提供了一个标准化的基准。本文将使用 PyTorch 框架构建一个标准的卷积神经网络(CNN),对 MNIST 数据集进行分类,并展示完整的训练和测试流程。

1. 环境准备与数据加载

在深度学习中,数据预处理是至关重要的步骤。MNIST 数据集包含 28×28 灰度图像,训练集 60000 张,测试集 10000 张。我们使用 torchvision 提供的工具来加载数据,并进行标准化处理:

复制代码
import torch
from torchvision import datasets, transforms

# 数据预处理:转换为 Tensor 并标准化到 [-1, 1]
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 加载训练集和测试集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset  = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# 数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader  = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)

说明

  • ToTensor() 会将图像从 [0, 255] 转换为 [0, 1] 的浮点数张量。

  • Normalize(mean, std) 将张量标准化为 (x - mean) / std),有助于加快训练收敛。


2. 构建卷积神经网络

我们定义一个两层卷积的 CNN 模型,结构如下:

  • 卷积层1:输入 1 通道 → 输出 32 通道,卷积核 3×3

  • 卷积层2:输入 32 通道 → 输出 64 通道,卷积核 3×3

  • 池化层:使用 2×2 最大池化

  • 全连接层1:将特征映射展平后映射到 128 个神经元

  • 输出层:映射到 10 类数字

    import torch.nn as nn
    import torch.nn.functional as F

    class CNN(nn.Module):
    def init(self):
    super(CNN, self).init()
    self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
    self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
    self.pool = nn.MaxPool2d(2, 2)
    self.fc1 = nn.Linear(64 * 7 * 7, 128)
    self.fc2 = nn.Linear(128, 10)

    复制代码
      def forward(self, x):
          x = F.relu(self.conv1(x))
          x = self.pool(x)
          x = F.relu(self.conv2(x))
          x = self.pool(x)
          x = x.view(x.size(0), -1)
          x = F.relu(self.fc1(x))
          x = self.fc2(x)
          return x

说明

  • view(x.size(0), -1) 将二维特征图展平为向量,便于全连接层处理。

  • ReLU 激活函数增加非线性能力。


3. 定义训练环境

在训练前,需要将模型移动到 GPU(若可用),并定义损失函数与优化器:

复制代码
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNN().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

说明

  • CrossEntropyLoss 适用于多分类任务。

  • Adam 优化器在大多数场景下表现稳定且收敛速度快。


4. 训练函数

训练函数包括以下步骤:

  1. 将模型设为训练模式 model.train()

  2. 遍历训练数据

  3. 前向传播 → 计算损失 → 反向传播 → 更新参数

  4. 每 100 个 batch 输出一次训练信息

    def train(model, device, train_loader, optimizer, criterion, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(device), target.to(device)
    optimizer.zero_grad()
    output = model(data)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()

    复制代码
         if batch_idx % 100 == 0:
             print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                   f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

5. 测试函数

测试阶段不更新模型参数,只计算平均损失和准确率:

复制代码
def test(model, device, test_loader, criterion):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    
    test_loss /= len(test_loader)
    accuracy = 100. * correct / len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} '
          f'({accuracy:.2f}%)\n')

说明

  • 使用 torch.no_grad() 禁止梯度计算,提高测试效率并减少显存占用。

  • argmax 获取预测类别。


6. 执行训练与测试

设定训练轮数 num_epochs,依次调用训练和测试函数:

复制代码
num_epochs = 5
for epoch in range(1, num_epochs + 1):
    train(model, device, train_loader, optimizer, criterion, epoch)
    test(model, device, test_loader, criterion)

训练与测试结果示例

复制代码
Train Epoch: 1 [0/60000 (0%)]	Loss: 0.384921
...
Test set: Average loss: 0.0543, Accuracy: 9812/10000 (98.12%)

7. 总结

本文使用 PyTorch 实现了一个标准 CNN 模型完成 MNIST 手写数字分类任务,模型结构简单、易于理解,适合作为入门和实验。经过 5 个训练周期,模型可以达到 约 98% 的测试准确率,显示出卷积神经网络在图像分类中的有效性。

后续可以进一步优化:

  • 增加卷积层或残差结构

  • 使用数据增强提升泛化能力

  • 尝试学习率调度器(LR scheduler)或其他优化器

相关推荐
骇城迷影1 天前
Makemore 核心面试题大汇总
人工智能·pytorch·python·深度学习·线性回归
mailangduoduo1 天前
零基础教学连接远程服务器部署项目——VScode版本
服务器·pytorch·vscode·深度学习·ssh·gpu算力
多恩Stone1 天前
【3D AICG 系列-6】OmniPart 训练流程梳理
人工智能·pytorch·算法·3d·aigc
爱吃泡芙的小白白1 天前
深入解析CNN中的BN层:从稳定训练到前沿演进
人工智能·神经网络·cnn·梯度爆炸·bn·稳定模型
水月wwww1 天前
【深度学习】卷积神经网络
人工智能·深度学习·cnn·卷积神经网络
酷酷的崽7981 天前
CANN 开源生态实战:端到端构建高效文本分类服务
分类·数据挖掘·开源
前端摸鱼匠2 天前
YOLOv8 环境配置全攻略:Python、PyTorch 与 CUDA 的和谐共生
人工智能·pytorch·python·yolo·目标检测
纤纡.2 天前
PyTorch 入门精讲:从框架选择到 MNIST 手写数字识别实战
人工智能·pytorch·python
摘星编程2 天前
CANN ops-nn Pooling算子解读:CNN模型下采样与特征提取的核心
人工智能·神经网络·cnn
子榆.2 天前
CANN 与主流 AI 框架集成:从 PyTorch/TensorFlow 到高效推理的无缝迁移指南
人工智能·pytorch·tensorflow