MNIST 数据集作为机器学习领域的经典数据集,包含了大量手写数字图像,非常适合用于入门深度学习模型的训练与测试。今天,我们就来一步步使用 PyTorch 实现 MNIST 手写数字识别。
一、准备工作
首先,我们需要导入必要的库。numpy
用于数值计算,torch
及其相关模块用于构建和训练神经网络,matplotlib
用于数据可视化。
python运行
import numpy as np
import torch
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
from torch import nn
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
%matplotlib inline
接着定义一些超参数,比如批次大小、学习率和训练轮数等,这些参数会影响模型的训练过程和结果。
python运行
train_batch_size = 64
test_batch_size = 128
learning_rate = 0.01
num_epochs = 20
二、数据准备与预处理
我们使用torchvision
的transforms
来对数据进行预处理,将图像转换为张量并进行归一化,使得数据更适合神经网络训练。然后利用MNIST
类下载数据集,并通过DataLoader
创建数据迭代器,方便按批次获取数据。
python运行
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
train_dataset = MNIST('../data/', train=True, transform=transform, download=True)
test_dataset = MNIST('../data/', train=False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)
为了直观了解数据,我们可以取出一批测试数据,查看其形状,还可以将图像可视化出来。
python运行
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
print(example_data.shape)
fig = plt.figure()
for i in range(6):
plt.subplot(2, 3, i + 1)
plt.tight_layout()
plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
plt.title(f'Ground Truth: {example_targets[i]}')
plt.xticks([])
plt.yticks([])
三、构建神经网络模型
我们定义一个Net
类来构建神经网络,使用两个隐藏层,隐藏层使用 ReLU 激活函数,输出层使用 softmax 激活函数,这样可以将输出转换为概率分布,方便进行分类。
python运行
class Net(nn.Module):
def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
super(Net, self).__init__()
self.flatten = nn.Flatten()
self.layer1 = nn.Sequential(nn.Linear(in_dim, n_hidden_1), nn.BatchNorm1d(n_hidden_1))
self.layer2 = nn.Sequential(nn.Linear(n_hidden_1, n_hidden_2), nn.BatchNorm1d(n_hidden_2))
self.out = nn.Sequential(nn.Linear(n_hidden_2, out_dim))
def forward(self, x):
x = self.flatten(x)
x = F.relu(self.layer1(x))
x = F.relu(self.layer2(x))
x = F.softmax(self.out(x), dim=1)
return x
四、实例化模型与定义优化器
根据可用设备(GPU 或 CPU)实例化模型,然后定义损失函数(交叉熵损失)和优化器(SGD),用于模型的训练更新。
python运行
lr = 0.01
momentum = 0.9
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = Net(28 * 28, 300, 100, 10)
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
五、训练模型
在训练过程中,我们会动态调整学习率,每一轮遍历训练数据进行正向传播、反向传播和参数更新,同时记录训练损失和准确率。然后在测试集上验证模型效果,记录测试损失和准确率。
python运行
losses = []
acces = []
eval_losses = []
eval_acces = []
writer = SummaryWriter(log_dir='logs', comment='train-loss')
for epoch in range(num_epochs):
train_loss = 0
train_acc = 0
model.train()
if epoch % 5 == 0:
optimizer.param_groups[0]['lr'] *= 0.9
print(f'学习率:{optimizer.param_groups[0]["lr"]}')
for img, label in train_loader:
img = img.to(device)
label = label.to(device)
out = model(img)
loss = criterion(out, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item()
writer.add_scalar('Train', train_loss / len(train_loader), epoch)
_, pred = out.max(1)
num_correct = (pred == label).sum().item()
acc = num_correct / img.shape[0]
train_acc += acc
losses.append(train_loss / len(train_loader))
acces.append(train_acc / len(train_loader))
eval_loss = 0
eval_acc = 0
model.eval()
for img, label in test_loader:
img = img.to(device)
label = label.to(device)
img = img.view(img.size(0), -1)
out = model(img)
loss = criterion(out, label)
eval_loss += loss.item()
_, pred = out.max(1)
num_correct = (pred == label).sum().item()
acc = num_correct / img.shape[0]
eval_acc += acc
eval_losses.append(eval_loss / len(test_loader))
eval_acces.append(eval_acc / len(test_loader))
print(f'epoch: {epoch}, Train Loss: {train_loss / len(train_loader):.4f}, Train Acc: {train_acc / len(train_loader):.4f}, Test Loss: {eval_loss / len(test_loader):.4f}, Test Acc: {eval_acc / len(test_loader):.4f}')
六、可视化训练结果
最后,我们可以将训练损失的变化过程可视化出来,直观地看到模型的训练效果。
python运行
plt.title('train loss')
plt.plot(np.arange(len(losses)), losses)
plt.legend(['Train Loss'], loc='upper right')
plt.show()
通过以上步骤,我们成功使用 PyTorch 实现了 MNIST 手写数字识别,从数据准备、模型构建到训练测试,完整地走通了深度学习的流程,希望能帮助大家更好地理解和入门深度学习。