在计算机视觉的图像分类任务中,从头搭建并训练深度神经网络不仅需要大量的标注数据,还需要耗费大量的计算资源。迁移学习作为一种高效的建模方法,能够借助预训练模型的特征提取能力,快速适配新的分类任务。本文将以 20 类食物分类为例,详细讲解如何基于 PyTorch 框架,利用 ResNet18 预训练模型实现迁移学习,从模型改造、数据处理到训练测试。
一、迁移学习逻辑
迁移学习的核心实现逻辑分为两步:
-
冻结预训练模型的卷积层:保留其已学习的特征提取能力,避免训练过程中破坏原有权重;
-
替换全连接层:ResNet18 原全连接层输出为 1000 类(适配 ImageNet),本次任务为 20 类食物分类,因此将原全连接层替换为输入特征数不变、输出为 20 的新全连接层,仅训练该层的参数,大幅降低训练成本。
同时,为了提升模型泛化能力,对训练集进行数据增强,验证集仅做标准化处理;训练过程中使用 Adam 优化器更新参数,结合学习率调度器动态调整学习率,最终实现高效的模型训练。
二、预训练模型改造与参数冻结
这一步是迁移学习的核心,需要完成加载预训练 ResNet18、冻结卷积层参数、替换全连接层三个关键操作,同时筛选出需要训练的参数。
2.1 加载预训练 ResNet18 模型
PyTorch 的torchvision.models模块提供了封装好的 ResNet18,通过指定weights=models.ResNet18_Weights.DEFAULT可以直接加载在 ImageNet 上预训练的权重,无需手动下载权重文件:
python
import torch
import torchvision.models as models
from torch import nn
# 加载ResNet18预训练模型
resnet_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
2.2 冻结卷积层所有参数
将模型所有参数的requires_grad属性设置为False,表示这些参数在训练过程中不计算梯度,也不会被优化器更新,即实现参数冻结:
python
# 冻结所有预训练参数(卷积层为主)
for param in resnet_model.parameters():
param.requires_grad = False
2.3 替换全连接层并筛选训练参数
ResNet18 的全连接层(fc)是模型的分类头,需要根据任务需求替换。首先获取原全连接层的输入特征数(由卷积层的输出特征决定,固定为 512),然后替换为输出 20 的新全连接层;最后筛选出requires_grad=True的参数(仅新全连接层的参数),作为后续优化器的更新对象:
python
# 获取原全连接层的输入特征数
in_features = resnet_model.fc.in_features
# 替换为输出20的新全连接层(适配20类食物分类)
resnet_model.fc = nn.Linear(in_features, 20)
# 筛选需要训练的参数(仅全连接层参数)
params_to_update = []
for param in resnet_model.parameters():
if param.requires_grad == True:
params_to_update.append(param)
三、图像数据处理:变换、加载与封装
食物分类的原始数据为图像文件和对应的标签文件(trainda.txt/testda.txt),标签文件中每行格式为图像路径 类别标签。需要通过数据变换、自定义 Dataset、DataLoader 封装三步,将原始数据转换为 PyTorch 可训练的格式。
3.1 图像变换策略:训练集增强,验证集标准化
图像变换的核心原则是:训练集做数据增强提升泛化能力,验证 / 测试集仅做与训练集一致的标准化处理,避免引入额外噪声。本次使用torchvision.transforms实现变换
python
from torchvision import transforms
# 定义训练集和验证集的图像变换
data_transforms = {
"train": transforms.Compose([
transforms.Resize([300, 300]), # 缩放为300*300,为后续裁剪预留空间
transforms.RandomRotation(45), # 随机旋转-45~45度,增强旋转鲁棒性
transforms.CenterCrop(224), # 中心裁剪为224*224,匹配ResNet输入要求
transforms.ToTensor(), # 转换为Tensor,格式[C, H, W],值归一化到0~1
# 标准化:使用ImageNet的均值和标准差,与预训练模型保持一致
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
"valid": transforms.Compose([
transforms.Resize([256, 256]), # 缩放为256*256
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
3.2 自定义 Dataset:解析标签文件并加载图像
PyTorch 的Dataset是抽象类,需要自定义子类实现__init__、__len__、__getitem__三个核心方法,实现标签文件解析、图像读取、格式转换的功能,将图像和标签封装为可索引的数据集
python
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
class food_dataset(Dataset):
def __init__(self, file_path, transform=None):
"""
初始化:解析标签文件,保存图像路径和标签
:param file_path: 标签文件路径(trainda.txt/testda.txt)
:param transform: 图像变换策略
"""
self.file_path = file_path
self.imgs = [] # 保存所有图像路径
self.labels = [] # 保存所有图像对应的标签
self.transform = transform
# 解析标签文件
with open(self.file_path) as f:
# 每行按空格分割,得到[图像路径, 标签]
samples = [x.strip().split(' ') for x in f.readlines()]
for img_path, label in samples:
self.imgs.append(img_path)
self.labels.append(label)
def __len__(self):
"""返回数据集的样本总数"""
return len(self.imgs)
def __getitem__(self, idx):
"""根据索引获取单个样本(图像+标签),并完成格式转换"""
# 读取PIL图像
image = Image.open(self.imgs[idx])
# 应用图像变换(转换为Tensor并标准化/增强)
if self.transform:
image = self.transform(image)
# 标签转换为Tensor(int64类型,适配交叉熵损失)
label = torch.from_numpy(np.array(self.labels[idx], dtype=np.int64))
return image, label
3.3 DataLoader 封装:批量加载与数据打乱
通过torch.utils.data.DataLoader将自定义的 Dataset 封装为可迭代的批量数据加载器,实现批量读取、数据打乱、多进程加载(默认)等功能,是 PyTorch 训练的标准数据输入方式
python
from torch.utils.data import DataLoader
# 实例化训练集和测试集
training_data = food_dataset(file_path=r'.\trainda.txt', transform=data_transforms['train'])
test_data = food_dataset(file_path=r'.\testda.txt', transform=data_transforms['valid'])
# 封装为DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True) # 训练集打乱
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True) # 测试集可选择是否打乱
参数说明:
-
batch_size=64:每次读取 64 个样本为一个批次,批次大小可根据显存调整(显存不足则减小); -
shuffle=True:每个 epoch 训练前打乱数据顺序,避免模型学习到数据的顺序规律,提升泛化能力。
四、设备配置与模型部署
python
# 自动检测可用设备
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")
# 将模型部署到指定设备
model = resnet_model.to(device)
五、训练与测试函数定义
5.1 训练函数:实现模型的一次 epoch 训练
训练函数的核心是将模型切换为训练模式(model.train()),遍历训练集的每个批次,完成前向传播预测、交叉熵损失计算、梯度清零、反向传播求梯度、优化器更新参数的流程,并打印批次损失值.
python
def train(dataloader, model, loss_fn, optimizer):
model.train() # 切换为训练模式:启用Dropout、BatchNorm的训练模式
batch_size_num = 1 # 统计训练的批次数量
for x, y in dataloader:
# 将数据部署到指定设备
x, y = x.to(device), y.to(device)
# 前向传播:模型预测(.forward可省略,直接model(x)即可)
pred = model(x)
# 计算损失:交叉熵损失(适配分类任务)
loss = loss_fn(pred, y)
# 反向传播与参数更新
optimizer.zero_grad() # 梯度清零:避免上一批次的梯度累积
loss.backward() # 反向传播:计算所有可训练参数的梯度
optimizer.step() # 优化器更新:根据梯度调整参数
# 提取损失值并按批次打印
loss_value = loss.item()
if batch_size_num % 2 == 0:
print(f"loss:{loss_value:>7f} [number:{batch_size_num}]")
batch_size_num += 1
-
model.train():必须在训练前调用,启用模型中训练相关的层(如 BatchNorm、Dropout),确保训练过程的正确性; -
optimizer.zero_grad():每次批次训练前必须清零梯度,否则梯度会在批次间累积,导致参数更新错误; -
前向传播时
model(x)等价于model.forward(x),PyTorch 已对__call__方法做了封装,推荐直接使用model(x)。
5.2 测试函数:实现模型的精度与损失评估
测试函数的核心是将模型切换为评估模式(model.eval()),关闭梯度计(torch.no_grad()),遍历测试集的所有批次,统计总损失和正确预测数,最终计算平均损失和准确率。
python
def test(dataloader, model, loss_fn):
size = len(dataloader.dataset) # 测试集总样本数
num_batches = len(dataloader) # 测试集总批次数
model.eval() # 切换为评估模式:关闭Dropout、固定BatchNorm参数
test_loss, correct = 0, 0 # 初始化总损失和正确预测数
# 关闭梯度计算:测试过程无需求梯度,节省显存和计算资源
with torch.no_grad():
for x, y in dataloader:
x, y = x.to(device), y.to(device)
pred = model(x) # 前向传播预测
# 累加批次损失
test_loss += loss_fn(pred, y).item()
# 累加正确预测数:pred.argmax(1)获取每行最大值的索引(预测类别),与真实标签比较
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
# 计算平均损失和准确率
test_loss /= num_batches
correct /= size
print(f"Test result: Accuracy:{(100*correct):>0.1f}%,Avg loss: {test_loss:>8f}")
-
model.eval():必须在测试前调用,关闭 Dropout 层、固定 BatchNorm 的均值和方差,避免评估过程中模型参数变化,确保评估结果的稳定性; -
torch.no_grad():上下文管理器,关闭梯度计算,大幅节省显存和计算资源,测试过程无需反向传播,因此无需计算梯度; -
pred.argmax(1):对模型输出的预测值(形状 [batch_size, 20])按维度 1 取最大值索引,即模型预测的类别标签(20 类中概率最大的类)。
六、优化器、损失函数与学习率
6.1 损失函数:交叉熵损失
python
loss_fn = nn.CrossEntropyLoss()
6.2 优化器:Adam 优化器
选择 Adam 优化器更新模型参数,Adam 结合了动量法和自适应学习率的优点,收敛速度快、调参简单,本次仅更新筛选出的全连接层参数(params_to_update),学习率设置为 0.001
python
optimizer = torch.optim.Adam(params_to_update, lr=0.001)
# 若全模型微调,改为:optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
6.3 学习率调度器:StepLR
为了避免模型训练后期陷入局部最优,使用StepLR学习率调度器,每经过指定的 epoch 数,将学习率乘以衰减系数,实现动态学习率调整
python
# 每5个epoch,学习率乘以0.5(衰减为原来的一半)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
七、训练与测试
python
# 设置训练的epoch数
epochs = 10
for t in range(epochs):
print(f"Epoch {t+1}\n----------")
train(train_dataloader, model, loss_fn, optimizer) # 训练一个epoch
scheduler.step() # 更新学习率
test(test_dataloader, model, loss_fn) # 测试模型性能
print("Training Done!")
scheduler.step():每个 epoch 训练完成后调用,更新学习率
结果显示
