PyTorch--双向长短期记忆网络(BiRNN)在MNIST数据集上的实现与分析

文章目录

前言

本代码实现了一个基于PyTorch的双向长短期记忆网络(BiRNN),用于对MNIST数据集中的手写数字进行分类。MNIST数据集是一个广泛使用的计算机视觉数据集,包含了大量的手写数字图像,适合用来训练和测试深度学习模型。

代码的关键特点包括:

  1. 数据加载与预处理 :使用torchvision库加载MNIST数据集,并应用了标准化变换以准备数据输入模型。

  2. BiRNN模型定义 :模型使用nn.LSTM模块构建双向LSTM层,能够处理序列数据,并通过nn.Linear层进行最终的分类。

  3. 设备无关性 :通过torch.device自动选择GPU或CPU,提高了代码的通用性。

  4. 训练与测试:实现了模型的训练循环和测试循环,包括损失计算、反向传播和参数更新。

  5. 可视化工具 :集成了数据可视化和模型架构可视化功能,使用matplotlib库展示数据样本和训练进度。

  6. 模型保存 :训练完成后,使用torch.save保存模型参数,方便后续的加载和使用。

  7. 超参数设置:提供了灵活的超参数设置,包括隐藏层大小、层数、批次大小、训练轮数和学习率。

代码结构清晰,易于理解和修改,适合作为深度学习入门和实践的参考。通过本代码,用户可以了解如何使用PyTorch构建和训练一个BiRNN模型,并对MNIST数据集进行分类任务。

说明

  • 确保安装了PyTorch、torchvision和matplotlib。
  • 调整超参数以适应不同的训练需求。
  • 运行代码,观察训练过程和测试结果。
  • 使用可视化工具了解数据和模型架构。

完整代码

python 复制代码
import torch 
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms


# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters
sequence_length = 28
input_size = 28
hidden_size = 128
num_layers = 2
num_classes = 10
batch_size = 100
num_epochs = 2
learning_rate = 0.003

# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='../../data/',
                                           train=True, 
                                           transform=transforms.ToTensor(),
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='../../data/',
                                          train=False, 
                                          transform=transforms.ToTensor())

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size, 
                                          shuffle=False)

# Bidirectional recurrent neural network (many-to-one)
class BiRNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(BiRNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_size*2, num_classes)  # 2 for bidirection
    
    def forward(self, x):
        # Set initial states
        h0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(device) # 2 for bidirection 
        c0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(device)
        
        # Forward propagate LSTM
        out, _ = self.lstm(x, (h0, c0))  # out: tensor of shape (batch_size, seq_length, hidden_size*2)
        
        # Decode the hidden state of the last time step
        out = self.fc(out[:, -1, :])
        return out

model = BiRNN(input_size, hidden_size, num_layers, num_classes).to(device)


# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.reshape(-1, sequence_length, input_size).to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, num_epochs, i+1, total_step, loss.item()))

# Test the model
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.reshape(-1, sequence_length, input_size).to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total)) 

# Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')

代码解析

1.导入库

python 复制代码
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

这部分代码导入了编写神经网络所需的PyTorch库及其子模块。

2.设备配置

python 复制代码
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

根据是否有可用的GPU,设置计算设备,优先使用GPU以加速训练。

3.超参数设置

python 复制代码
sequence_length = 28
input_size = 28
hidden_size = 128
num_layers = 2
num_classes = 10
batch_size = 100
num_epochs = 2
learning_rate = 0.003

设置了模型训练所需的超参数,包括时间序列的长度、输入数据的尺寸、隐藏层的尺寸、LSTM层数、类别数、批次大小、训练轮数和学习率。

4.数据集加载

python 复制代码
train_dataset = torchvision.datasets.MNIST(..., transform=transforms.ToTensor(), download=True)
test_dataset = torchvision.datasets.MNIST(..., transform=transforms.ToTensor())

加载MNIST数据集的训练集和测试集,并使用transforms.ToTensor()将图像数据转换为张量。

5.数据加载器

python 复制代码
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

创建了两个数据加载器,分别用于训练和测试数据的批量加载。

6.定义BiRNN模型

python 复制代码
class BiRNN(nn.Module):
    # 定义双向循环神经网络模型

创建了一个双向LSTM的模型,包含初始化方法和前向传播方法。

7.实例化模型并移动到设备

python 复制代码
model = BiRNN(input_size, hidden_size, num_layers, num_classes).to(device)

实例化BiRNN模型,并将模型移动到之前设置的计算设备上。

8.损失函数和优化器

python 复制代码
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

定义了交叉熵损失函数和Adam优化器。

9.训练模型

python 复制代码
for epoch in range(num_epochs):
    # 训练循环

在每个epoch中,遍历训练数据的每个批次,执行前向传播、计算损失、反向传播和参数更新。

10.测试模型

python 复制代码
with torch.no_grad():
    # 测试循环

在测试阶段,关闭梯度计算,遍历测试数据的每个批次,计算模型的预测准确率。

11.保存模型

python 复制代码
torch.save(model.state_dict(), 'model.ckpt')

保存模型的参数到文件,以便于后续的加载和使用。

这段代码实现了一个完整的训练和测试流程,适合用于分类任务,特别是涉及序列数据的任务。对于MNIST数据集,尽管它不是序列数据,但通过将图像的每一行视为序列的一部分,可以使用RNN进行处理。

常用函数

  1. torch.device

    • 格式:torch.device(device_str)

    • 参数:device_str ------ 指定设备类型(如'cuda''cpu')的字符串。

    • 样式:属性访问器。

    • 示例:

      python 复制代码
      device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  2. torchvision.datasets.MNIST

    • 格式:torchvision.datasets.MNIST(root, train, transform, download)

    • 参数:

      • root ------ 数据集存放的根目录。
      • train ------ 是否加载训练集。
      • transform ------ 对图像进行的变换操作。
      • download ------ 是否下载数据集。
    • 样式:类方法调用。

    • 示例:

      python 复制代码
      train_dataset = torchvision.datasets.MNIST(root='../../data/', train=True, transform=transforms.ToTensor(), download=True)
  3. torchvision.transforms.Compose

    • 格式:torchvision.transforms.Compose(transforms_list)

    • 参数:transforms_list ------ 包含多个变换操作的列表。

    • 样式:类方法调用。

    • 示例:

      python 复制代码
      transform = transforms.Compose([
          transforms.ToTensor(),
          transforms.Normalize((0.1307,), (0.3081,))
      ])
  4. torch.utils.data.DataLoader

    • 格式:torch.utils.data.DataLoader(dataset, batch_size, shuffle)

    • 参数:

      • dataset ------ 加载的数据集。
      • batch_size ------ 每个批次的样本数。
      • shuffle ------ 是否在每个epoch开始时打乱数据。
    • 样式:类方法调用。

    • 示例:

      python 复制代码
      train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
  5. nn.Module

    • 格式:class YourModelClass(nn.Module)

    • 参数:继承自nn.Module的类定义。

    • 样式:类继承。

    • 示例:

      python 复制代码
      class BiRNN(nn.Module):
          def __init__(self, ...):
              super(BiRNN, self).__init__()
              ...
  6. nn.LSTM

    • 格式:nn.LSTM(input_size, hidden_size, num_layers, batch_first, bidirectional)

    • 参数:

      • input_size ------ 输入特征的维度。
      • hidden_size ------ 隐藏层的维度。
      • num_layers ------ LSTM层的数量。
      • batch_first ------ 输入和输出张量的第一个维度是否为批次大小。
      • bidirectional ------ 是否使用双向LSTM。
    • 样式:类方法调用。

    • 示例:

      python 复制代码
      self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True)
  7. nn.Linear

    • 格式:nn.Linear(in_features, out_features)

    • 参数:

      • in_features ------ 输入特征的数量。
      • out_features ------ 输出特征的数量。
    • 样式:类方法调用。

    • 示例:

      python 复制代码
      self.fc = nn.Linear(hidden_size * 2, num_classes)
  8. nn.CrossEntropyLoss

    • 格式:nn.CrossEntropyLoss()

    • 参数:无默认参数。

    • 样式:类方法调用。

    • 示例:

      python 复制代码
      criterion = nn.CrossEntropyLoss()
  9. torch.optim.Adam

    • 格式:torch.optim.Adam(params, lr)

    • 参数:

      • params ------ 模型参数。
      • lr ------ 学习率。
    • 样式:类方法调用。

    • 示例:

      python 复制代码
      optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
  10. .to(device)

    • 格式:.to(device)

    • 参数:device ------ 指定的计算设备。

    • 样式:方法调用。

    • 示例:

      python 复制代码
      images = images.to(device)
  11. .reshape

    • 格式:.reshape(shape)

    • 参数:shape ------ 要重塑成的新形状。

    • 样式:方法调用。

    • 示例:

      python 复制代码
      images = images.reshape(-1, sequence_length, input_size)
  12. torch.zeros

    • 格式:torch.zeros(size, device)

    • 参数:

      • size ------ 张量的形状。
      • device ------ 张量所在的设备。
    • 样式:函数调用。

    • 示例:

      python 复制代码
      h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(device)
  13. torch.max

    • 格式:torch.max(input, dim, keepdim)

    • 参数:

      • input ------ 输入张量。
      • dim ------ 要计算最大值的维度。
      • keepdim ------ 是否保留计算维度。
    • 样式:函数调用。

    • 示例:

      python 复制代码
      _, predicted = torch.max(outputs.data, 1)
  14. torch.no_grad()

    • 格式:torch.no_grad()

    • 参数:无参数。

    • 样式:上下文管理器。

    • 示例:

      python 复制代码
      with torch.no_grad():
          ...
  15. torch.save

    • 格式:torch.save(object, filename)

    • 参数:

      • object ------ 要保存的对象。
      • filename ------ 文件名。
    • 样式:函数调用。

    • 示例:

      python 复制代码
      torch.save(model.state_dict(), 'model.ckpt')
  16. plt.imshow

    • 格式:plt.imshow(X, cmap)

    • 参数:

      • X ------ 要显示的图像数据。
      • cmap ------ 颜色映射。
    • 样式:函数调用。

    • 示例:

      python 复制代码
      plt.imshow(images[j].squeeze().cpu(), cmap='gray')
  17. plt.show

    • 格式:plt.show()

    • 参数:无参数。

    • 样式:函数调用。

    • 示例:

      python 复制代码
      plt.show()
  18. plt.figure

    • 格式:plt.figure(figsize)

    • 参数:figsize ------ 图形的尺寸。

    • 样式:函数调用。

    • 示例:

      python 复制代码
      plt.figure(figsize=(20, 4))
  19. plt.subplot

    • 格式:plt.subplot(nrows, ncols, index)

    • 参数:

      • nrows ------ 子图的行数。
      • ncols ------ 子图的列数。
      • index ------ 当前子图的索引。
    • 样式:函数调用。

    • 示例:

      python 复制代码
      plt.subplot(1, num_samples, j+1)

这些函数覆盖了从数据预处理、模型构建、训练、测试到结果可视化的整个流程。

相关推荐
Eric.Lee20212 分钟前
moviepy将图片序列制作成视频并加载字幕 - python 实现
开发语言·python·音视频·moviepy·字幕视频合成·图像制作为视频
Dontla6 分钟前
vscode怎么设置anaconda python解释器(anaconda解释器、vscode解释器)
ide·vscode·python
深圳南柯电子15 分钟前
深圳南柯电子|电子设备EMC测试整改:常见问题与解决方案
人工智能
Kai HVZ16 分钟前
《OpenCV计算机视觉》--介绍及基础操作
人工智能·opencv·计算机视觉
biter008821 分钟前
opencv(15) OpenCV背景减除器(Background Subtractors)学习
人工智能·opencv·学习
吃个糖糖27 分钟前
35 Opencv 亚像素角点检测
人工智能·opencv·计算机视觉
qq_529025291 小时前
Torch.gather
python·深度学习·机器学习
数据小爬虫@1 小时前
如何高效利用Python爬虫按关键字搜索苏宁商品
开发语言·爬虫·python
Cachel wood1 小时前
python round四舍五入和decimal库精确四舍五入
java·linux·前端·数据库·vue.js·python·前端框架
IT古董1 小时前
【漫话机器学习系列】017.大O算法(Big-O Notation)
人工智能·机器学习