PyTorch--双向长短期记忆网络(BiRNN)在MNIST数据集上的实现与分析

文章目录

前言

本代码实现了一个基于PyTorch的双向长短期记忆网络(BiRNN),用于对MNIST数据集中的手写数字进行分类。MNIST数据集是一个广泛使用的计算机视觉数据集,包含了大量的手写数字图像,适合用来训练和测试深度学习模型。

代码的关键特点包括:

  1. 数据加载与预处理 :使用torchvision库加载MNIST数据集,并应用了标准化变换以准备数据输入模型。

  2. BiRNN模型定义 :模型使用nn.LSTM模块构建双向LSTM层,能够处理序列数据,并通过nn.Linear层进行最终的分类。

  3. 设备无关性 :通过torch.device自动选择GPU或CPU,提高了代码的通用性。

  4. 训练与测试:实现了模型的训练循环和测试循环,包括损失计算、反向传播和参数更新。

  5. 可视化工具 :集成了数据可视化和模型架构可视化功能,使用matplotlib库展示数据样本和训练进度。

  6. 模型保存 :训练完成后,使用torch.save保存模型参数,方便后续的加载和使用。

  7. 超参数设置:提供了灵活的超参数设置,包括隐藏层大小、层数、批次大小、训练轮数和学习率。

代码结构清晰,易于理解和修改,适合作为深度学习入门和实践的参考。通过本代码,用户可以了解如何使用PyTorch构建和训练一个BiRNN模型,并对MNIST数据集进行分类任务。

说明

  • 确保安装了PyTorch、torchvision和matplotlib。
  • 调整超参数以适应不同的训练需求。
  • 运行代码,观察训练过程和测试结果。
  • 使用可视化工具了解数据和模型架构。

完整代码

python 复制代码
import torch 
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms


# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters
sequence_length = 28
input_size = 28
hidden_size = 128
num_layers = 2
num_classes = 10
batch_size = 100
num_epochs = 2
learning_rate = 0.003

# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='../../data/',
                                           train=True, 
                                           transform=transforms.ToTensor(),
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='../../data/',
                                          train=False, 
                                          transform=transforms.ToTensor())

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size, 
                                          shuffle=False)

# Bidirectional recurrent neural network (many-to-one)
class BiRNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(BiRNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_size*2, num_classes)  # 2 for bidirection
    
    def forward(self, x):
        # Set initial states
        h0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(device) # 2 for bidirection 
        c0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(device)
        
        # Forward propagate LSTM
        out, _ = self.lstm(x, (h0, c0))  # out: tensor of shape (batch_size, seq_length, hidden_size*2)
        
        # Decode the hidden state of the last time step
        out = self.fc(out[:, -1, :])
        return out

model = BiRNN(input_size, hidden_size, num_layers, num_classes).to(device)


# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.reshape(-1, sequence_length, input_size).to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, num_epochs, i+1, total_step, loss.item()))

# Test the model
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.reshape(-1, sequence_length, input_size).to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total)) 

# Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')

代码解析

1.导入库

python 复制代码
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

这部分代码导入了编写神经网络所需的PyTorch库及其子模块。

2.设备配置

python 复制代码
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

根据是否有可用的GPU,设置计算设备,优先使用GPU以加速训练。

3.超参数设置

python 复制代码
sequence_length = 28
input_size = 28
hidden_size = 128
num_layers = 2
num_classes = 10
batch_size = 100
num_epochs = 2
learning_rate = 0.003

设置了模型训练所需的超参数,包括时间序列的长度、输入数据的尺寸、隐藏层的尺寸、LSTM层数、类别数、批次大小、训练轮数和学习率。

4.数据集加载

python 复制代码
train_dataset = torchvision.datasets.MNIST(..., transform=transforms.ToTensor(), download=True)
test_dataset = torchvision.datasets.MNIST(..., transform=transforms.ToTensor())

加载MNIST数据集的训练集和测试集,并使用transforms.ToTensor()将图像数据转换为张量。

5.数据加载器

python 复制代码
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

创建了两个数据加载器,分别用于训练和测试数据的批量加载。

6.定义BiRNN模型

python 复制代码
class BiRNN(nn.Module):
    # 定义双向循环神经网络模型

创建了一个双向LSTM的模型,包含初始化方法和前向传播方法。

7.实例化模型并移动到设备

python 复制代码
model = BiRNN(input_size, hidden_size, num_layers, num_classes).to(device)

实例化BiRNN模型,并将模型移动到之前设置的计算设备上。

8.损失函数和优化器

python 复制代码
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

定义了交叉熵损失函数和Adam优化器。

9.训练模型

python 复制代码
for epoch in range(num_epochs):
    # 训练循环

在每个epoch中,遍历训练数据的每个批次,执行前向传播、计算损失、反向传播和参数更新。

10.测试模型

python 复制代码
with torch.no_grad():
    # 测试循环

在测试阶段,关闭梯度计算,遍历测试数据的每个批次,计算模型的预测准确率。

11.保存模型

python 复制代码
torch.save(model.state_dict(), 'model.ckpt')

保存模型的参数到文件,以便于后续的加载和使用。

这段代码实现了一个完整的训练和测试流程,适合用于分类任务,特别是涉及序列数据的任务。对于MNIST数据集,尽管它不是序列数据,但通过将图像的每一行视为序列的一部分,可以使用RNN进行处理。

常用函数

  1. torch.device

    • 格式:torch.device(device_str)

    • 参数:device_str ------ 指定设备类型(如'cuda''cpu')的字符串。

    • 样式:属性访问器。

    • 示例:

      python 复制代码
      device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  2. torchvision.datasets.MNIST

    • 格式:torchvision.datasets.MNIST(root, train, transform, download)

    • 参数:

      • root ------ 数据集存放的根目录。
      • train ------ 是否加载训练集。
      • transform ------ 对图像进行的变换操作。
      • download ------ 是否下载数据集。
    • 样式:类方法调用。

    • 示例:

      python 复制代码
      train_dataset = torchvision.datasets.MNIST(root='../../data/', train=True, transform=transforms.ToTensor(), download=True)
  3. torchvision.transforms.Compose

    • 格式:torchvision.transforms.Compose(transforms_list)

    • 参数:transforms_list ------ 包含多个变换操作的列表。

    • 样式:类方法调用。

    • 示例:

      python 复制代码
      transform = transforms.Compose([
          transforms.ToTensor(),
          transforms.Normalize((0.1307,), (0.3081,))
      ])
  4. torch.utils.data.DataLoader

    • 格式:torch.utils.data.DataLoader(dataset, batch_size, shuffle)

    • 参数:

      • dataset ------ 加载的数据集。
      • batch_size ------ 每个批次的样本数。
      • shuffle ------ 是否在每个epoch开始时打乱数据。
    • 样式:类方法调用。

    • 示例:

      python 复制代码
      train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
  5. nn.Module

    • 格式:class YourModelClass(nn.Module)

    • 参数:继承自nn.Module的类定义。

    • 样式:类继承。

    • 示例:

      python 复制代码
      class BiRNN(nn.Module):
          def __init__(self, ...):
              super(BiRNN, self).__init__()
              ...
  6. nn.LSTM

    • 格式:nn.LSTM(input_size, hidden_size, num_layers, batch_first, bidirectional)

    • 参数:

      • input_size ------ 输入特征的维度。
      • hidden_size ------ 隐藏层的维度。
      • num_layers ------ LSTM层的数量。
      • batch_first ------ 输入和输出张量的第一个维度是否为批次大小。
      • bidirectional ------ 是否使用双向LSTM。
    • 样式:类方法调用。

    • 示例:

      python 复制代码
      self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True)
  7. nn.Linear

    • 格式:nn.Linear(in_features, out_features)

    • 参数:

      • in_features ------ 输入特征的数量。
      • out_features ------ 输出特征的数量。
    • 样式:类方法调用。

    • 示例:

      python 复制代码
      self.fc = nn.Linear(hidden_size * 2, num_classes)
  8. nn.CrossEntropyLoss

    • 格式:nn.CrossEntropyLoss()

    • 参数:无默认参数。

    • 样式:类方法调用。

    • 示例:

      python 复制代码
      criterion = nn.CrossEntropyLoss()
  9. torch.optim.Adam

    • 格式:torch.optim.Adam(params, lr)

    • 参数:

      • params ------ 模型参数。
      • lr ------ 学习率。
    • 样式:类方法调用。

    • 示例:

      python 复制代码
      optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
  10. .to(device)

    • 格式:.to(device)

    • 参数:device ------ 指定的计算设备。

    • 样式:方法调用。

    • 示例:

      python 复制代码
      images = images.to(device)
  11. .reshape

    • 格式:.reshape(shape)

    • 参数:shape ------ 要重塑成的新形状。

    • 样式:方法调用。

    • 示例:

      python 复制代码
      images = images.reshape(-1, sequence_length, input_size)
  12. torch.zeros

    • 格式:torch.zeros(size, device)

    • 参数:

      • size ------ 张量的形状。
      • device ------ 张量所在的设备。
    • 样式:函数调用。

    • 示例:

      python 复制代码
      h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(device)
  13. torch.max

    • 格式:torch.max(input, dim, keepdim)

    • 参数:

      • input ------ 输入张量。
      • dim ------ 要计算最大值的维度。
      • keepdim ------ 是否保留计算维度。
    • 样式:函数调用。

    • 示例:

      python 复制代码
      _, predicted = torch.max(outputs.data, 1)
  14. torch.no_grad()

    • 格式:torch.no_grad()

    • 参数:无参数。

    • 样式:上下文管理器。

    • 示例:

      python 复制代码
      with torch.no_grad():
          ...
  15. torch.save

    • 格式:torch.save(object, filename)

    • 参数:

      • object ------ 要保存的对象。
      • filename ------ 文件名。
    • 样式:函数调用。

    • 示例:

      python 复制代码
      torch.save(model.state_dict(), 'model.ckpt')
  16. plt.imshow

    • 格式:plt.imshow(X, cmap)

    • 参数:

      • X ------ 要显示的图像数据。
      • cmap ------ 颜色映射。
    • 样式:函数调用。

    • 示例:

      python 复制代码
      plt.imshow(images[j].squeeze().cpu(), cmap='gray')
  17. plt.show

    • 格式:plt.show()

    • 参数:无参数。

    • 样式:函数调用。

    • 示例:

      python 复制代码
      plt.show()
  18. plt.figure

    • 格式:plt.figure(figsize)

    • 参数:figsize ------ 图形的尺寸。

    • 样式:函数调用。

    • 示例:

      python 复制代码
      plt.figure(figsize=(20, 4))
  19. plt.subplot

    • 格式:plt.subplot(nrows, ncols, index)

    • 参数:

      • nrows ------ 子图的行数。
      • ncols ------ 子图的列数。
      • index ------ 当前子图的索引。
    • 样式:函数调用。

    • 示例:

      python 复制代码
      plt.subplot(1, num_samples, j+1)

这些函数覆盖了从数据预处理、模型构建、训练、测试到结果可视化的整个流程。

相关推荐
kakaZhui3 分钟前
【llm对话系统】大模型源码分析之 LLaMA 位置编码 RoPE
人工智能·深度学习·chatgpt·aigc·llama
struggle20251 小时前
一个开源 GenBI AI 本地代理(确保本地数据安全),使数据驱动型团队能够与其数据进行互动,生成文本到 SQL、图表、电子表格、报告和 BI
人工智能·深度学习·目标检测·语言模型·自然语言处理·数据挖掘·集成学习
佛州小李哥1 小时前
通过亚马逊云科技Bedrock打造自定义AI智能体Agent(上)
人工智能·科技·ai·语言模型·云计算·aws·亚马逊云科技
Mason Lin1 小时前
2025年1月22日(网络编程 udp)
网络·python·udp
清弦墨客2 小时前
【蓝桥杯】43697.机器人塔
python·蓝桥杯·程序算法
云空2 小时前
《DeepSeek 网页/API 性能异常(DeepSeek Web/API Degraded Performance):网络安全日志》
运维·人工智能·web安全·网络安全·开源·网络攻击模型·安全威胁分析
AIGC大时代2 小时前
对比DeepSeek、ChatGPT和Kimi的学术写作关键词提取能力
论文阅读·人工智能·chatgpt·数据分析·prompt
山晨啊83 小时前
2025年美赛B题-结合Logistic阻滞增长模型和SIR传染病模型研究旅游可持续性-成品论文
人工智能·机器学习
RZer3 小时前
Hypium+python鸿蒙原生自动化安装配置
python·自动化·harmonyos
一水鉴天4 小时前
为AI聊天工具添加一个知识系统 之77 详细设计之18 正则表达式 之5
人工智能·正则表达式