使用 PyTorch 实现 MNIST 手写数字分类的 CNN 模型

手写数字识别是计算机视觉领域的经典任务,而 MNIST 数据集提供了一个标准化的基准。本文将使用 PyTorch 框架构建一个标准的卷积神经网络(CNN),对 MNIST 数据集进行分类,并展示完整的训练和测试流程。

1. 环境准备与数据加载

在深度学习中,数据预处理是至关重要的步骤。MNIST 数据集包含 28×28 灰度图像,训练集 60000 张,测试集 10000 张。我们使用 torchvision 提供的工具来加载数据,并进行标准化处理:

复制代码
import torch
from torchvision import datasets, transforms

# 数据预处理:转换为 Tensor 并标准化到 [-1, 1]
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 加载训练集和测试集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset  = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# 数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader  = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)

说明

  • ToTensor() 会将图像从 [0, 255] 转换为 [0, 1] 的浮点数张量。

  • Normalize(mean, std) 将张量标准化为 (x - mean) / std),有助于加快训练收敛。


2. 构建卷积神经网络

我们定义一个两层卷积的 CNN 模型,结构如下:

  • 卷积层1:输入 1 通道 → 输出 32 通道,卷积核 3×3

  • 卷积层2:输入 32 通道 → 输出 64 通道,卷积核 3×3

  • 池化层:使用 2×2 最大池化

  • 全连接层1:将特征映射展平后映射到 128 个神经元

  • 输出层:映射到 10 类数字

    import torch.nn as nn
    import torch.nn.functional as F

    class CNN(nn.Module):
    def init(self):
    super(CNN, self).init()
    self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
    self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
    self.pool = nn.MaxPool2d(2, 2)
    self.fc1 = nn.Linear(64 * 7 * 7, 128)
    self.fc2 = nn.Linear(128, 10)

    复制代码
      def forward(self, x):
          x = F.relu(self.conv1(x))
          x = self.pool(x)
          x = F.relu(self.conv2(x))
          x = self.pool(x)
          x = x.view(x.size(0), -1)
          x = F.relu(self.fc1(x))
          x = self.fc2(x)
          return x

说明

  • view(x.size(0), -1) 将二维特征图展平为向量,便于全连接层处理。

  • ReLU 激活函数增加非线性能力。


3. 定义训练环境

在训练前,需要将模型移动到 GPU(若可用),并定义损失函数与优化器:

复制代码
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNN().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

说明

  • CrossEntropyLoss 适用于多分类任务。

  • Adam 优化器在大多数场景下表现稳定且收敛速度快。


4. 训练函数

训练函数包括以下步骤:

  1. 将模型设为训练模式 model.train()

  2. 遍历训练数据

  3. 前向传播 → 计算损失 → 反向传播 → 更新参数

  4. 每 100 个 batch 输出一次训练信息

    def train(model, device, train_loader, optimizer, criterion, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(device), target.to(device)
    optimizer.zero_grad()
    output = model(data)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()

    复制代码
         if batch_idx % 100 == 0:
             print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                   f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

5. 测试函数

测试阶段不更新模型参数,只计算平均损失和准确率:

复制代码
def test(model, device, test_loader, criterion):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    
    test_loss /= len(test_loader)
    accuracy = 100. * correct / len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} '
          f'({accuracy:.2f}%)\n')

说明

  • 使用 torch.no_grad() 禁止梯度计算,提高测试效率并减少显存占用。

  • argmax 获取预测类别。


6. 执行训练与测试

设定训练轮数 num_epochs,依次调用训练和测试函数:

复制代码
num_epochs = 5
for epoch in range(1, num_epochs + 1):
    train(model, device, train_loader, optimizer, criterion, epoch)
    test(model, device, test_loader, criterion)

训练与测试结果示例

复制代码
Train Epoch: 1 [0/60000 (0%)]	Loss: 0.384921
...
Test set: Average loss: 0.0543, Accuracy: 9812/10000 (98.12%)

7. 总结

本文使用 PyTorch 实现了一个标准 CNN 模型完成 MNIST 手写数字分类任务,模型结构简单、易于理解,适合作为入门和实验。经过 5 个训练周期,模型可以达到 约 98% 的测试准确率,显示出卷积神经网络在图像分类中的有效性。

后续可以进一步优化:

  • 增加卷积层或残差结构

  • 使用数据增强提升泛化能力

  • 尝试学习率调度器(LR scheduler)或其他优化器

相关推荐
BHXDML2 小时前
基于卷积神经网络通用手写体识别应用实验
人工智能·神经网络·cnn
WJSKad12354 小时前
火腿切片表面缺陷检测与分类_YOLOv26模型实现与优化详解
yolo·分类·数据挖掘
爱学习的张大4 小时前
transform基础练习(从细节里面理解)
人工智能·pytorch·深度学习
轴测君4 小时前
卷积神经网络的开端:LeNet−5
人工智能·神经网络·cnn
AI街潜水的八角4 小时前
医学图像算法之基于MK_UNet的肾小球分割系统1:数据集说明(含下载链接)
pytorch·深度学习
果粒蹬i5 小时前
你的第一个神经网络:用PyTorch/Keras实现手写数字识别
pytorch·神经网络·keras
AI即插即用5 小时前
即插即用系列 | AAAI 2025 Mesorch:CNN与Transformer的双剑合璧:基于频域增强与自适应剪枝的篡改定位
人工智能·深度学习·神经网络·计算机视觉·cnn·transformer·剪枝
Faker66363aaa6 小时前
YOLOv26樱桃缺陷检测与分类算法实现含Python源码_计算机视觉
python·yolo·分类
奔袭的算法工程师17 小时前
CRN源码详细解析(4)-- 图像骨干网络之DepthNet和ViewAggregation
人工智能·pytorch·深度学习·目标检测·自动驾驶