文章目录
前言
本代码实现了一个基于PyTorch的双向长短期记忆网络(BiRNN),用于对MNIST数据集中的手写数字进行分类。MNIST数据集是一个广泛使用的计算机视觉数据集,包含了大量的手写数字图像,适合用来训练和测试深度学习模型。
代码的关键特点包括:
-
数据加载与预处理 :使用
torchvision库加载MNIST数据集,并应用了标准化变换以准备数据输入模型。 -
BiRNN模型定义 :模型使用
nn.LSTM模块构建双向LSTM层,能够处理序列数据,并通过nn.Linear层进行最终的分类。 -
设备无关性 :通过
torch.device自动选择GPU或CPU,提高了代码的通用性。 -
训练与测试:实现了模型的训练循环和测试循环,包括损失计算、反向传播和参数更新。
-
可视化工具 :集成了数据可视化和模型架构可视化功能,使用
matplotlib库展示数据样本和训练进度。 -
模型保存 :训练完成后,使用
torch.save保存模型参数,方便后续的加载和使用。 -
超参数设置:提供了灵活的超参数设置,包括隐藏层大小、层数、批次大小、训练轮数和学习率。
代码结构清晰,易于理解和修改,适合作为深度学习入门和实践的参考。通过本代码,用户可以了解如何使用PyTorch构建和训练一个BiRNN模型,并对MNIST数据集进行分类任务。
说明
- 确保安装了PyTorch、torchvision和matplotlib。
- 调整超参数以适应不同的训练需求。
- 运行代码,观察训练过程和测试结果。
- 使用可视化工具了解数据和模型架构。
完整代码
python
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Hyper-parameters
sequence_length = 28
input_size = 28
hidden_size = 128
num_layers = 2
num_classes = 10
batch_size = 100
num_epochs = 2
learning_rate = 0.003
# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='../../data/',
train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset = torchvision.datasets.MNIST(root='../../data/',
train=False,
transform=transforms.ToTensor())
# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False)
# Bidirectional recurrent neural network (many-to-one)
class BiRNN(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, num_classes):
super(BiRNN, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True)
self.fc = nn.Linear(hidden_size*2, num_classes) # 2 for bidirection
def forward(self, x):
# Set initial states
h0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(device) # 2 for bidirection
c0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(device)
# Forward propagate LSTM
out, _ = self.lstm(x, (h0, c0)) # out: tensor of shape (batch_size, seq_length, hidden_size*2)
# Decode the hidden state of the last time step
out = self.fc(out[:, -1, :])
return out
model = BiRNN(input_size, hidden_size, num_layers, num_classes).to(device)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
images = images.reshape(-1, sequence_length, input_size).to(device)
labels = labels.to(device)
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i+1) % 100 == 0:
print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
.format(epoch+1, num_epochs, i+1, total_step, loss.item()))
# Test the model
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
images = images.reshape(-1, sequence_length, input_size).to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))
# Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')
代码解析
1.导入库
python
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
这部分代码导入了编写神经网络所需的PyTorch库及其子模块。
2.设备配置
python
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
根据是否有可用的GPU,设置计算设备,优先使用GPU以加速训练。
3.超参数设置
python
sequence_length = 28
input_size = 28
hidden_size = 128
num_layers = 2
num_classes = 10
batch_size = 100
num_epochs = 2
learning_rate = 0.003
设置了模型训练所需的超参数,包括时间序列的长度、输入数据的尺寸、隐藏层的尺寸、LSTM层数、类别数、批次大小、训练轮数和学习率。
4.数据集加载
python
train_dataset = torchvision.datasets.MNIST(..., transform=transforms.ToTensor(), download=True)
test_dataset = torchvision.datasets.MNIST(..., transform=transforms.ToTensor())
加载MNIST数据集的训练集和测试集,并使用transforms.ToTensor()将图像数据转换为张量。
5.数据加载器
python
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
创建了两个数据加载器,分别用于训练和测试数据的批量加载。
6.定义BiRNN模型
python
class BiRNN(nn.Module):
# 定义双向循环神经网络模型
创建了一个双向LSTM的模型,包含初始化方法和前向传播方法。
7.实例化模型并移动到设备
python
model = BiRNN(input_size, hidden_size, num_layers, num_classes).to(device)
实例化BiRNN模型,并将模型移动到之前设置的计算设备上。
8.损失函数和优化器
python
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
定义了交叉熵损失函数和Adam优化器。
9.训练模型
python
for epoch in range(num_epochs):
# 训练循环
在每个epoch中,遍历训练数据的每个批次,执行前向传播、计算损失、反向传播和参数更新。
10.测试模型
python
with torch.no_grad():
# 测试循环
在测试阶段,关闭梯度计算,遍历测试数据的每个批次,计算模型的预测准确率。
11.保存模型
python
torch.save(model.state_dict(), 'model.ckpt')
保存模型的参数到文件,以便于后续的加载和使用。
这段代码实现了一个完整的训练和测试流程,适合用于分类任务,特别是涉及序列数据的任务。对于MNIST数据集,尽管它不是序列数据,但通过将图像的每一行视为序列的一部分,可以使用RNN进行处理。
常用函数
-
torch.device-
格式:
torch.device(device_str) -
参数:
device_str------ 指定设备类型(如'cuda'或'cpu')的字符串。 -
样式:属性访问器。
-
示例:
pythondevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-
-
torchvision.datasets.MNIST-
格式:
torchvision.datasets.MNIST(root, train, transform, download) -
参数:
root------ 数据集存放的根目录。train------ 是否加载训练集。transform------ 对图像进行的变换操作。download------ 是否下载数据集。
-
样式:类方法调用。
-
示例:
pythontrain_dataset = torchvision.datasets.MNIST(root='../../data/', train=True, transform=transforms.ToTensor(), download=True)
-
-
torchvision.transforms.Compose-
格式:
torchvision.transforms.Compose(transforms_list) -
参数:
transforms_list------ 包含多个变换操作的列表。 -
样式:类方法调用。
-
示例:
pythontransform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])
-
-
torch.utils.data.DataLoader-
格式:
torch.utils.data.DataLoader(dataset, batch_size, shuffle) -
参数:
dataset------ 加载的数据集。batch_size------ 每个批次的样本数。shuffle------ 是否在每个epoch开始时打乱数据。
-
样式:类方法调用。
-
示例:
pythontrain_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
-
-
nn.Module-
格式:
class YourModelClass(nn.Module) -
参数:继承自
nn.Module的类定义。 -
样式:类继承。
-
示例:
pythonclass BiRNN(nn.Module): def __init__(self, ...): super(BiRNN, self).__init__() ...
-
-
nn.LSTM-
格式:
nn.LSTM(input_size, hidden_size, num_layers, batch_first, bidirectional) -
参数:
input_size------ 输入特征的维度。hidden_size------ 隐藏层的维度。num_layers------ LSTM层的数量。batch_first------ 输入和输出张量的第一个维度是否为批次大小。bidirectional------ 是否使用双向LSTM。
-
样式:类方法调用。
-
示例:
pythonself.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True)
-
-
nn.Linear-
格式:
nn.Linear(in_features, out_features) -
参数:
in_features------ 输入特征的数量。out_features------ 输出特征的数量。
-
样式:类方法调用。
-
示例:
pythonself.fc = nn.Linear(hidden_size * 2, num_classes)
-
-
nn.CrossEntropyLoss-
格式:
nn.CrossEntropyLoss() -
参数:无默认参数。
-
样式:类方法调用。
-
示例:
pythoncriterion = nn.CrossEntropyLoss()
-
-
torch.optim.Adam-
格式:
torch.optim.Adam(params, lr) -
参数:
params------ 模型参数。lr------ 学习率。
-
样式:类方法调用。
-
示例:
pythonoptimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
-
-
.to(device)-
格式:
.to(device) -
参数:
device------ 指定的计算设备。 -
样式:方法调用。
-
示例:
pythonimages = images.to(device)
-
-
.reshape-
格式:
.reshape(shape) -
参数:
shape------ 要重塑成的新形状。 -
样式:方法调用。
-
示例:
pythonimages = images.reshape(-1, sequence_length, input_size)
-
-
torch.zeros-
格式:
torch.zeros(size, device) -
参数:
size------ 张量的形状。device------ 张量所在的设备。
-
样式:函数调用。
-
示例:
pythonh0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(device)
-
-
torch.max-
格式:
torch.max(input, dim, keepdim) -
参数:
input------ 输入张量。dim------ 要计算最大值的维度。keepdim------ 是否保留计算维度。
-
样式:函数调用。
-
示例:
python_, predicted = torch.max(outputs.data, 1)
-
-
torch.no_grad()-
格式:
torch.no_grad() -
参数:无参数。
-
样式:上下文管理器。
-
示例:
pythonwith torch.no_grad(): ...
-
-
torch.save-
格式:
torch.save(object, filename) -
参数:
object------ 要保存的对象。filename------ 文件名。
-
样式:函数调用。
-
示例:
pythontorch.save(model.state_dict(), 'model.ckpt')
-
-
plt.imshow-
格式:
plt.imshow(X, cmap) -
参数:
X------ 要显示的图像数据。cmap------ 颜色映射。
-
样式:函数调用。
-
示例:
pythonplt.imshow(images[j].squeeze().cpu(), cmap='gray')
-
-
plt.show-
格式:
plt.show() -
参数:无参数。
-
样式:函数调用。
-
示例:
pythonplt.show()
-
-
plt.figure-
格式:
plt.figure(figsize) -
参数:
figsize------ 图形的尺寸。 -
样式:函数调用。
-
示例:
pythonplt.figure(figsize=(20, 4))
-
-
plt.subplot-
格式:
plt.subplot(nrows, ncols, index) -
参数:
nrows------ 子图的行数。ncols------ 子图的列数。index------ 当前子图的索引。
-
样式:函数调用。
-
示例:
pythonplt.subplot(1, num_samples, j+1)
-
这些函数覆盖了从数据预处理、模型构建、训练、测试到结果可视化的整个流程。