用 PyTorch 轻松实现 MNIST 手写数字识别

用 PyTorch 轻松实现 MNIST 手写数字识别

引言

在深度学习领域,MNIST 数据集就像是 "Hello World" 级别的经典入门项目。它包含大量手写数字图像及对应标签,非常适合新手学习如何搭建和训练神经网络模型。本文将基于 PyTorch 框架,详细拆解如何完成 MNIST 手写数字识别任务,让你轻松入门深度学习实践。

1. 数据加载与预处理

首先,我们利用torchvision库中的datasets.MNIST函数来加载 MNIST 数据集。代码如下:

复制代码
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)
test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

在这段代码中,root="data"指定了数据集的存储路径;train=True表示加载训练集,train=False则用于加载测试集;download=True确保如果本地没有数据集,会自动从网络下载;transform=ToTensor()将图像数据转换为 PyTorch 能够处理的张量格式,同时将像素值从 0-255 归一化到 0-1 区间 。

为了直观感受数据集,我们还可以绘制几张图像:

python

复制代码
from matplotlib import pyplot as plt
figure = plt.figure()
for i in range(9):
    img, label = training_data[i + 59000]
    figure.add_subplot(3, 3, i + 1)
    plt.title(label)
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
    a = img.squeeze()
plt.show()

上述代码从训练集中选取了 9 张图像,绘制出图像及其对应的标签,方便我们对数据有更直观的认识。

接下来,使用DataLoader对数据集进行封装,以方便后续按批次训练和测试:

复制代码
train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

batch_size=64表示每次训练或测试时,模型会同时处理 64 个样本,这有助于提高计算效率和训练稳定性。

2. 模型构建

我们定义一个简单的全连接神经网络类NeuralNetwork

复制代码
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.hidden1 = nn.Linear(28 * 28, 128)
        self.hidden2 = nn.Linear(128, 256)
        self.out = nn.Linear(256, 10)

    def forward(self, x):
        x = self.flatten(x)
        x = self.hidden1(x)
        x = torch.relu(x)
        x = self.hidden2(x)
        x = torch.relu(x)
        x = self.out(x)
        return x

__init__函数中,nn.Flatten()用于将输入的二维图像张量展平为一维向量;nn.Linear()是全连接层,我们依次构建了两个隐藏层和一个输出层,输出层有 10 个神经元,对应 0-9 这 10 个数字类别。在forward函数中,定义了数据的前向传播过程,包括线性变换和激活函数torch.relu()的应用,激活函数能为模型引入非线性,使其能够学习更复杂的模式。

然后将模型移动到合适的设备(GPU 或 CPU)上:

复制代码
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")
model = NeuralNetwork().to(device)
print(model)

3. 训练与测试

3.1 训练函数

复制代码
def train(dataloader, model, loss_fn, optimizer):
    model.train()
    batch_size_num = 1
    for X, y in dataloader:
        X, y = X.to(device), y.to(device)
        pred = model.forward(X)
        loss = loss_fn(pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_value = loss.item()
        if batch_size_num % 100 == 0:
            print(f"loss: {loss_value:>7f} [number:{batch_size_num}]")
        batch_size_num += 1

在训练函数中,首先通过model.train()将模型设置为训练模式,然后遍历数据加载器中的每一批数据。对于每一批数据,将数据和标签移动到指定设备上,进行前向传播计算预测值,通过损失函数nn.CrossEntropyLoss()计算预测值与真实标签之间的损失。接着使用optimizer.zero_grad()清空梯度,loss.backward()进行反向传播计算梯度,最后optimizer.step()根据计算得到的梯度更新模型参数。每训练 100 个批次,打印当前的损失值。

3.2 测试函数

复制代码
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    return test_loss, correct

测试函数中,先将模型设置为评估模式model.eval(),关闭一些在训练过程中使用的操作(如 Dropout)。在测试过程中,不需要计算梯度,因此使用with torch.no_grad()。通过遍历测试数据加载器,计算模型预测结果与真实标签之间的损失,并统计正确预测的样本数量,最后计算平均损失和准确率并打印输出。

3.3 执行训练与测试

复制代码
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
epochs = 10
for t in range(epochs):
    print(f"Epoch {t + 1}\n--------------------")
    train(train_dataloader, model, loss_fn, optimizer)
print("Done!")
test(test_dataloader, model, loss_fn)

我们选择交叉熵损失函数nn.CrossEntropyLoss()作为损失计算方式,Adam 优化器torch.optim.Adam()来更新模型参数,学习率设置为 0.01。通过循环 10 个训练周期,不断训练模型,训练完成后进行测试,得到模型在测试集上的准确率和平均损失。

4. 总结

通过上述步骤,我们基于 PyTorch 完成了 MNIST 手写数字识别任务。从数据加载、模型构建,到训练和测试,每个环节都紧密相连。这个项目不仅让我们熟悉了 PyTorch 的基本使用流程,也对神经网络的工作原理有了更直观的认识。后续我们可以通过调整模型结构、超参数等方式进一步优化模型性能,探索更多深度学习的奥秘。

相关推荐
龙虎榜小红牛系统12 分钟前
Python项目源码57:数据格式转换工具1.0(csv+json+excel+sqlite3)
python·json·excel
新加坡内哥谈技术17 分钟前
谷歌最新推出的Gemini 2.5 Flash人工智能模型因其安全性能相较前代产品出现下滑
人工智能
搏博22 分钟前
神经网络在专家系统中的应用:从符号逻辑到连接主义的融合创新
人工智能·深度学习·神经网络·算法·机器学习
regret~25 分钟前
【论文笔记】SOTR: Segmenting Objects with Transformers
论文阅读·python·深度学习
Eric.Lee202130 分钟前
数据集-目标检测系列- 印度人脸 检测数据集 indian face >> DataBall
人工智能·算法·目标检测·计算机视觉·yolo检测·印度人脸检测
CHNMSCS36 分钟前
PyTorch_点积运算
人工智能·pytorch·python
leeseean891 小时前
使用AI 将文本转成视频 工具 介绍
人工智能·音视频
缘友一世1 小时前
深度学习系统学习系列【1】之基本知识
人工智能·深度学习·学习
feng995201 小时前
从巴别塔到通天塔:Manus AI 如何重构多语言手写识别的智能版图
大数据·人工智能·机器学习
Echo``1 小时前
19:常见的Halcon数据格式
java·linux·图像处理·人工智能·windows·机器学习·视觉检测