文章目录
前言
本代码实现了一个基于PyTorch的双向长短期记忆网络(BiRNN),用于对MNIST数据集中的手写数字进行分类。MNIST数据集是一个广泛使用的计算机视觉数据集,包含了大量的手写数字图像,适合用来训练和测试深度学习模型。
代码的关键特点包括:
-
数据加载与预处理 :使用
torchvision
库加载MNIST数据集,并应用了标准化变换以准备数据输入模型。 -
BiRNN模型定义 :模型使用
nn.LSTM
模块构建双向LSTM层,能够处理序列数据,并通过nn.Linear
层进行最终的分类。 -
设备无关性 :通过
torch.device
自动选择GPU或CPU,提高了代码的通用性。 -
训练与测试:实现了模型的训练循环和测试循环,包括损失计算、反向传播和参数更新。
-
可视化工具 :集成了数据可视化和模型架构可视化功能,使用
matplotlib
库展示数据样本和训练进度。 -
模型保存 :训练完成后,使用
torch.save
保存模型参数,方便后续的加载和使用。 -
超参数设置:提供了灵活的超参数设置,包括隐藏层大小、层数、批次大小、训练轮数和学习率。
代码结构清晰,易于理解和修改,适合作为深度学习入门和实践的参考。通过本代码,用户可以了解如何使用PyTorch构建和训练一个BiRNN模型,并对MNIST数据集进行分类任务。
说明
- 确保安装了PyTorch、torchvision和matplotlib。
- 调整超参数以适应不同的训练需求。
- 运行代码,观察训练过程和测试结果。
- 使用可视化工具了解数据和模型架构。
完整代码
python
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Hyper-parameters
sequence_length = 28
input_size = 28
hidden_size = 128
num_layers = 2
num_classes = 10
batch_size = 100
num_epochs = 2
learning_rate = 0.003
# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='../../data/',
train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset = torchvision.datasets.MNIST(root='../../data/',
train=False,
transform=transforms.ToTensor())
# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False)
# Bidirectional recurrent neural network (many-to-one)
class BiRNN(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, num_classes):
super(BiRNN, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True)
self.fc = nn.Linear(hidden_size*2, num_classes) # 2 for bidirection
def forward(self, x):
# Set initial states
h0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(device) # 2 for bidirection
c0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(device)
# Forward propagate LSTM
out, _ = self.lstm(x, (h0, c0)) # out: tensor of shape (batch_size, seq_length, hidden_size*2)
# Decode the hidden state of the last time step
out = self.fc(out[:, -1, :])
return out
model = BiRNN(input_size, hidden_size, num_layers, num_classes).to(device)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
images = images.reshape(-1, sequence_length, input_size).to(device)
labels = labels.to(device)
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i+1) % 100 == 0:
print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
.format(epoch+1, num_epochs, i+1, total_step, loss.item()))
# Test the model
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
images = images.reshape(-1, sequence_length, input_size).to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))
# Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')
代码解析
1.导入库
python
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
这部分代码导入了编写神经网络所需的PyTorch库及其子模块。
2.设备配置
python
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
根据是否有可用的GPU,设置计算设备,优先使用GPU以加速训练。
3.超参数设置
python
sequence_length = 28
input_size = 28
hidden_size = 128
num_layers = 2
num_classes = 10
batch_size = 100
num_epochs = 2
learning_rate = 0.003
设置了模型训练所需的超参数,包括时间序列的长度、输入数据的尺寸、隐藏层的尺寸、LSTM层数、类别数、批次大小、训练轮数和学习率。
4.数据集加载
python
train_dataset = torchvision.datasets.MNIST(..., transform=transforms.ToTensor(), download=True)
test_dataset = torchvision.datasets.MNIST(..., transform=transforms.ToTensor())
加载MNIST数据集的训练集和测试集,并使用transforms.ToTensor()
将图像数据转换为张量。
5.数据加载器
python
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
创建了两个数据加载器,分别用于训练和测试数据的批量加载。
6.定义BiRNN模型
python
class BiRNN(nn.Module):
# 定义双向循环神经网络模型
创建了一个双向LSTM的模型,包含初始化方法和前向传播方法。
7.实例化模型并移动到设备
python
model = BiRNN(input_size, hidden_size, num_layers, num_classes).to(device)
实例化BiRNN模型,并将模型移动到之前设置的计算设备上。
8.损失函数和优化器
python
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
定义了交叉熵损失函数和Adam优化器。
9.训练模型
python
for epoch in range(num_epochs):
# 训练循环
在每个epoch中,遍历训练数据的每个批次,执行前向传播、计算损失、反向传播和参数更新。
10.测试模型
python
with torch.no_grad():
# 测试循环
在测试阶段,关闭梯度计算,遍历测试数据的每个批次,计算模型的预测准确率。
11.保存模型
python
torch.save(model.state_dict(), 'model.ckpt')
保存模型的参数到文件,以便于后续的加载和使用。
这段代码实现了一个完整的训练和测试流程,适合用于分类任务,特别是涉及序列数据的任务。对于MNIST数据集,尽管它不是序列数据,但通过将图像的每一行视为序列的一部分,可以使用RNN进行处理。
常用函数
-
torch.device
-
格式:
torch.device(device_str)
-
参数:
device_str
------ 指定设备类型(如'cuda'
或'cpu'
)的字符串。 -
样式:属性访问器。
-
示例:
pythondevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-
-
torchvision.datasets.MNIST
-
格式:
torchvision.datasets.MNIST(root, train, transform, download)
-
参数:
root
------ 数据集存放的根目录。train
------ 是否加载训练集。transform
------ 对图像进行的变换操作。download
------ 是否下载数据集。
-
样式:类方法调用。
-
示例:
pythontrain_dataset = torchvision.datasets.MNIST(root='../../data/', train=True, transform=transforms.ToTensor(), download=True)
-
-
torchvision.transforms.Compose
-
格式:
torchvision.transforms.Compose(transforms_list)
-
参数:
transforms_list
------ 包含多个变换操作的列表。 -
样式:类方法调用。
-
示例:
pythontransform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])
-
-
torch.utils.data.DataLoader
-
格式:
torch.utils.data.DataLoader(dataset, batch_size, shuffle)
-
参数:
dataset
------ 加载的数据集。batch_size
------ 每个批次的样本数。shuffle
------ 是否在每个epoch开始时打乱数据。
-
样式:类方法调用。
-
示例:
pythontrain_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
-
-
nn.Module
-
格式:
class YourModelClass(nn.Module)
-
参数:继承自
nn.Module
的类定义。 -
样式:类继承。
-
示例:
pythonclass BiRNN(nn.Module): def __init__(self, ...): super(BiRNN, self).__init__() ...
-
-
nn.LSTM
-
格式:
nn.LSTM(input_size, hidden_size, num_layers, batch_first, bidirectional)
-
参数:
input_size
------ 输入特征的维度。hidden_size
------ 隐藏层的维度。num_layers
------ LSTM层的数量。batch_first
------ 输入和输出张量的第一个维度是否为批次大小。bidirectional
------ 是否使用双向LSTM。
-
样式:类方法调用。
-
示例:
pythonself.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True)
-
-
nn.Linear
-
格式:
nn.Linear(in_features, out_features)
-
参数:
in_features
------ 输入特征的数量。out_features
------ 输出特征的数量。
-
样式:类方法调用。
-
示例:
pythonself.fc = nn.Linear(hidden_size * 2, num_classes)
-
-
nn.CrossEntropyLoss
-
格式:
nn.CrossEntropyLoss()
-
参数:无默认参数。
-
样式:类方法调用。
-
示例:
pythoncriterion = nn.CrossEntropyLoss()
-
-
torch.optim.Adam
-
格式:
torch.optim.Adam(params, lr)
-
参数:
params
------ 模型参数。lr
------ 学习率。
-
样式:类方法调用。
-
示例:
pythonoptimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
-
-
.to(device)
-
格式:
.to(device)
-
参数:
device
------ 指定的计算设备。 -
样式:方法调用。
-
示例:
pythonimages = images.to(device)
-
-
.reshape
-
格式:
.reshape(shape)
-
参数:
shape
------ 要重塑成的新形状。 -
样式:方法调用。
-
示例:
pythonimages = images.reshape(-1, sequence_length, input_size)
-
-
torch.zeros
-
格式:
torch.zeros(size, device)
-
参数:
size
------ 张量的形状。device
------ 张量所在的设备。
-
样式:函数调用。
-
示例:
pythonh0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(device)
-
-
torch.max
-
格式:
torch.max(input, dim, keepdim)
-
参数:
input
------ 输入张量。dim
------ 要计算最大值的维度。keepdim
------ 是否保留计算维度。
-
样式:函数调用。
-
示例:
python_, predicted = torch.max(outputs.data, 1)
-
-
torch.no_grad()
-
格式:
torch.no_grad()
-
参数:无参数。
-
样式:上下文管理器。
-
示例:
pythonwith torch.no_grad(): ...
-
-
torch.save
-
格式:
torch.save(object, filename)
-
参数:
object
------ 要保存的对象。filename
------ 文件名。
-
样式:函数调用。
-
示例:
pythontorch.save(model.state_dict(), 'model.ckpt')
-
-
plt.imshow
-
格式:
plt.imshow(X, cmap)
-
参数:
X
------ 要显示的图像数据。cmap
------ 颜色映射。
-
样式:函数调用。
-
示例:
pythonplt.imshow(images[j].squeeze().cpu(), cmap='gray')
-
-
plt.show
-
格式:
plt.show()
-
参数:无参数。
-
样式:函数调用。
-
示例:
pythonplt.show()
-
-
plt.figure
-
格式:
plt.figure(figsize)
-
参数:
figsize
------ 图形的尺寸。 -
样式:函数调用。
-
示例:
pythonplt.figure(figsize=(20, 4))
-
-
plt.subplot
-
格式:
plt.subplot(nrows, ncols, index)
-
参数:
nrows
------ 子图的行数。ncols
------ 子图的列数。index
------ 当前子图的索引。
-
样式:函数调用。
-
示例:
pythonplt.subplot(1, num_samples, j+1)
-
这些函数覆盖了从数据预处理、模型构建、训练、测试到结果可视化的整个流程。