PyTorch:基于MNIST的手写数字识别

文章目录

    • [1. 深度学习与PyTorch简介](#1. 深度学习与PyTorch简介)
    • [2. 环境配置与数据准备](#2. 环境配置与数据准备)
      • [2.1 环境检查](#2.1 环境检查)
      • [2.2 数据加载与预处理](#2.2 数据加载与预处理)
      • [2.3 数据可视化](#2.3 数据可视化)
      • [2.4 数据批量加载](#2.4 数据批量加载)
    • [3. 神经网络模型设计](#3. 神经网络模型设计)
      • [3.1 设备选择](#3.1 设备选择)
      • [3.2 神经网络架构](#3.2 神经网络架构)
      • [3.3 模型实例化](#3.3 模型实例化)
    • [4. 训练与评估流程](#4. 训练与评估流程)
      • [4.1 训练函数](#4.1 训练函数)
      • [4.2 测试函数](#4.2 测试函数)
    • [5. 损失函数配置](#5. 损失函数配置)
    • [6. 模型训练与评估](#6. 模型训练与评估)
      • [6.1 优化器配置](#6.1 优化器配置)
      • [6.2 单次训练与测试](#6.2 单次训练与测试)
      • [6.3 多轮训练(可选)](#6.3 多轮训练(可选))
    • [7. 提高准确率的优化方式](#7. 提高准确率的优化方式)

1. 深度学习与PyTorch简介

深度学习作为机器学习的重要分支,已在计算机视觉、自然语言处理等领域取得了显著成果。PyTorch是由Facebook开源的深度学习框架,以其动态计算图和直观的API设计而广受欢迎。本文以经典的MNIST手写数字数据集为例,展示如何利用PyTorch框架构建并训练深度学习模型。

2. 环境配置与数据准备

2.1 环境检查

首先检查PyTorch及相关库的版本,确保环境配置正确:

python 复制代码
import torch
import torchvision
import torchaudio
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
from matplotlib import pyplot as plt

print(torch.__version__)
print(torchaudio.__version__)
print(torchvision.__version__)

2.2 数据加载与预处理

MNIST数据集包含60,000个训练样本和10,000个测试样本,每个样本为28×28像素的灰度手写数字图像。

python 复制代码
training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

参数

  • root:数据存储路径
  • train:是否为训练集
  • download:是否自动下载
  • transform:数据预处理转换,ToTensor()将PIL图像转换为张量并归一化到[0,1]

2.3 数据可视化

我们可以查看数据集的样本分布:

python 复制代码
print(len(training_data))

figure = plt.figure()
for i in range(9):
    img, label = training_data[i + 59000]
    figure.add_subplot(3, 3, i + 1)
    plt.title(label)
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

2.4 数据批量加载

使用DataLoader实现数据的批量加载和随机打乱:

python 复制代码
# 增加批次大小
train_dataloader = DataLoader(training_data, batch_size=128)  # 增大batch size
test_dataloader = DataLoader(test_data, batch_size=128)

for X, y in test_dataloader:
    print(f"Shape of X[N,C,H,W]:{X.shape}")
    print(f"Shape of y:{y.shape} {y.dtype}")
    break

3. 神经网络模型设计

3.1 设备选择

根据可用硬件选择计算设备:

python 复制代码
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")

3.2 神经网络架构

设计一个包含多个全连接层的深度神经网络:

python 复制代码
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.a = 10
        self.flatten = nn.Flatten()
        原始架构
        self.hidden1 = nn.Linear(28 * 28, 128)
        self.hidden2 = nn.Linear(128, 256)
        self.out = nn.Linear(256, 10)
        
    
    def forward(self, x):
        # 原始前向传播
        x = self.flatten(x)
        x = self.hidden1(x)
        x = torch.sigmoid(x)
        x = self.hidden2(x)
        x = torch.sigmoid(x)
        return x

3.3 模型实例化

python 复制代码
model = NeuralNetwork().to(device)
print(model)

4. 训练与评估流程

4.1 训练函数

python 复制代码
def train(dataloader, model, loss_fn, optimizer):
    model.train()
    batch_size_num = 1
    for X, y in dataloader:
        X, y = X.to(device), y.to(device)
        pred = model.forward(X)
        loss = loss_fn(pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_value = loss.item()
        if batch_size_num % 100 == 0:
            print(f"loss: {loss_value:>7f} [number:{batch_size_num}]")
        batch_size_num += 1

训练步骤

  1. model.train():设置为训练模式(启用Dropout)
  2. 前向传播计算预测值
  3. 计算损失函数值
  4. optimizer.zero_grad():清空梯度
  5. loss.backward():反向传播计算梯度
  6. optimizer.step():更新模型参数

4.2 测试函数

python 复制代码
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model.forward(X)
            test_loss = loss_fn(pred, y)
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            a = (pred.argmax(1) == y)
            b = (pred.argmax(1) == y).type(torch.float)
    test_loss /= num_batches
    correct /= size

    print(f"Test result:\n Accuracy:{(100 * correct):.2f}%, Avg loss: {test_loss}")

测试要点

  • model.eval():设置为评估模式(禁用Dropout)
  • torch.no_grad():禁用梯度计算,节省内存
  • pred.argmax(1):获取预测类别

5. 损失函数配置

python 复制代码
loss_fn = nn.CrossEntropyLoss()

损失函数说明

  • 使用CrossEntropyLoss,适用于多分类问题
  • 结合了LogSoftmax和NLLLoss,直接输出分类概率

6. 模型训练与评估

6.1 优化器配置

python 复制代码
# 原始优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

6.2 单次训练与测试

python 复制代码
train(train_dataloader, model, loss_fn, optimizer)
test(train_dataloader, model, loss_fn)

6.3 多轮训练(可选)

python 复制代码
epochs = 10
for t in range(epochs):
    print(f"Epoch {t+1}\n----------------------")
    train(train_dataloader, model, loss_fn, optimizer)
print("Done!")
test(test_dataloader, model, loss_fn)

7. 提高准确率的优化方式

  1. 层数增加:从2层隐藏层增加到3层,增强模型表达能力
  2. 神经元增加:第一层从128个神经元增加到512个
  3. 激活函数:用ReLU替代sigmoid,缓解梯度消失问题
  4. 正则化:添加Dropout层(0.2丢弃率),防止过拟合
  5. 改进优化器:降低学习率
python 复制代码
        # 改进架构
        self.hidden1 = nn.Linear(28 * 28, 512)  # 增加神经元
        self.dropout1 = nn.Dropout(0.2)  # 添加Dropout
        self.hidden2 = nn.Linear(512, 256)
        self.dropout2 = nn.Dropout(0.2)  # 添加Dropout
        self.hidden3 = nn.Linear(256, 128)  # 增加一层
        self.out = nn.Linear(128, 10)
python 复制代码
        # 改进的前向传播
        x = self.flatten(x)
        x = self.hidden1(x)
        x = torch.relu(x)  # 使用ReLU替代sigmoid
        x = self.dropout1(x)  # 训练时随机丢弃
        x = self.hidden2(x)
        x = torch.relu(x)  # 使用ReLU替代sigmoid
        x = self.dropout2(x)  # 训练时随机丢弃
        x = self.hidden3(x)
        x = torch.relu(x)
        x = self.out(x)
python 复制代码
# 改进优化器
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # 降低学习率
相关推荐
深蓝电商API18 小时前
Scrapy爬虫限速与并发控制最佳实践
爬虫·python·scrapy
Derrick__118 小时前
淘宝MD5爬虫
爬虫·python
薛定谔的猫198218 小时前
llama-index Embedding 落地到 RAG 系统
开发语言·人工智能·python·llama-index
翱翔的苍鹰19 小时前
神经网络中损失函数(Loss Function)介绍
人工智能·深度学习·神经网络
元智启19 小时前
企业AI应用面临“敏捷响应”难题:快速变化的业务与相对滞后的智能如何同步?
人工智能·深度学习·机器学习
nimadan1220 小时前
**手机小说扫榜工具2025推荐,精准追踪榜单动态与题材风向
python·智能手机
编程武士20 小时前
Python 各版本主要变化速览
开发语言·python
Hcoco_me20 小时前
大模型面试题63:介绍一下RLHF
人工智能·深度学习·机器学习·chatgpt·机器人
傻啦嘿哟20 小时前
Python中的@property:优雅控制类成员访问的魔法
前端·数据库·python
向量引擎小橙20 小时前
“2026数据枯竭”警报拉响:合成数据如何成为驱动AI进化的“新石油”?
大数据·人工智能·深度学习·集成学习