使用 PyTorch RNN 识别手写数字

引言

循环神经网络(RNN)通常被视为处理序列数据的利器,如自然语言处理或时间序列预测。但你是否想过,RNN 也能胜任图像分类任务?本文将介绍如何使用 PyTorch 构建一个基于 LSTM 的模型来处理经典的 MNIST 手写数字识别任务,灵感来源于 TensorFlow Keras 的官方 RNN 示例 。

为什么用 RNN 做图像分类?

MNIST 数据集中的图像是 28×28 的灰度图。传统上,我们会使用 CNN 来提取空间特征。但换个角度:如果将每一行(或列)像素视为一个时间步,整张图片就变成了一个序列------28 个时间步,每个时间步包含 28 个特征。这种视角让 RNN 有了用武之地,同时也展示了 RNN 处理非传统序列数据的能力。

模型架构解析

核心思路是用 LSTM 逐行"阅读"图像,就像阅读文本一样:

python 复制代码
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.rnn = nn.LSTM(input_size=28, hidden_size=64, batch_first=True)
        self.batchnorm = nn.BatchNorm1d(64)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(64, 32)
        self.fc2 = nn.Linear(32, 10)

    def forward(self, input):
        # 将 (batch, 1, 28, 28) 重塑为 (batch, 28, 28) 以适应 RNN
        input = input.reshape(-1, 28, 28)
        output, hidden = self.rnn(input)
        
        # 取序列最后一个时间步的输出
        output = output[:, -1, :]
        output = self.batchnorm(output)
        output = self.dropout1(output)
        output = self.fc1(output)
        output = F.relu(output)
        output = self.dropout2(output)
        output = self.fc2(output)
        return F.log_softmax(output, dim=1)

关键设计点:

  • LSTM 层:输入维度 28(每行像素数),隐藏层 64,自动学习行与行之间的时序依赖

  • 取最后时刻output[:, -1, :] 获取最后行的隐藏状态,聚合了整图信息

  • 正则化:BatchNorm + Dropout 防止过拟合

  • 分类头:两层全连接将 64 维特征映射到 10 个数字类别

数据预处理

python 复制代码
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

使用 MNIST 数据集的标准归一化参数,将像素值缩放到适合神经网络的范围。

训练配置与技巧

优化策略:

  • 使用 Adadelta 替代 SGD,自适应学习率减少调参负担

  • StepLR 调度器每轮衰减学习率,帮助收敛

运行方式

bash 复制代码
# 安装依赖
pip install torch torchvision

# CPU 训练
python main.py

# GPU 加速训练
python main.py --accel

# 快速验证(仅跑一个 batch)
python main.py --dry-run

# 保存训练好的模型
python main.py --save-model

PyTorch 的实现更灵活:

  • 显式控制:前向传播、损失计算、反向传播步骤清晰可见

  • 动态图:调试方便,可在 forward 中打断点检查张量形状

  • 设备管理 :手动控制 device 对象,更透明地管理 CPU/GPU 切换

总结与扩展

这个示例虽小,却展示了 RNN 的灵活性。你可以尝试:

  • 双向 LSTM (bidirectional=True) 同时捕捉上下行信息

  • Attention 机制 替代简单的取最后时刻

  • GRU 替代 LSTM 减少参数量

  • 应用到其他序列化图像数据(如文档扫描件)

完整代码:

python 复制代码
from __future__ import print_function

import argparse

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torchvision import datasets, transforms
from tqdm import tqdm  # 导入 tqdm


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.rnn = nn.LSTM(input_size=28, hidden_size=64, batch_first=True)
        self.batchnorm = nn.BatchNorm1d(64)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(64, 32)
        self.fc2 = nn.Linear(32, 10)

    def forward(self, input):
        # Shape of input is (batch_size,1, 28, 28)
        # converting shape of input to (batch_size, 28, 28)
        # as required by RNN when batch_first is set True
        input = input.reshape(-1, 28, 28)
        output, hidden = self.rnn(input)

        # RNN output shape is (seq_len, batch, input_size)
        # Get last output of RNN
        output = output[:, -1, :]
        output = self.batchnorm(output)
        output = self.dropout1(output)
        output = self.fc1(output)
        output = F.relu(output)
        output = self.dropout2(output)
        output = self.fc2(output)
        output = F.log_softmax(output, dim=1)
        return output


def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    # 使用 tqdm 包装 train_loader,添加动态信息显示
    pbar = tqdm(enumerate(train_loader), total=len(train_loader), 
                desc=f'Epoch {epoch}/{args.epochs}', ncols=100)
    
    for batch_idx, (data, target) in pbar:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        
        # 更新 tqdm 的后缀显示当前 loss 和进度
        pbar.set_postfix({
            'loss': f'{loss.item():.6f}',
            'batch': f'{batch_idx}/{len(train_loader)}'
        })
        
        if args.dry_run:
            break
    
    pbar.close()


def test(args, model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    
    # 使用 tqdm 包装 test_loader
    pbar = tqdm(test_loader, desc='Testing', ncols=100, leave=False)
    
    with torch.no_grad():
        for data, target in pbar:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            
            # 实时更新准确率
            current_acc = 100. * correct / len(test_loader.dataset)
            pbar.set_postfix({'acc': f'{current_acc:.2f}%'})
            
            if args.dry_run:
                break
    
    pbar.close()
    
    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    
    print(f'\nTest set: Average loss: {test_loss:.4f}, '
          f'Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.0f}%)\n')


def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example using RNN')
    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=14, metavar='N',
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
                        help='learning rate (default: 0.1)')
    parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                        help='learning rate step gamma (default: 0.7)')
    parser.add_argument('--accel', action='store_true',
                        help='enables accelerator')
    parser.add_argument('--dry-run', action='store_true',
                        help='quickly check a single pass')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model', action='store_true',
                        help='for Saving the current Model')
    args = parser.parse_args()

    if args.accel:
        device = torch.accelerator.current_accelerator()
    else:
        device = torch.device("cpu")

    torch.manual_seed(args.seed)

    kwargs = {'num_workers': 1, 'pin_memory': True} if args.accel else {}
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=args.batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])),
        batch_size=args.test_batch_size, shuffle=True, **kwargs)

    model = Net().to(device)
    optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    
    # 添加总体进度条显示 epoch 进度
    epoch_pbar = tqdm(range(1, args.epochs + 1), desc='Overall Progress', 
                      position=0, ncols=100)
    
    for epoch in epoch_pbar:
        epoch_pbar.set_postfix({'lr': f'{scheduler.get_last_lr()[0]:.6f}'})
        train(args, model, device, train_loader, optimizer, epoch)
        test(args, model, device, test_loader)
        scheduler.step()
    
    epoch_pbar.close()

    if args.save_model:
        torch.save(model.state_dict(), "mnist_rnn.pt")
        tqdm.write(f"Model saved to mnist_rnn.pt")


if __name__ == '__main__':
    main()

运行结果:

bash 复制代码
Overall Progress:   0%|                                         | 0/14 [00:00<?, ?it/s, lr=0.100000]/home/haichao/MLSTAT/anaconda/envs/mlstat/lib/python3.11/site-packages/torch/nn/functional.py:1538: UserWarning: dropout2d: Received a 2-D input to dropout2d, which is deprecated and will result in an error in a future release. To retain the behavior and silence this warning, please use dropout instead. Note that dropout2d exists to provide channel-wise dropout on inputs with 2 spatial dimensions, a channel dimension, and an optional batch dimension (i.e. 3D or 4D inputs).
  warnings.warn(warn_msg)
Epoch 1/14: 100%|███████████████████| 938/938 [00:12<00:00, 75.10it/s, loss=1.013126, batch=937/938]
Epoch 1/14:  99%|██████████████████▉| 932/938 [00:12<00:00, 89.55it/s, loss=1.013126, batch=937/938]
Test set: Average loss: 0.7654, Accuracy: 7548/10000 (75%)                                          

Epoch 2/14: 100%|███████████████████| 938/938 [00:11<00:00, 84.62it/s, loss=0.579384, batch=937/938]
Epoch 2/14:  99%|██████████████████▉| 933/938 [00:11<00:00, 93.70it/s, loss=0.579384, batch=937/938]
Test set: Average loss: 0.4426, Accuracy: 8544/10000 (85%)                                          

Epoch 3/14: 100%|███████████████████| 938/938 [00:10<00:00, 89.23it/s, loss=0.406930, batch=937/938]
Epoch 3/14:  99%|██████████████████▉| 932/938 [00:10<00:00, 91.98it/s, loss=0.406930, batch=937/938]
Test set: Average loss: 0.3288, Accuracy: 9025/10000 (90%)                                          

Epoch 4/14: 100%|███████████████████| 938/938 [00:10<00:00, 88.78it/s, loss=0.716000, batch=937/938]
Epoch 4/14: 100%|██████████████████▉| 934/938 [00:10<00:00, 92.20it/s, loss=0.716000, batch=937/938]
Test set: Average loss: 0.2768, Accuracy: 9193/10000 (92%)                                          

Epoch 5/14: 100%|███████████████████| 938/938 [00:10<00:00, 88.77it/s, loss=0.157785, batch=937/938]
Epoch 5/14: 100%|██████████████████▉| 934/938 [00:10<00:00, 94.86it/s, loss=0.157785, batch=937/938]
Test set: Average loss: 0.2519, Accuracy: 9251/10000 (93%)                                          

Epoch 6/14: 100%|███████████████████| 938/938 [00:10<00:00, 89.81it/s, loss=0.524307, batch=937/938]
Epoch 6/14:  99%|██████████████████▉| 932/938 [00:10<00:00, 80.61it/s, loss=0.524307, batch=937/938]
Test set: Average loss: 0.2388, Accuracy: 9282/10000 (93%)                                          

Epoch 7/14: 100%|███████████████████| 938/938 [00:10<00:00, 86.33it/s, loss=0.347808, batch=937/938]
Epoch 7/14: 100%|██████████████████▉| 937/938 [00:10<00:00, 80.96it/s, loss=0.347808, batch=937/938]
Test set: Average loss: 0.2286, Accuracy: 9318/10000 (93%)                                          

Epoch 8/14: 100%|███████████████████| 938/938 [00:10<00:00, 87.81it/s, loss=0.483327, batch=937/938]
Epoch 8/14:  99%|██████████████████▊| 930/938 [00:10<00:00, 83.26it/s, loss=0.483327, batch=937/938]
Test set: Average loss: 0.2213, Accuracy: 9335/10000 (93%)                                          

Epoch 9/14: 100%|███████████████████| 938/938 [00:10<00:00, 86.23it/s, loss=0.582588, batch=937/938]
Epoch 9/14:  99%|██████████████████▊| 930/938 [00:10<00:00, 93.04it/s, loss=0.582588, batch=937/938]
Test set: Average loss: 0.2159, Accuracy: 9344/10000 (93%)                                          

Epoch 10/14: 100%|██████████████████| 938/938 [00:10<00:00, 86.33it/s, loss=0.278955, batch=937/938]
Epoch 10/14: 100%|██████████████████| 938/938 [00:10<00:00, 90.55it/s, loss=0.278955, batch=937/938]
Test set: Average loss: 0.2131, Accuracy: 9359/10000 (94%)                                          

Epoch 11/14: 100%|██████████████████| 938/938 [00:10<00:00, 86.26it/s, loss=0.404683, batch=937/938]
Epoch 11/14: 100%|█████████████████▉| 935/938 [00:10<00:00, 79.48it/s, loss=0.404683, batch=937/938]
Test set: Average loss: 0.2110, Accuracy: 9363/10000 (94%)                                          

Epoch 12/14: 100%|██████████████████| 938/938 [00:11<00:00, 79.80it/s, loss=0.250124, batch=937/938]
Epoch 12/14:  99%|█████████████████▉| 932/938 [00:11<00:00, 91.28it/s, loss=0.250124, batch=937/938]
Test set: Average loss: 0.2106, Accuracy: 9359/10000 (94%)                                          

Epoch 13/14: 100%|██████████████████| 938/938 [00:11<00:00, 85.16it/s, loss=0.278060, batch=937/938]
Epoch 13/14:  99%|█████████████████▊| 929/938 [00:11<00:00, 92.83it/s, loss=0.278060, batch=937/938]
Test set: Average loss: 0.2091, Accuracy: 9354/10000 (94%)                                          

Epoch 14/14: 100%|██████████████████| 938/938 [00:10<00:00, 91.87it/s, loss=0.448233, batch=937/938]
Epoch 14/14:  99%|█████████████████▉| 932/938 [00:10<00:00, 94.04it/s, loss=0.448233, batch=937/938]
Test set: Average loss: 0.2083, Accuracy: 9368/10000 (94%)                                          

Overall Progress: 100%|████████████████████████████████| 14/14 [02:42<00:00, 11.63s/it, lr=0.000969]
相关推荐
百***07452 小时前
OpenClaw+一步API实战:本地化AI自动化助手从部署到落地全指南
大数据·人工智能·python
临水逸2 小时前
OpenClaw WebUI 的外网访问配置
人工智能·策略模式
yuanyuan2o22 小时前
【深度学习】AlexNet
人工智能·深度学习
不惑_2 小时前
通俗理解条件生成对抗网络(cGAN)
人工智能·生成对抗网络·计算机视觉
deephub2 小时前
torch.compile 加速原理:kernel 融合与缓冲区复用
人工智能·pytorch·深度学习·神经网络
ydl11282 小时前
解码AI大模型:从神经网络到落地应用的全景探索
人工智能·深度学习·神经网络
小程故事多_802 小时前
Elasticsearch ES 分词与关键词匹配技术方案解析
大数据·人工智能·elasticsearch·搜索引擎·aigc
yuanyuan2o22 小时前
【深度学习】ResNet
人工智能·深度学习
HyperAI超神经2 小时前
覆盖天体物理/地球科学/流变学/声学等19种场景,Polymathic AI构建1.3B模型实现精确连续介质仿真
人工智能·深度学习·学习·算法·机器学习·ai编程·vllm