项目任务:运用残差网络模型来识别手写数字
代码实现:
python
import torch
print(torch.__version__)
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
training_data = datasets.MNIST(root='data',train=True,download=True,transform=ToTensor())
test_data = datasets.MNIST(root='data',train=False,download=True,transform=ToTensor())
train_dataloader = DataLoader(training_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)
for X,y in test_dataloader:
print(f"Shape of X[N,C,H,W]:{X.shape}")
print(f"Shape of y: {y.shape} {y.dtype}")
break
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f"Using {device} device")
import torch
import torch.nn as nn
import torch.nn.functional as F
# 残差块定义
class ResBlock(nn.Module):
def __init__(self, channels_in):
super().__init__()
self.conv1 = torch.nn.Conv2d(channels_in,30, kernel_size=5, padding=2)
self.conv2 = torch.nn.Conv2d(30, channels_in, kernel_size=3, padding=1)
def forward(self, x):
out = self.conv1(x)
out = self.conv2(out)
return F.relu(out + x) # 残差连接(out + 输入x)
# ResNet网络定义
class ResNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1,20,5)
self.conv2 = torch.nn.Conv2d(20,15,3)
self.maxpool = torch.nn.MaxPool2d(2)
self.resblock1 = ResBlock(channels_in=20) # 第一个残差块
self.resblock2 = ResBlock(channels_in=15) # 第二个残差块
self.full_c = torch.nn.Linear(375, 10) # 全连接层(输出维度10,对应10分类)
def forward(self, x):
size = x.shape[0] # 获取批次大小
# 第一段卷积+池化+残差块
x = F.relu(self.maxpool(self.conv1(x)))
x = self.resblock1(x)
# 第二段卷积+池化+残差块
x = F.relu(self.maxpool(self.conv2(x)))
x = self.resblock2(x)
# 展平后送入全连接层
x = x.view(size, -1)
x = self.full_c(x)
return x
model = ResNet().to(device)
def train(dataloader,model,loss_fn,optimizer):
model.train()
batch_size_num = 1
for X ,y in dataloader:
X,y = X.to(device),y.to(device)
pred = model(X)
loss = loss_fn(pred,y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_value = loss.item()
if batch_size_num % 100 ==0:
print(f"loss:{loss_value:>7f} [number:{batch_size_num}]")
batch_size_num +=1
best_acc=0
def test(dataloader,model,loss_fn):
global best_acc
size = len(dataloader.dataset)
num_batches= len(dataloader)
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for X ,y in dataloader:
X,y = X.to(device),y.to(device)
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_pj_loss = test_loss / num_batches
test_acy = correct / size * 100
print(f"Avg loss: {test_pj_loss:>7f} \n Accuray: {test_acy:>5.2f}%")
if correct > best_acc:
best_acc = correct
# 保存模型的状态字典,而非整个模型
torch.save(model.state_dict(), 'best.pth') # 重点修改这里
print(f"保存最佳模型,准确率: {test_acy:>5.2f}%")
else:
print(f"保存最佳模型,准确率: {test_acy:>5.2f}%")
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
# train(train_dataloader,model,loss_fn,optimizer)
# test(test_dataloader,model,loss_fn)
i=10
for j in range(i):
print(f"Epoch {j+1}\n----------")
train(train_dataloader, model,loss_fn,optimizer)
test(test_dataloader,model,loss_fn)
这段代码是一个基于 PyTorch 实现的残差网络(ResNet),用于训练和测试 MNIST 手写数字识别任务。下面对代码进行解析:
1. 库导入与数据集准备
import torch
print(torch.__version__) # 打印PyTorch版本
from torch import nn # 神经网络模块
from torch.utils.data import DataLoader # 数据加载工具
from torchvision import datasets # 计算机视觉数据集
from torchvision.transforms import ToTensor # 图像转张量的转换工具
导入了 PyTorch 核心库、神经网络模块、数据加载工具,以及处理 MNIST 数据集的相关工具。
ToTensor()
用于将图像(PIL 格式)转换为 PyTorch 张量,并自动将像素值归一化到[0, 1]
范围。
# 加载MNIST训练集和测试集
training_data = datasets.MNIST(
root='data', # 数据保存路径
train=True, # 训练集
download=True, # 若本地无数据则自动下载
transform=ToTensor() # 应用转换
)
test_data = datasets.MNIST(
root='data',
train=False, # 测试集
download=True,
transform=ToTensor()
)
MNIST 是经典的手写数字数据集,包含 60000 张训练图和 10000 张测试图,每张图是 28x28 的灰度图(单通道),标签为 0-9 的数字。
# 数据加载器(按批次加载数据,方便批量训练)
train_dataloader = DataLoader(training_data, batch_size=64) # 训练集批次大小64
test_dataloader = DataLoader(test_data, batch_size=64) # 测试集批次大小64
DataLoader
将数据集按batch_size
分批,支持自动打乱数据、多线程加载等功能,是训练中高效读取数据的工具。
# 打印数据形状,验证数据格式
for X, y in test_dataloader:
print(f"Shape of X[N,C,H,W]: {X.shape}") # 输入图像形状
print(f"Shape of y: {y.shape} {y.dtype}") # 标签形状和类型
break
输出示例:X[N,C,H,W]
中,N=64
(批次大小)、C=1
(单通道灰度图)、H=28
、W=28
(图像尺寸);y
是长度为 64 的标签(类型为long
,适合分类任务)。
2. 设备配置
# 自动选择训练设备(优先GPU,其次苹果芯片,最后CPU)
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f"Using {device} device")
深度学习训练通常需要 GPU 加速,这里自动检测并选择可用的加速设备,最大化训练效率。
3. 残差网络(ResNet)定义
残差网络的核心是残差块(Residual Block),通过 "跳跃连接"(将输入直接加到输出)缓解深层网络的梯度消失问题。
3.1 残差块(ResBlock)
class ResBlock(nn.Module):
def __init__(self, channels_in):
super().__init__()
# 第一个卷积层:输入通道数→30,5x5卷积核,padding=2(保持尺寸)
self.conv1 = torch.nn.Conv2d(channels_in, 30, kernel_size=5, padding=2)
# 第二个卷积层:30→输入通道数,3x3卷积核,padding=1(保持尺寸)
self.conv2 = torch.nn.Conv2d(30, channels_in, kernel_size=3, padding=1)
def forward(self, x):
out = self.conv1(x) # 第一次卷积
out = self.conv2(out) # 第二次卷积
return F.relu(out + x) # 残差连接(输出+输入)+ ReLU激活
残差块的关键是out + x
:将输入x
直接加到卷积后的输出out
上,实现 "跳跃连接",确保梯度能有效回传。
卷积层的padding
设置保证了输入和输出的尺寸一致,才能进行加法操作。
3.2 完整 ResNet 网络
class ResNet(nn.Module):
def __init__(self):
super().__init__()
# 第一层卷积:输入1通道(灰度图)→20通道,5x5卷积核
self.conv1 = torch.nn.Conv2d(1, 20, 5)
# 第二层卷积:20通道→15通道,3x3卷积核
self.conv2 = torch.nn.Conv2d(20, 15, 3)
self.maxpool = torch.nn.MaxPool2d(2) # 2x2最大池化(尺寸减半)
self.resblock1 = ResBlock(channels_in=20) # 第一个残差块(输入20通道)
self.resblock2 = ResBlock(channels_in=15) # 第二个残差块(输入15通道)
self.full_c = torch.nn.Linear(375, 10) # 全连接层(输出10类,对应0-9)
def forward(self, x):
size = x.shape[0] # 获取批次大小(用于后续展平操作)
# 第一段:卷积→池化→激活→残差块
x = F.relu(self.maxpool(self.conv1(x))) # conv1→池化(尺寸减半)→ReLU
x = self.resblock1(x) # 经过第一个残差块
# 第二段:卷积→池化→激活→残差块
x = F.relu(self.maxpool(self.conv2(x))) # conv2→池化(尺寸减半)→ReLU
x = self.resblock2(x) # 经过第二个残差块
# 展平特征图→全连接层输出
x = x.view(size, -1) # 展平为(batch_size, 特征数),这里特征数为375
x = self.full_c(x) # 输出10类的预测概率(未经过softmax)
return x
网络整体流程:输入图像→卷积层提取特征→池化层降维→残差块增强特征→全连接层输出分类结果。
x.view(size, -1)
将卷积后的三维特征图(batch, channel, height, width)展平为二维张量(batch, 特征数),才能输入全连接层。
全连接层输入维度375
是根据前面的特征图尺寸计算的(具体为:经过多次卷积和池化后,特征图尺寸为 5x5,通道数 15,5×5×15=375)。
4. 模型初始化
model = ResNet().to(device) # 实例化模型,并移动到之前选择的设备(GPU/CPU)
5. 训练与测试函数
5.1 训练函数(train)
def train(dataloader, model, loss_fn, optimizer):
model.train() # 设置模型为训练模式(启用dropout、批归一化更新等)
batch_size_num = 1 # 批次计数器
for X, y in dataloader:
X, y = X.to(device), y.to(device) # 数据移到设备
# 前向传播:计算预测值
pred = model(X)
# 计算损失(预测值与真实标签的差异)
loss = loss_fn(pred, y)
# 反向传播与参数更新
optimizer.zero_grad() # 清空上一轮梯度
loss.backward() # 计算梯度
optimizer.step() # 更新参数
# 每100个批次打印一次损失
loss_value = loss.item()
if batch_size_num % 100 == 0:
print(f"loss: {loss_value:>7f} [number: {batch_size_num}]")
batch_size_num += 1
核心流程:前向传播计算预测→计算损失→反向传播求梯度→优化器更新参数。
model.train()
:启用训练模式(例如,若有 dropout 层会随机丢弃神经元)。
5.2 测试函数(test)
best_acc = 0 # 记录最佳准确率
def test(dataloader, model, loss_fn):
global best_acc # 引用全局变量
size = len(dataloader.dataset) # 测试集总样本数
num_batches = len(dataloader) # 测试集批次数
model.eval() # 设置模型为评估模式(关闭dropout等)
test_loss = 0 # 总测试损失
correct = 0 # 正确分类的样本数
with torch.no_grad(): # 关闭梯度计算(节省内存,加速计算)
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X) # 预测
test_loss += loss_fn(pred, y).item() # 累加损失
# 计算正确数:预测最大值的索引(类别)与真实标签一致
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
# 计算平均损失和准确率
test_pj_loss = test_loss / num_batches
test_acy = correct / size * 100
print(f"Avg loss: {test_pj_loss:>7f} \n Accuracy: {test_acy:>5.2f}%")
# 保存准确率最高的模型
if correct > best_acc:
best_acc = correct
torch.save(model.state_dict(), 'best.pth') # 保存模型参数(而非整个模型)
print(f"保存最佳模型,准确率: {test_acy:>5.2f}%")
else:
print(f"当前准确率未超过最佳,最佳准确率: {best_acc/size*100:>5.2f}%")
核心作用:评估模型在测试集上的性能(损失和准确率),并保存表现最好的模型。
model.eval()
:切换到评估模式(例如,关闭 dropout,固定批归一化参数)。
with torch.no_grad()
:关闭梯度计算,减少内存占用,加速测试过程。
模型保存用model.state_dict()
:仅保存参数(权重和偏置),而非整个模型结构,更轻量且灵活。
6. 训练配置与执行
loss_fn = nn.CrossEntropyLoss() # 交叉熵损失(适合分类任务,内置softmax)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # Adam优化器,学习率0.001
# 训练10个epoch
epochs = 10
for j in range(epochs):
print(f"Epoch {j+1}\n----------")
train(train_dataloader, model, loss_fn, optimizer) # 训练一轮
test(test_dataloader, model, loss_fn) # 测试一轮
CrossEntropyLoss
:适用于多分类任务,自动对输出进行 softmax 处理,并计算与标签的交叉熵。
Adam
:一种常用的优化器,结合了动量和自适应学习率,收敛速度快且稳定。
epoch
:完整遍历一次训练集的次数,这里设置为 10 次,每次训练后测试模型性能。
总结
这段代码实现了一个简化版的残差网络,用于 MNIST 手写数字识别。核心亮点包括:
使用残差块解决深层网络梯度消失问题;
完整的训练 - 测试流程(含设备自动选择、损失计算、参数更新、模型保存);
符合 PyTorch 最佳实践(如train()
/eval()
模式切换、torch.no_grad()
关闭梯度等)。
通过训练,模型通常能达到 98% 以上的准确率,残差结构相比普通卷积网络能更高效地学习特征。