MNIST手写数字识别是深度学习领域的经典入门案例,本文将基于PyTorch框架,从零搭建一个卷积神经网络(CNN)来完成这一任务,详细讲解数据加载、网络构建、模型训练与测试的全流程。
一、项目背景与技术栈
MNIST数据集包含70000张28×28像素的手写数字灰度图片,其中60000张作为训练集,10000张作为测试集,任务目标是将图片分类为0-9这10个数字。
本次实战使用的核心技术栈:
• Python:基础编程语言
• PyTorch:主流深度学习框架,提供便捷的神经网络构建、训练接口
• torchvision:PyTorch官方视觉工具库,用于加载MNIST数据集和数据预处理
二、代码实现全解析
1. 环境与库导入
首先导入所需的核心库,包括PyTorch核心模块、数据加载器、数据集和数据转换工具:
python
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
2. 加载MNIST数据集
使用torchvision.datasets下载并加载MNIST数据集,通过ToTensor()将图片从PIL格式转换为张量(Tensor),同时归一化到0-1区间:
python
# 下载训练数据集
training_data = datasets.MNIST(
root='data', # 数据集保存路径
train=True, # 标记为训练集
download=True, # 本地无数据时自动下载
transform=ToTensor(), # 数据转换:PIL→Tensor
)
# 下载测试数据集
test_data = datasets.MNIST(
root='data',
train=False, # 标记为测试集
download=True,
transform=ToTensor(),
)
3. 构建数据加载器(DataLoader)
DataLoader将数据集按批次(batch)划分,既减少内存占用,又能提升训练效率。这里设置批次大小为64:
python
training_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
# 验证数据形状
for X, y in test_dataloader:
print(f"Shape of X [N, C, H, W]: {X.shape}") # 输出:[64, 1, 28, 28](批次量、通道数、高、宽)
print(f"Shape of y: {y.shape} {y.dtype}") # 输出:[64] torch.int64(标签形状、类型)
break
4. 设备配置(CPU/GPU自动适配)
PyTorch支持自动检测并使用GPU(包括NVIDIA CUDA和苹果M系列芯片的MPS),提升训练速度:
python
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"using {device} device")
5. 构建卷积神经网络(CNN)
设计一个多层卷积神经网络,包含卷积层、激活层、池化层和全连接层,核心思路是通过卷积提取图像特征,最终通过全连接层完成分类:
python
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
# 第一层卷积+激活+池化
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),
nn.ReLU(), # 非线性激活,增加模型表达能力
nn.MaxPool2d(kernel_size=2), # 池化层,降维并保留关键特征
)
# 第二层卷积(双层卷积)+激活+池化
self.conv2 = nn.Sequential(
nn.Conv2d(16, 32, 5, 1, 2),
nn.ReLU(),
nn.Conv2d(32, 32, 5, 1, 2),
nn.ReLU(),
nn.MaxPool2d(2)
)
# 第三层卷积+激活
self.conv3 = nn.Sequential(
nn.Conv2d(32, 64, 5, 1, 2),
nn.ReLU(),
)
# 全连接层,映射到10个分类(0-9)
self.out = nn.Linear(64*7*7, 10)
def forward(self, x):
# 前向传播:依次通过各层
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = x.view(x.size(0), -1) # 展平:将多维特征图转为一维向量
output = self.out(x)
return output
# 初始化模型并移至指定设备(CPU/GPU)
model = CNN().to(device)
print(model)
6. 定义训练与测试函数
(1)训练函数
训练过程的核心是:前向传播计算预测值→计算损失→反向传播更新参数:
python
def train(dataloader, model, loss_fn, optimizer):
model.train() # 切换为训练模式(启用梯度更新)
batch_size_num = 1
for X, y in dataloader:
X, y = X.to(device), y.to(device)
# 前向传播
pred = model(X)
loss = loss_fn(pred, y)
# 反向传播
optimizer.zero_grad() # 梯度清零(避免累积)
loss.backward() # 计算梯度
optimizer.step() # 更新参数
# 打印训练进度
loss = loss.item()
if batch_size_num % 100 == 0:
print(f"loss: {loss:>7f} [number:{batch_size_num}]")
batch_size_num += 1
(2)测试函数
测试过程关闭梯度计算,仅评估模型精度和平均损失:
python
def test(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval() # 切换为测试模式(禁用梯度更新)
test_loss, correct = 0, 0
with torch.no_grad(): # 关闭梯度计算,节省内存
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model.forward(X)
test_loss += loss_fn(pred, y).item()
# 计算正确预测数:取预测概率最大的类别与真实标签对比
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
# 计算平均损失和准确率
test_loss /= num_batches
correct /= size
print(f'TEST RESULT: \t Accuracy: {(100*correct)}%, Avg loss: {test_loss}')
7. 模型训练与评估
设置损失函数、优化器和训练轮数,开始训练并验证:
python
# 定义损失函数(交叉熵损失,适用于多分类)
loss_fn = nn.CrossEntropyLoss()
# 定义优化器(Adam优化器,自适应学习率)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# 训练轮数
epochs = 10
for t in range(epochs):
print(f"Epoch {t+1}\n-----------------")
train(training_dataloader, model, loss_fn, optimizer)
print("Done!")
test(test_dataloader, model, loss_fn)
三、关键知识点解读
-
训练/测试模式切换:model.train()启用Dropout、BatchNorm等层的训练行为,model.eval()则固定这些层的参数,避免测试时参数变化。
-
梯度清零:optimizer.zero_grad()必须在反向传播前执行,否则梯度会累积,导致参数更新错误。
-
展平操作:x.view(x.size(0), -1)将卷积层输出的多维特征图转为一维向量,适配全连接层的输入格式。
-
Adam优化器:相比传统SGD,Adam自适应调整学习率,收敛速度更快,是深度学习中常用的优化器。
四、运行结果与优化方向
- 预期结果
训练10轮后,模型在测试集上的准确率可达98%以上,平均损失逐步降低,说明模型有效学习了手写数字的特征。

- 优化方向
• 数据增强:添加随机旋转、平移、缩放等操作,提升模型泛化能力;
• 调整网络结构:增加卷积层/全连接层、调整卷积核数量/大小;
• 超参数调优:调整学习率、批次大小、训练轮数,或尝试不同优化器(如SGD、RMSprop);
• 正则化:添加Dropout层、L2正则化,防止过拟合。
五、总结
本文通过PyTorch实现了基于CNN的MNIST手写数字识别,完整覆盖了数据加载、网络构建、训练测试的全流程。对于深度学习入门者而言,这个案例能帮助理解卷积神经网络的核心原理和PyTorch的基本使用方式。在此基础上,可进一步探索更复杂的数据集(如CIFAR-10)和网络结构(如ResNet、VGG),深化对深度学习的理解。