PyTorch 实现 MNIST 手写数字识别全流程

一、引言

MNIST 数据集是机器学习领域的 "Hello World",包含大量手写数字图片及对应标签。本文将使用 PyTorch 框架,从数据准备、模型构建到训练与可视化,完整实现 MNIST 手写数字识别任务,帮助初学者快速上手深度学习图像分类。

二、环境准备与库导入

首先导入所需的库,包括 NumPy 用于数值计算,PyTorch 相关模块用于构建模型、处理数据,以及 Matplotlib 用于可视化。

python

运行

复制代码
import numpy as np
import torch
from torchvision.datasets import mnist
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
from torch import nn
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
%matplotlib inline

三、数据准备

(一)超参数定义

设置批次大小、学习率和训练轮数等超参数,这些参数会影响模型的训练过程和结果。

python

运行

复制代码
train_batch_size = 64
test_batch_size = 128
learning_rate = 0.01
num_epoches = 20

(二)数据预处理

使用 transforms 对数据进行预处理,将图像转为张量并标准化,使模型训练更稳定。然后通过 DataLoader 加载数据集,实现数据的批量读取和打乱。

python

运行

复制代码
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])

train_dataset = mnist.MNIST('../data/', train=True, transform=transform, download=True)
test_dataset = mnist.MNIST('../data/', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

(三)数据可视化

为了直观了解数据,从测试集中取出部分数据进行可视化展示,查看手写数字的真实样子和标签。

python

运行

复制代码
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)

fig = plt.figure()
for i in range(6):
    plt.subplot(2, 3, i + 1)
    plt.tight_layout()
    plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
    plt.title("Ground Truth: {}".format(example_targets[i]))
    plt.xticks([])
    plt.yticks([])

四、模型构建

定义一个基于 nn.Module 的神经网络类 Net,使用 Sequential 组合网络层,包括展平层、带批量归一化的线性层和激活函数等,最后通过 softmax 输出分类概率。

python

运行

复制代码
class Net(nn.Module):
    def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
        super(Net, self).__init__()
        self.flatten = nn.Flatten()
        self.layer1 = nn.Sequential(nn.Linear(in_dim, n_hidden_1), nn.BatchNorm1d(n_hidden_1))
        self.layer2 = nn.Sequential(nn.Linear(n_hidden_1, n_hidden_2), nn.BatchNorm1d(n_hidden_2))
        self.out = nn.Sequential(nn.Linear(n_hidden_2, out_dim))

    def forward(self, x):
        x = self.flatten(x)
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        x = F.softmax(self.out(x), dim=1)
        return x

五、模型训练与评估

(一)实例化模型与设置优化器

选择运行设备(GPU 或 CPU),实例化模型并移至对应设备,定义损失函数和优化器,这里使用交叉熵损失和 SGD 优化器。

python

运行

复制代码
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = Net(28 * 28, 300, 100, 10)
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)

(二)训练循环

在多个 epoch 中训练模型,每个 epoch 包括训练阶段和测试阶段。训练时计算损失并反向传播更新参数,测试时评估模型在测试集上的性能,同时记录损失和准确率用于后续可视化。

python

运行

复制代码
losses = []
acces = []
eval_losses = []
eval_acces = []
writer = SummaryWriter(log_dir='logs', comment='train-loss')

for epoch in range(num_epoches):
    train_loss = 0
    train_acc = 0
    model.train()
    if epoch % 5 == 0:
        optimizer.param_groups[0]['lr'] *= 0.9
    print('学习率:{:.6f}'.format(optimizer.param_groups[0]['lr']))
    for img, label in train_loader:
        img = img.to(device)
        label = label.to(device)
        out = model(img)
        loss = criterion(out, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        writer.add_scalar('Train', train_loss / len(train_loader), epoch)
        _, pred = out.max(1)
        num_correct = (pred == label).sum().item()
        acc = num_correct / img.shape[0]
        train_acc += acc

    losses.append(train_loss / len(train_loader))
    acces.append(train_acc / len(train_loader))
    eval_loss = 0
    eval_acc = 0
    model.eval()
    for img, label in test_loader:
        img = img.to(device)
        label = label.to(device)
        img = img.view(img.size(0), -1)
        out = model(img)
        loss = criterion(out, label)
        eval_loss += loss.item()
        _, pred = out.max(1)
        num_correct = (pred == label).sum().item()
        acc = num_correct / img.shape[0]
        eval_acc += acc

    eval_losses.append(eval_loss / len(test_loader))
    eval_acces.append(eval_acc / len(test_loader))
    print('epoch: {}, Train Loss: {:.4f}, Train Acc: {:.4f}, Test Loss: {:.4f}, Test Acc: {:.4f}'
          .format(epoch, train_loss / len(train_loader), train_acc / len(train_loader),
                  eval_loss / len(test_loader), eval_acc / len(test_loader)))

(三)损失可视化

训练完成后,绘制训练损失曲线,直观观察模型训练过程中损失的变化情况。

python

运行

复制代码
plt.title('train loss')
plt.plot(np.arange(len(losses)), losses)
plt.legend(['Train Loss'], loc='upper right')

六、总结

本文详细介绍了使用 PyTorch 实现 MNIST 手写数字识别的全流程,包括数据准备、模型构建、训练评估与可视化。通过这个经典任务,能帮助初学者熟悉深度学习图像分类的基本步骤和 PyTorch 的使用方法。在实际应用中,还可进一步优化模型结构、调整超参数或使用数据增强等方法提升模型性能。

相关推荐
忘却的旋律dw1 小时前
使用LLM模型的tokenizer报错AttributeError: ‘dict‘ object has no attribute ‘model_type‘
人工智能·pytorch·python
studytosky3 小时前
深度学习理论与实战:MNIST 手写数字分类实战
人工智能·pytorch·python·深度学习·机器学习·分类·matplotlib
哥布林学者3 小时前
吴恩达深度学习课程三: 结构化机器学习项目 第一周:机器学习策略(二)数据集设置
深度学习·ai
【建模先锋】5 小时前
精品数据分享 | 锂电池数据集(四)PINN+锂离子电池退化稳定性建模和预测
深度学习·预测模型·pinn·锂电池剩余寿命预测·锂电池数据集·剩余寿命
九年义务漏网鲨鱼5 小时前
【大模型学习】现代大模型架构(二):旋转位置编码和SwiGLU
深度学习·学习·大模型·智能体
CoovallyAIHub5 小时前
破局红外小目标检测:异常感知Anomaly-Aware YOLO以“俭”驭“繁”
深度学习·算法·计算机视觉
云雾J视界5 小时前
AI芯片设计实战:用Verilog高级综合技术优化神经网络加速器功耗与性能
深度学习·神经网络·verilog·nvidia·ai芯片·卷积加速器
噜~噜~噜~14 小时前
最大熵原理(Principle of Maximum Entropy,MaxEnt)的个人理解
深度学习·最大熵原理
小女孩真可爱15 小时前
大模型学习记录(五)-------调用大模型API接口
pytorch·深度学习·学习
水月wwww19 小时前
深度学习——神经网络
人工智能·深度学习·神经网络