用PyTorch搭建卷积神经网络实现MNIST手写数字识别

用PyTorch搭建卷积神经网络实现MNIST手写数字识别

在深度学习领域,卷积神经网络(Convolutional Neural Network,简称CNN)是处理图像数据的强大工具。它通过卷积层、池化层和全连接层等组件,自动提取图像特征,在图像分类、目标检测等任务中表现卓越。本文将使用PyTorch框架,搭建一个CNN模型来实现MNIST手写数字识别,并详细解析每一步代码。

一、MNIST数据集介绍

MNIST数据集是深度学习领域经典的入门数据集,包含70,000张手写数字图像,其中60,000张用于训练,10,000张用于测试。这些图像均为灰度图,尺寸是28x28像素,并且已经做了居中处理,这在一定程度上减少了预处理的工作量,能够加快模型的训练和运行速度。

二、环境准备与数据加载

2.1 导入必要的库

python 复制代码
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

上述代码导入了PyTorch的核心库、神经网络模块、数据加载工具以及用于图像数据处理和数据集管理的库。

2.2 下载并加载数据集

python 复制代码
training_data = datasets.MNIST(
    root='data',
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.MNIST(
    root='data',
    train=False,
    download=True,
    transform=ToTensor()
)

通过datasets.MNIST函数分别下载训练集和测试集。root参数指定数据下载的路径;train=True表示下载训练集数据,train=False则表示下载测试集数据;download=True确保如果数据尚未下载,会自动进行下载;transform=ToTensor()将图像数据转换为PyTorch能够处理的张量格式。

2.3 数据可视化

python 复制代码
from matplotlib import pyplot as plt
figure = plt.figure()
for i in range(9):
    img, label = training_data[i + 59000]
    figure.add_subplot(3, 3, i + 1)
    plt.title(label)
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

这段代码使用matplotlib库展示了训练数据集中的部分手写数字图像,通过plt.imshow函数将张量格式的图像数据可视化,直观感受MNIST数据集的内容。

2.4 创建数据加载器

python 复制代码
train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

DataLoader用于将数据集打包成批次,batch_size参数指定每个批次包含的数据样本数量。将数据集分成批次进行训练,能够有效减少内存使用,并提高训练速度。

三、设备配置

python 复制代码
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")

这段代码检测当前设备是否支持GPU(CUDA)或苹果M系列芯片的GPU(MPS),如果都不支持,则使用CPU进行计算。后续模型和数据都会被移动到选定的设备上运行,以充分利用硬件资源加速训练。

四、定义卷积神经网络模型

python 复制代码
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels=1,
                out_channels=16,
                kernel_size=5,
                stride=1,
                padding=2
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 32, 5, 1, 2),
            nn.ReLU(),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.ReLU()
        )
        self.out = nn.Linear(64 * 7 * 7, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x.view(x.size(0), -1)
        output = self.out(x)
        return output

在这个自定义的CNN类中,继承自nn.Module__init__方法中定义了网络的结构:

  • 卷积层(nn.Conv2d :用于提取图像特征,通过设置in_channels(输入通道数)、out_channels(输出通道数,即卷积核个数)、kernel_size(卷积核大小)、stride(步长)和padding(填充)等参数,控制卷积操作。
  • 激活函数层(nn.ReLU:引入非线性,增强网络的表达能力。
  • 池化层(nn.MaxPool2d:对特征图进行下采样,减少数据量和计算量,同时保留主要特征。
  • 全连接层(nn.Linear:将卷积层和池化层提取的特征映射到输出类别(MNIST数据集中有10个数字类别)。

forward方法定义了数据在网络中的前向传播路径,确保数据按照网络结构依次经过各层处理,最终输出预测结果。

五、训练与测试模型

5.1 定义损失函数和优化器

python 复制代码
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

nn.CrossEntropyLoss是适用于多分类任务的交叉熵损失函数,用于计算模型预测结果与真实标签之间的差距。torch.optim.Adam是一种常用的优化器,通过调整模型的参数(model.parameters())来最小化损失函数,lr参数设置学习率,控制参数更新的步长。

5.2 训练函数

python 复制代码
def train(dataloader, model, loss_fn, optimizer):
    model.train()
    batch_size_num = 1
    for X, y in dataloader:
        X, y = X.to(device), y.to(device)
        pred = model(X)
        loss = loss_fn(pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_value = loss.item()
        if batch_size_num % 100 == 0:
            print(f'loss:{loss_value:>7f} [number:{batch_size_num}]')
        batch_size_num += 1

在训练函数中:

  • model.train()将模型设置为训练模式,此时模型中的一些层(如Dropout层)会按照训练规则工作。
  • 遍历数据加载器中的每一个批次数据,将数据和标签移动到指定设备上。
  • 通过模型进行预测,计算损失值。
  • 使用optimizer.zero_grad()清零梯度,loss.backward()进行反向传播计算梯度,optimizer.step()根据梯度更新模型参数。
  • 每隔100个批次,打印当前的损失值,以便观察训练过程中的损失变化。

5.3 测试函数

python 复制代码
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f'Test result: \n Accuracy: {(100 * correct)}%, Avg loss: {test_loss}')

测试函数中:

  • model.eval()将模型设置为测试模式,关闭一些在训练过程中起作用但在测试时不需要的操作(如Dropout)。
  • 使用with torch.no_grad()上下文管理器,关闭梯度计算,因为在测试阶段不需要更新模型参数,这样可以节省计算资源。
  • 遍历测试数据,计算每个批次的损失值并累加,同时统计预测正确的样本数量。
  • 最后计算并打印测试集上的平均损失和准确率,评估模型的性能。

5.4 执行训练和测试

python 复制代码
epoch = 9
for i in range(epoch):
    print(i + 1)
    train(train_dataloader, model, loss_fn, optimizer)

test(test_dataloader, model, loss_fn)

通过设置训练轮数(epoch),循环调用训练函数进行模型训练,每一轮训练结束后,调用测试函数评估模型在测试集上的性能。

六、总结

本文通过详细的代码解析,展示了如何使用PyTorch搭建一个简单的卷积神经网络来实现MNIST手写数字识别任务。从数据加载、模型定义,到训练和测试,每一个步骤都体现了CNN在图像分类任务中的核心思想和实现方法。通过不断调整模型结构、超参数等,还可以进一步提升模型的性能。卷积神经网络在图像领域的应用远不止于此,它在更复杂的图像任务和其他领域也有着广泛的应用前景,希望本文能为大家深入学习深度学习提供一个良好的开端。

相关推荐
Echo``2 分钟前
2:点云处理—3D相机开发
人工智能·笔记·数码相机·算法·计算机视觉·3d·视觉检测
create1720 分钟前
使用 AI 如何高效解析视频内容?生成思维导图或分时段概括总结
人工智能·aigc·语音识别·ai写作
TIANE-Kimmy30 分钟前
FID和IS的区别
人工智能·深度学习·计算机视觉
知舟不叙1 小时前
使用OpenCV 和 Dlib 实现人脸融合技术
人工智能·opencv·计算机视觉·人脸融合
碣石潇湘无限路1 小时前
【AI】基于生活案例的LLM强化学习(入门帖)
人工智能·经验分享·笔记·生活·openai·强化学习
HumbleSwage2 小时前
深度学习中常用的符号表达式
人工智能·深度学习
缘友一世2 小时前
深度学习系统学习系列【3】之血液检测项目
人工智能·深度学习·学习
bullnfresh2 小时前
神经网络语言模型(NNLM)的原理与实现
人工智能·神经网络·语言模型
归去_来兮2 小时前
基于主成分分析(PCA)的数据降维
人工智能·数据分析·数据降维·主成分分析(pca)
大势智慧2 小时前
世界无人机大会将至,大势智慧以“AI+实景三维”赋能低空经济
人工智能·无人机