Day 44 预训练模型与迁移学习

在深度学习领域,从零开始训练一个高性能模型通常需要海量数据(如 ImageNet 的 120 万张图片)和昂贵的计算资源。对于大多数实际应用场景,我们更倾向于使用迁移学习 (Transfer Learning)

本篇笔记将结合 Day 44 的代码,深入剖析如何利用预训练的 ResNet18 模型,在 CIFAR-10 数据集上实现 86%+ 的高准确率。我们将重点拆解代码实现的每一个细节。


一、 数据准备:为模型提供高质量"燃料"

数据增强是提升模型泛化能力的关键。在迁移学习中,由于模型参数量较大(ResNet18 约 1100 万参数),而在小数据集(CIFAR-10 仅 5 万张训练图)上容易过拟合,因此强力的数据增强尤为重要。

1. 训练集增强策略 (代码详解)

复制代码
train_transform = transforms.Compose([
    # 1. 随机裁剪 (RandomCrop)
    # 先在图像四周填充 4 个像素的 0 (padding=4),图像变大 (40x40)
    # 然后随机裁剪出 32x32 的区域。
    # 作用:让模型学习到物体在不同位置的特征,模拟物体平移。
    transforms.RandomCrop(32, padding=4),
    
    # 2. 随机水平翻转 (RandomHorizontalFlip)
    # 以 50% 的概率水平翻转图像。
    # 作用:模拟物体朝向的变化(如车头朝左或朝右),增加数据多样性。
    transforms.RandomHorizontalFlip(),
    
    # 3. 颜色抖动 (ColorJitter)
    # 随机调整亮度(brightness)、对比度(contrast)、饱和度(saturation) 和色相(hue)。
    # 作用:模拟不同光照条件下的物体,让模型对颜色变化不敏感。
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    
    # 4. 随机旋转 (RandomRotation)
    # 在 -15度 到 +15度 之间随机旋转。
    # 作用:模拟拍摄角度的微小偏差。
    transforms.RandomRotation(15),
    
    # 5. 转为 Tensor
    # 将 PIL Image (0-255) 转换为 Tensor (0.0-1.0),并调整维度顺序 (HWC -> CHW)。
    transforms.ToTensor(),
    
    # 6. 标准化 (Normalize)
    # 使用 CIFAR-10 数据集的均值和标准差进行归一化:(x - mean) / std
    # 作用:加速收敛,使数据分布更符合模型假设。
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

2. 测试集处理

测试集只需进行必要的格式转换和标准化,严禁使用随机增强操作,以确保评估结果的稳定性与真实性。

复制代码
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

二、 模型构建:ResNet18 的"换头"手术

我们使用 torchvision.models 提供的 ResNet18。由于预训练模型是在 ImageNet(1000 类)上训练的,我们需要修改其输出层(Head)以适配 CIFAR-10(10 类)。

代码实现与解析

复制代码
from torchvision.models import resnet18
import torch.nn as nn

def create_resnet18(pretrained=True, num_classes=10):
    # 1. 加载预训练模型
    # pretrained=True: 自动下载并加载在 ImageNet 上训练好的权重。
    # 这些权重包含了提取通用视觉特征(边缘、纹理、形状)的能力。
    model = resnet18(pretrained=pretrained)
    
    # 2. 修改全连接层 (Head)
    # model.fc 是 ResNet 的最后一层全连接层。
    # in_features: 获取原全连接层的输入维度(ResNet18 为 512)。
    in_features = model.fc.in_features
    
    # 3. 替换为新的全连接层
    # 新层初始化时权重是随机的,输出维度设为 num_classes (10)。
    # 注意:这一层没有预训练权重,需要从头训练。
    model.fc = nn.Linear(in_features, num_classes)
    
    # 4. 转移到 GPU
    return model.to(device)

三、 训练策略:冻结与解冻 (Freeze & Unfreeze)

这是迁移学习中最核心的技巧。为了防止新初始化的全连接层(随机权重)产生的巨大梯度破坏预训练好的骨干网络(Backbone),我们通常采用分阶段训练

1. 冻结控制函数

复制代码
def freeze_model(model, freeze=True):
    """
    控制模型参数的冻结与解冻
    freeze=True: 冻结卷积层,只训练全连接层
    freeze=False: 解冻所有层,进行全局微调
    """
    # 遍历模型的所有参数(权重和偏置)
    for name, param in model.named_parameters():
        # 我们始终要训练全连接层 (fc),所以只冻结非 fc 层
        if 'fc' not in name:
            # requires_grad=False 表示该参数不计算梯度,也不会被优化器更新
            param.requires_grad = not freeze
            
    # (可选) 打印当前冻结状态,方便调试
    frozen_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
    print(f"当前冻结参数量: {frozen_params}")
    return model

2. 分阶段训练逻辑

我们在 train_with_freeze_schedule 函数中实现了这一逻辑:

  • 阶段一 (Epoch 0 ~ freeze_epochs-1)
    • 调用 freeze_model(model, freeze=True)
    • 此时,只有全连接层在更新。骨干网络充当一个固定的特征提取器。
    • 目的:让全连接层的权重快速收敛到合理范围。
  • 阶段二 (Epoch >= freeze_epochs)
    • 调用 freeze_model(model, freeze=False)

    • 解冻所有参数

    • 降低学习率:通常将学习率降低 10 倍(如从 1e-3 降到 1e-4),以免破坏预训练的特征。

    • 目的:让整个网络针对 CIFAR-10 的特征进行微调 (Fine-tuning),进一步提升性能。

      伪代码演示阶段切换逻辑

      if epoch == freeze_epochs:
      print(">>> 解冻所有层,开始全局微调!")
      model = freeze_model(model, freeze=False)
      # 降低学习率,精细调整
      optimizer.param_groups[0]['lr'] = 1e-4


四、 完整训练流程详解

以下是整合了上述所有模块的训练循环核心代码,每一行都有详细注释:

复制代码
def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train() # 切换到训练模式(启用 Dropout 和 BatchNorm 更新)
    running_loss = 0.0
    correct = 0
    total = 0
    
    for batch_idx, (data, target) in enumerate(loader):
        # 1. 数据迁移到 GPU
        data, target = data.to(device), target.to(device)
        
        # 2. 梯度清零 (标准步骤)
        optimizer.zero_grad()
        
        # 3. 前向传播
        output = model(data)
        
        # 4. 计算损失
        loss = criterion(output, target)
        
        # 5. 反向传播
        # 如果是冻结阶段,只有 fc 层的参数会有梯度
        loss.backward()
        
        # 6. 参数更新
        optimizer.step()
        
        # --- 统计指标 ---
        running_loss += loss.item()
        _, predicted = output.max(1) # 获取预测类别
        total += target.size(0)
        correct += predicted.eq(target).sum().item()
        
    return running_loss / len(loader), 100. * correct / total

五、 实验现象与经验总结

在运行 Day 44 的代码时,你会观察到几个有趣的现象:

  1. 起步即巅峰
    • 即使在冻结阶段(前 5 个 Epoch),准确率也能迅速达到 70% 左右。这归功于 ResNet 强大的特征提取能力。
  2. 解冻后的飞跃
    • 第 6 个 Epoch(解冻瞬间),准确率通常会有一个明显的提升,因为卷积层开始适应新数据集的特征分布(如 CIFAR-10 的低分辨率)。
  3. 训练集 vs 测试集准确率倒挂
    • 现象:训练准确率 (Training Acc) 往往低于测试准确率 (Test Acc)。
    • 原因:训练集使用了强力数据增强(裁剪、旋转、变色),模型看到的是"变态"难度的图片;而测试集是"标准"图片。模型就像是在负重训练(训练集),考试(测试集)时自然觉得轻松。
  4. 最终性能
    • 经过 40 个 Epoch 的训练,ResNet18 在 CIFAR-10 上通常能达到 86% - 90% 的准确率,远超普通 CNN 的表现。
相关推荐
AI产品测评官2 小时前
2025年深度观察:技术招聘的“数据孤岛”效应与AI智能体的破局之道
人工智能
Deepoch2 小时前
面向AI算力瓶颈的光电混合计算路径探析
人工智能·光电·deepoc
m0_462605222 小时前
第N9周:seq2seq翻译实战-Pytorch复现-小白版
人工智能·pytorch·python
百***24372 小时前
GPT5.1 vs Gemini 3.0 Pro 全维度对比及快速接入实战
大数据·人工智能·gpt
乾元2 小时前
基于时序数据的异常预测——短期容量与拥塞的提前感知
运维·开发语言·网络·人工智能·python·自动化·运维开发
Elastic 中国社区官方博客2 小时前
Elasticsearch:构建一个 AI 驱动的电子邮件钓鱼检测
大数据·数据库·人工智能·elasticsearch·搜索引擎·ai·全文检索
IT_陈寒2 小时前
Vite 5大优化技巧:让你的构建速度飙升50%,开发者都在偷偷用!
前端·人工智能·后端
l1t2 小时前
利用DeepSeek计算abcde五人排成一队,要使c在ab 之间,有几种排法
人工智能·组合数学·deepseek
阿拉斯攀登2 小时前
电子签名:笔迹特征比对核心算法详解
人工智能·算法·机器学习·电子签名·汉王