PyTorch迁移学习实战:用ResNet18实现20类食物图像分类(附代码详解)

一、迁移学习(Transfer Learning)详解

1. 什么是迁移学习?

迁移学习是一种机器学习方法,其核心思想是将从一个任务(源任务)中学到的知识,应用到另一个相关但不同的任务(目标任务)中。它模仿了人类的学习方式------我们学习骑自行车后,再学骑摩托车会容易很多;学会识别猫狗后,识别虎豹也会更轻松。

在深度学习中,迁移学习通常指:在大规模数据集(如ImageNet)上预训练一个深度神经网络,然后将该网络的权重作为初始值,在目标任务的小规模数据集上进行微调或特征提取。

2. 为什么需要迁移学习?

  • 数据不足:目标任务往往缺乏海量标注数据,而从头训练大模型容易过拟合。

  • 计算资源有限:从头训练ResNet等大型网络需要昂贵的GPU集群和时间成本。

  • 收敛更快:预训练模型已经学会了通用的特征(边缘、纹理、形状等),只需在特定任务上微调即可快速达到良好性能。

3. 迁移学习的常见策略

策略 做法 适用场景
特征提取 冻结预训练模型的所有层(不更新参数),仅将输出特征送入新的分类器(如全连接层)进行训练。 目标任务与源任务差异不大,且目标数据量非常小。
微调(Fine-tuning) 解冻部分或全部预训练层,允许参数在目标任务上继续更新。通常先用较小的学习率微调高层(靠近输出的层),再逐步微调低层。 目标任务与源任务有一定差异,且数据量中等(几百到几千张)。
渐进式微调 先冻结所有层训练新分类头,再逐步解冻更多层联合训练。 需要平衡训练速度和泛化能力。

4. 关键注意事项

  • 数据分布差异:若源任务(ImageNet)与目标任务(如医学影像)差异巨大,迁移收益可能有限,需考虑是否从零训练或使用域自适应技术。

  • 学习率选择:微调时,预训练层的学习率通常设为新层的1/10甚至更小,避免破坏学到的通用特征。

  • 过拟合风险:尽管迁移学习能缓解过拟合,但若目标数据集仍然极小,应使用更强的正则化(Dropout、权重衰减、数据增强等)。


二、完整代码

复制代码
import torch
import torchvision.models as models
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np

# 1. 加载预训练模型并冻结所有层
resnet_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
for param in resnet_model.parameters():
    param.requires_grad = False

# 2. 替换全连接层(输出20类)
in_features = resnet_model.fc.in_features
resnet_model.fc = nn.Linear(in_features=in_features, out_features=20)

# 3. 收集需要更新的参数(仅新fc层)
params_to_updata = []
for param in resnet_model.parameters():
    if param.requires_grad == True:
        params_to_updata.append(param)

# 4. 定义数据预处理
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomRotation(10),
        transforms.CenterCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomGrayscale(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'valid': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}

# 5. 自定义Dataset类
class food_dataset(Dataset):
    def __init__(self, file_path, transform=None):
        self.file_path = file_path
        self.imgs = []
        self.labels = []
        self.transform = transform
        with open(self.file_path) as f:
            samples = [x.strip().split(' ') for x in f.readlines()]
            for img_path, label in samples:
                self.imgs.append(img_path)
                self.labels.append(label)

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, index):
        image = Image.open(self.imgs[index])
        if self.transform:
            image = self.transform(image)
        label = self.labels[index]
        label = torch.tensor(int(label), dtype=torch.long)
        return image, label

# 6. 加载数据集
train_dataset = food_dataset(file_path='./train.txt', transform=data_transforms['train'])
test_dataset = food_dataset(file_path='./test.txt', transform=data_transforms['valid'])

train_loader = DataLoader(train_dataset, batch_size=24, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=24, shuffle=True)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

model = resnet_model.to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params_to_updata, lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

# 7. 训练函数
def train(dataloader, model, loss_fn, optimizer):
    model.train()
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        pred = model(x)   # 推荐使用 model(x) 而不是 model.forward(x)
        loss = loss_fn(pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# 8. 测试函数
best_acc = 0
acc_s = []
loss_s = []

def test(dataloader, model, loss_fn):
    global best_acc
    size = len(dataloader.dataset)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            pred = model(x)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).sum().item()
    avg_loss = test_loss / len(dataloader)
    accuracy = correct / size
    print(f'Test accuracy: {100*accuracy:.2f}%, avg loss: {avg_loss:.4f}')
    
    acc_s.append(accuracy)
    loss_s.append(avg_loss)
    
    if accuracy > best_acc:
        best_acc = accuracy
        torch.save(model.state_dict(), 'best_model.pth')  # 保存最佳模型

# 9. 开始训练
epochs = 10
for epoch in range(epochs):
    print(f"Epoch {epoch+1}/{epochs}")
    train(train_loader, model, loss_fn, optimizer)
    scheduler.step()
    test(test_loader, model, loss_fn)

print(f'最优训练结果(准确率): {best_acc*100:.2f}%')

三、核心知识点逐段解析

3.1 冻结预训练模型

复制代码
for param in resnet_model.parameters():
    param.requires_grad = False
  • 将ResNet18的所有参数梯度设为False,反向传播时不会计算这些参数的梯度,也就不会更新它们。

  • 这样我们只训练新加的全连接层,极大减少计算量。

3.2 替换最后一层

复制代码
resnet_model.fc = nn.Linear(in_features=512, out_features=20)
  • ResNet18的fc层原是输出1000类(ImageNet),我们改成20类(自己的食物类别)。

  • in_features可以通过resnet_model.fc.in_features动态获取。

3.3 只收集可训练参数传给优化器

复制代码
params_to_updata = [param for param in resnet_model.parameters() if param.requires_grad]
optimizer = torch.optim.Adam(params_to_updata, lr=0.001)
  • 优化器只更新新fc层的参数,提高效率。

3.4 数据预处理与归一化

复制代码
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  • 均值和标准差来自ImageNet数据集,必须使用这些值才能与预训练模型匹配。

3.5 自定义Dataset类

  • __init__:读取train.txt(每行格式:图片路径 标签),将路径和标签分别存入列表。

  • __getitem__:根据索引加载图片、应用transform、返回(image, label)元组。

3.6 训练/测试模式切换

复制代码
model.train()   # 训练模式(Dropout启用,BN用batch统计量)
model.eval()    # 评估模式(Dropout关闭,BN用全局统计量)

3.7 学习率调度器

复制代码
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
  • 每5个epoch将学习率乘以0.5,有助于模型收敛更稳定。

四、代码进一步优化建议

当前代码已经能够正确运行并完成迁移学习任务,以下优化建议可以让代码更健壮、更专业,适合实际项目或进一步分享。

4.1 保存最佳模型权重

在测试函数中,当发现更高的准确率时,建议保存模型参数,便于后续直接加载推理,无需重新训练。

复制代码
if accuracy > best_acc:
    best_acc = accuracy
    torch.save(model.state_dict(), 'best_model.pth')

4.2 移除冗余的 CenterCrop

训练数据增强中的 CenterCrop(224)Resize(224,224) 之后没有实际作用,可以删除或替换为 RandomResizedCrop(224) 以获得更强的随机裁剪效果。

复制代码
transforms.Compose([
    transforms.RandomResizedCrop(224),  # 随机裁剪并缩放到224
    transforms.RandomRotation(10),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomGrayscale(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

4.3 统一使用 model(x) 代替 model.forward(x)

model(x) 是 PyTorch 推荐的调用方式,它会自动调用 forward 并处理钩子等内部逻辑,更规范。

4.4 为数据集类增加 convert('RGB')

某些图片可能是灰度图(L模式)或RGBA模式,为避免后续通道不匹配,建议在打开图片后统一转为RGB:

复制代码
image = Image.open(self.imgs[index]).convert('RGB')

4.5 设置随机种子以保证结果可复现

在脚本开头添加:

复制代码
torch.manual_seed(42)
np.random.seed(42)

4.6 添加数据路径的鲁棒性检查

__init__ 中可以检查 train.txt 文件是否存在,以及每行格式是否正确,给出友好的错误提示。

4.7 考虑使用验证集进行早停

当前代码只记录了测试集的最佳准确率,如果数据量允许,建议将训练集再划分为训练+验证集,根据验证集的表现来调整超参数或提前停止训练,避免在测试集上过拟合。


五、实验结果与总结

在20类食物数据集(假设每类约100-200张)上运行上述代码,通常能获得**75%~85%**的测试准确率,而从头训练相同模型可能只有40%~50%。这充分体现了迁移学习的威力。

关键收获

  1. 迁移学习是处理小数据集的利器,可以大幅提升性能并节省时间。

  2. 特征提取模式只需训练新分类头,适合数据量很少的情况。

  3. 数据预处理必须匹配预训练模型(尺寸224×224,ImageNet标准化)。

  4. 注意PyTorch中model.train()model.eval()的正确使用,否则会导致结果不稳定。

  5. 细心处理标签转换和准确率计算,避免低级bug。

下一步学习方向

  • 尝试微调:解冻最后几层卷积层,用更低的学习率(如1e-5)继续训练。

  • 使用更先进的预训练模型:ResNet50、EfficientNet、ViT等。

  • 学习学习率调整策略 (ReduceLROnPlateau、CosineAnnealing)和早停技巧。

希望这篇文章能帮助你彻底理解迁移学习的PyTorch实现。如果你有任何问题,欢迎在评论区留言交流!


附:完整代码已整合为可直接运行的脚本,注意修改train.txttest.txt的路径为你自己的数据集索引文件,并修正路径中的逗号笔误。

相关推荐
ForDreamMusk6 小时前
PyTorch编程基础
人工智能·pytorch
郝学胜-神的一滴7 小时前
神经网络参数初始化:从梯度失控到模型收敛的核心密码
人工智能·pytorch·深度学习·神经网络·机器学习·软件构建·软件设计
机器学习之心7 小时前
一键替换数据集!基于PSO多目标优化与SHAP可解释分析的回归预测神器来了PyTorch构建
pytorch·回归·可解释分析·pso多目标优化
深念Y8 小时前
感知机 ≈ 可学习的逻辑门?聊聊激活函数与二元分类的本质
人工智能·学习·分类·感知机·激活函数·逻辑门·二元分类
配奇10 小时前
PyTorch 核心使用
人工智能·pytorch·python
roman_日积跬步-终至千里10 小时前
【深度学习】国科大:CIFAR-100 图像分类项目
人工智能·深度学习·分类
墨心@13 小时前
pytorch 与资源核算
pytorch·语言模型·大语言模型·datawhale·组队学习
李昊哲小课13 小时前
WSL Ubuntu 24.04 GPU 加速环境完整安装指南
c++·pytorch·深度学习·ubuntu·cuda·tensorflow2
渡我白衣14 小时前
触类旁通——迁移学习、多任务学习与元学习
人工智能·深度学习·神经网络·学习·机器学习·迁移学习·caffe