一、迁移学习(Transfer Learning)详解
1. 什么是迁移学习?
迁移学习是一种机器学习方法,其核心思想是将从一个任务(源任务)中学到的知识,应用到另一个相关但不同的任务(目标任务)中。它模仿了人类的学习方式------我们学习骑自行车后,再学骑摩托车会容易很多;学会识别猫狗后,识别虎豹也会更轻松。
在深度学习中,迁移学习通常指:在大规模数据集(如ImageNet)上预训练一个深度神经网络,然后将该网络的权重作为初始值,在目标任务的小规模数据集上进行微调或特征提取。
2. 为什么需要迁移学习?
-
数据不足:目标任务往往缺乏海量标注数据,而从头训练大模型容易过拟合。
-
计算资源有限:从头训练ResNet等大型网络需要昂贵的GPU集群和时间成本。
-
收敛更快:预训练模型已经学会了通用的特征(边缘、纹理、形状等),只需在特定任务上微调即可快速达到良好性能。
3. 迁移学习的常见策略
| 策略 | 做法 | 适用场景 |
|---|---|---|
| 特征提取 | 冻结预训练模型的所有层(不更新参数),仅将输出特征送入新的分类器(如全连接层)进行训练。 | 目标任务与源任务差异不大,且目标数据量非常小。 |
| 微调(Fine-tuning) | 解冻部分或全部预训练层,允许参数在目标任务上继续更新。通常先用较小的学习率微调高层(靠近输出的层),再逐步微调低层。 | 目标任务与源任务有一定差异,且数据量中等(几百到几千张)。 |
| 渐进式微调 | 先冻结所有层训练新分类头,再逐步解冻更多层联合训练。 | 需要平衡训练速度和泛化能力。 |
4. 关键注意事项
-
数据分布差异:若源任务(ImageNet)与目标任务(如医学影像)差异巨大,迁移收益可能有限,需考虑是否从零训练或使用域自适应技术。
-
学习率选择:微调时,预训练层的学习率通常设为新层的1/10甚至更小,避免破坏学到的通用特征。
-
过拟合风险:尽管迁移学习能缓解过拟合,但若目标数据集仍然极小,应使用更强的正则化(Dropout、权重衰减、数据增强等)。
二、完整代码
import torch
import torchvision.models as models
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
# 1. 加载预训练模型并冻结所有层
resnet_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
for param in resnet_model.parameters():
param.requires_grad = False
# 2. 替换全连接层(输出20类)
in_features = resnet_model.fc.in_features
resnet_model.fc = nn.Linear(in_features=in_features, out_features=20)
# 3. 收集需要更新的参数(仅新fc层)
params_to_updata = []
for param in resnet_model.parameters():
if param.requires_grad == True:
params_to_updata.append(param)
# 4. 定义数据预处理
data_transforms = {
'train': transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomRotation(10),
transforms.CenterCrop(224),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.RandomGrayscale(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'valid': transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
}
# 5. 自定义Dataset类
class food_dataset(Dataset):
def __init__(self, file_path, transform=None):
self.file_path = file_path
self.imgs = []
self.labels = []
self.transform = transform
with open(self.file_path) as f:
samples = [x.strip().split(' ') for x in f.readlines()]
for img_path, label in samples:
self.imgs.append(img_path)
self.labels.append(label)
def __len__(self):
return len(self.imgs)
def __getitem__(self, index):
image = Image.open(self.imgs[index])
if self.transform:
image = self.transform(image)
label = self.labels[index]
label = torch.tensor(int(label), dtype=torch.long)
return image, label
# 6. 加载数据集
train_dataset = food_dataset(file_path='./train.txt', transform=data_transforms['train'])
test_dataset = food_dataset(file_path='./test.txt', transform=data_transforms['valid'])
train_loader = DataLoader(train_dataset, batch_size=24, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=24, shuffle=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
model = resnet_model.to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params_to_updata, lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
# 7. 训练函数
def train(dataloader, model, loss_fn, optimizer):
model.train()
for x, y in dataloader:
x, y = x.to(device), y.to(device)
pred = model(x) # 推荐使用 model(x) 而不是 model.forward(x)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 8. 测试函数
best_acc = 0
acc_s = []
loss_s = []
def test(dataloader, model, loss_fn):
global best_acc
size = len(dataloader.dataset)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for x, y in dataloader:
x, y = x.to(device), y.to(device)
pred = model(x)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).sum().item()
avg_loss = test_loss / len(dataloader)
accuracy = correct / size
print(f'Test accuracy: {100*accuracy:.2f}%, avg loss: {avg_loss:.4f}')
acc_s.append(accuracy)
loss_s.append(avg_loss)
if accuracy > best_acc:
best_acc = accuracy
torch.save(model.state_dict(), 'best_model.pth') # 保存最佳模型
# 9. 开始训练
epochs = 10
for epoch in range(epochs):
print(f"Epoch {epoch+1}/{epochs}")
train(train_loader, model, loss_fn, optimizer)
scheduler.step()
test(test_loader, model, loss_fn)
print(f'最优训练结果(准确率): {best_acc*100:.2f}%')
三、核心知识点逐段解析
3.1 冻结预训练模型
for param in resnet_model.parameters():
param.requires_grad = False
-
将ResNet18的所有参数梯度设为
False,反向传播时不会计算这些参数的梯度,也就不会更新它们。 -
这样我们只训练新加的全连接层,极大减少计算量。
3.2 替换最后一层
resnet_model.fc = nn.Linear(in_features=512, out_features=20)
-
ResNet18的
fc层原是输出1000类(ImageNet),我们改成20类(自己的食物类别)。 -
in_features可以通过resnet_model.fc.in_features动态获取。
3.3 只收集可训练参数传给优化器
params_to_updata = [param for param in resnet_model.parameters() if param.requires_grad]
optimizer = torch.optim.Adam(params_to_updata, lr=0.001)
- 优化器只更新新fc层的参数,提高效率。
3.4 数据预处理与归一化
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
- 均值和标准差来自ImageNet数据集,必须使用这些值才能与预训练模型匹配。
3.5 自定义Dataset类
-
__init__:读取train.txt(每行格式:图片路径 标签),将路径和标签分别存入列表。 -
__getitem__:根据索引加载图片、应用transform、返回(image, label)元组。
3.6 训练/测试模式切换
model.train() # 训练模式(Dropout启用,BN用batch统计量)
model.eval() # 评估模式(Dropout关闭,BN用全局统计量)
3.7 学习率调度器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
- 每5个epoch将学习率乘以0.5,有助于模型收敛更稳定。
四、代码进一步优化建议
当前代码已经能够正确运行并完成迁移学习任务,以下优化建议可以让代码更健壮、更专业,适合实际项目或进一步分享。
4.1 保存最佳模型权重
在测试函数中,当发现更高的准确率时,建议保存模型参数,便于后续直接加载推理,无需重新训练。
if accuracy > best_acc:
best_acc = accuracy
torch.save(model.state_dict(), 'best_model.pth')
4.2 移除冗余的 CenterCrop
训练数据增强中的 CenterCrop(224) 在 Resize(224,224) 之后没有实际作用,可以删除或替换为 RandomResizedCrop(224) 以获得更强的随机裁剪效果。
transforms.Compose([
transforms.RandomResizedCrop(224), # 随机裁剪并缩放到224
transforms.RandomRotation(10),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.RandomGrayscale(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
4.3 统一使用 model(x) 代替 model.forward(x)
model(x) 是 PyTorch 推荐的调用方式,它会自动调用 forward 并处理钩子等内部逻辑,更规范。
4.4 为数据集类增加 convert('RGB')
某些图片可能是灰度图(L模式)或RGBA模式,为避免后续通道不匹配,建议在打开图片后统一转为RGB:
image = Image.open(self.imgs[index]).convert('RGB')
4.5 设置随机种子以保证结果可复现
在脚本开头添加:
torch.manual_seed(42)
np.random.seed(42)
4.6 添加数据路径的鲁棒性检查
在 __init__ 中可以检查 train.txt 文件是否存在,以及每行格式是否正确,给出友好的错误提示。
4.7 考虑使用验证集进行早停
当前代码只记录了测试集的最佳准确率,如果数据量允许,建议将训练集再划分为训练+验证集,根据验证集的表现来调整超参数或提前停止训练,避免在测试集上过拟合。
五、实验结果与总结
在20类食物数据集(假设每类约100-200张)上运行上述代码,通常能获得**75%~85%**的测试准确率,而从头训练相同模型可能只有40%~50%。这充分体现了迁移学习的威力。
关键收获
-
迁移学习是处理小数据集的利器,可以大幅提升性能并节省时间。
-
特征提取模式只需训练新分类头,适合数据量很少的情况。
-
数据预处理必须匹配预训练模型(尺寸224×224,ImageNet标准化)。
-
注意PyTorch中
model.train()和model.eval()的正确使用,否则会导致结果不稳定。 -
细心处理标签转换和准确率计算,避免低级bug。
下一步学习方向
-
尝试微调:解冻最后几层卷积层,用更低的学习率(如1e-5)继续训练。
-
使用更先进的预训练模型:ResNet50、EfficientNet、ViT等。
-
学习学习率调整策略 (ReduceLROnPlateau、CosineAnnealing)和早停技巧。
希望这篇文章能帮助你彻底理解迁移学习的PyTorch实现。如果你有任何问题,欢迎在评论区留言交流!
附:完整代码已整合为可直接运行的脚本,注意修改train.txt和test.txt的路径为你自己的数据集索引文件,并修正路径中的逗号笔误。