Day 44 预训练模型与迁移学习

在深度学习领域,从零开始训练一个高性能模型通常需要海量数据(如 ImageNet 的 120 万张图片)和昂贵的计算资源。对于大多数实际应用场景,我们更倾向于使用迁移学习 (Transfer Learning)

本篇笔记将结合 Day 44 的代码,深入剖析如何利用预训练的 ResNet18 模型,在 CIFAR-10 数据集上实现 86%+ 的高准确率。我们将重点拆解代码实现的每一个细节。


一、 数据准备:为模型提供高质量"燃料"

数据增强是提升模型泛化能力的关键。在迁移学习中,由于模型参数量较大(ResNet18 约 1100 万参数),而在小数据集(CIFAR-10 仅 5 万张训练图)上容易过拟合,因此强力的数据增强尤为重要。

1. 训练集增强策略 (代码详解)

复制代码
train_transform = transforms.Compose([
    # 1. 随机裁剪 (RandomCrop)
    # 先在图像四周填充 4 个像素的 0 (padding=4),图像变大 (40x40)
    # 然后随机裁剪出 32x32 的区域。
    # 作用:让模型学习到物体在不同位置的特征,模拟物体平移。
    transforms.RandomCrop(32, padding=4),
    
    # 2. 随机水平翻转 (RandomHorizontalFlip)
    # 以 50% 的概率水平翻转图像。
    # 作用:模拟物体朝向的变化(如车头朝左或朝右),增加数据多样性。
    transforms.RandomHorizontalFlip(),
    
    # 3. 颜色抖动 (ColorJitter)
    # 随机调整亮度(brightness)、对比度(contrast)、饱和度(saturation) 和色相(hue)。
    # 作用:模拟不同光照条件下的物体,让模型对颜色变化不敏感。
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    
    # 4. 随机旋转 (RandomRotation)
    # 在 -15度 到 +15度 之间随机旋转。
    # 作用:模拟拍摄角度的微小偏差。
    transforms.RandomRotation(15),
    
    # 5. 转为 Tensor
    # 将 PIL Image (0-255) 转换为 Tensor (0.0-1.0),并调整维度顺序 (HWC -> CHW)。
    transforms.ToTensor(),
    
    # 6. 标准化 (Normalize)
    # 使用 CIFAR-10 数据集的均值和标准差进行归一化:(x - mean) / std
    # 作用:加速收敛,使数据分布更符合模型假设。
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

2. 测试集处理

测试集只需进行必要的格式转换和标准化,严禁使用随机增强操作,以确保评估结果的稳定性与真实性。

复制代码
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

二、 模型构建:ResNet18 的"换头"手术

我们使用 torchvision.models 提供的 ResNet18。由于预训练模型是在 ImageNet(1000 类)上训练的,我们需要修改其输出层(Head)以适配 CIFAR-10(10 类)。

代码实现与解析

复制代码
from torchvision.models import resnet18
import torch.nn as nn

def create_resnet18(pretrained=True, num_classes=10):
    # 1. 加载预训练模型
    # pretrained=True: 自动下载并加载在 ImageNet 上训练好的权重。
    # 这些权重包含了提取通用视觉特征(边缘、纹理、形状)的能力。
    model = resnet18(pretrained=pretrained)
    
    # 2. 修改全连接层 (Head)
    # model.fc 是 ResNet 的最后一层全连接层。
    # in_features: 获取原全连接层的输入维度(ResNet18 为 512)。
    in_features = model.fc.in_features
    
    # 3. 替换为新的全连接层
    # 新层初始化时权重是随机的,输出维度设为 num_classes (10)。
    # 注意:这一层没有预训练权重,需要从头训练。
    model.fc = nn.Linear(in_features, num_classes)
    
    # 4. 转移到 GPU
    return model.to(device)

三、 训练策略:冻结与解冻 (Freeze & Unfreeze)

这是迁移学习中最核心的技巧。为了防止新初始化的全连接层(随机权重)产生的巨大梯度破坏预训练好的骨干网络(Backbone),我们通常采用分阶段训练

1. 冻结控制函数

复制代码
def freeze_model(model, freeze=True):
    """
    控制模型参数的冻结与解冻
    freeze=True: 冻结卷积层,只训练全连接层
    freeze=False: 解冻所有层,进行全局微调
    """
    # 遍历模型的所有参数(权重和偏置)
    for name, param in model.named_parameters():
        # 我们始终要训练全连接层 (fc),所以只冻结非 fc 层
        if 'fc' not in name:
            # requires_grad=False 表示该参数不计算梯度,也不会被优化器更新
            param.requires_grad = not freeze
            
    # (可选) 打印当前冻结状态,方便调试
    frozen_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
    print(f"当前冻结参数量: {frozen_params}")
    return model

2. 分阶段训练逻辑

我们在 train_with_freeze_schedule 函数中实现了这一逻辑:

  • 阶段一 (Epoch 0 ~ freeze_epochs-1)
    • 调用 freeze_model(model, freeze=True)
    • 此时,只有全连接层在更新。骨干网络充当一个固定的特征提取器。
    • 目的:让全连接层的权重快速收敛到合理范围。
  • 阶段二 (Epoch >= freeze_epochs)
    • 调用 freeze_model(model, freeze=False)

    • 解冻所有参数

    • 降低学习率:通常将学习率降低 10 倍(如从 1e-3 降到 1e-4),以免破坏预训练的特征。

    • 目的:让整个网络针对 CIFAR-10 的特征进行微调 (Fine-tuning),进一步提升性能。

      伪代码演示阶段切换逻辑

      if epoch == freeze_epochs:
      print(">>> 解冻所有层,开始全局微调!")
      model = freeze_model(model, freeze=False)
      # 降低学习率,精细调整
      optimizer.param_groups[0]['lr'] = 1e-4


四、 完整训练流程详解

以下是整合了上述所有模块的训练循环核心代码,每一行都有详细注释:

复制代码
def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train() # 切换到训练模式(启用 Dropout 和 BatchNorm 更新)
    running_loss = 0.0
    correct = 0
    total = 0
    
    for batch_idx, (data, target) in enumerate(loader):
        # 1. 数据迁移到 GPU
        data, target = data.to(device), target.to(device)
        
        # 2. 梯度清零 (标准步骤)
        optimizer.zero_grad()
        
        # 3. 前向传播
        output = model(data)
        
        # 4. 计算损失
        loss = criterion(output, target)
        
        # 5. 反向传播
        # 如果是冻结阶段,只有 fc 层的参数会有梯度
        loss.backward()
        
        # 6. 参数更新
        optimizer.step()
        
        # --- 统计指标 ---
        running_loss += loss.item()
        _, predicted = output.max(1) # 获取预测类别
        total += target.size(0)
        correct += predicted.eq(target).sum().item()
        
    return running_loss / len(loader), 100. * correct / total

五、 实验现象与经验总结

在运行 Day 44 的代码时,你会观察到几个有趣的现象:

  1. 起步即巅峰
    • 即使在冻结阶段(前 5 个 Epoch),准确率也能迅速达到 70% 左右。这归功于 ResNet 强大的特征提取能力。
  2. 解冻后的飞跃
    • 第 6 个 Epoch(解冻瞬间),准确率通常会有一个明显的提升,因为卷积层开始适应新数据集的特征分布(如 CIFAR-10 的低分辨率)。
  3. 训练集 vs 测试集准确率倒挂
    • 现象:训练准确率 (Training Acc) 往往低于测试准确率 (Test Acc)。
    • 原因:训练集使用了强力数据增强(裁剪、旋转、变色),模型看到的是"变态"难度的图片;而测试集是"标准"图片。模型就像是在负重训练(训练集),考试(测试集)时自然觉得轻松。
  4. 最终性能
    • 经过 40 个 Epoch 的训练,ResNet18 在 CIFAR-10 上通常能达到 86% - 90% 的准确率,远超普通 CNN 的表现。
相关推荐
NAGNIP9 小时前
一文搞懂深度学习中的通用逼近定理!
人工智能·算法·面试
冬奇Lab10 小时前
一天一个开源项目(第36篇):EverMemOS - 跨 LLM 与平台的长时记忆 OS,让 Agent 会记忆更会推理
人工智能·开源·资讯
冬奇Lab10 小时前
OpenClaw 源码深度解析(一):Gateway——为什么需要一个"中枢"
人工智能·开源·源码阅读
AngelPP13 小时前
OpenClaw 架构深度解析:如何把 AI 助手搬到你的个人设备上
人工智能
宅小年14 小时前
Claude Code 换成了Kimi K2.5后,我再也回不去了
人工智能·ai编程·claude
九狼14 小时前
Flutter URL Scheme 跨平台跳转
人工智能·flutter·github
ZFSS14 小时前
Kimi Chat Completion API 申请及使用
前端·人工智能
天翼云开发者社区15 小时前
春节复工福利就位!天翼云息壤2500万Tokens免费送,全品类大模型一键畅玩!
人工智能·算力服务·息壤
知识浅谈15 小时前
教你如何用 Gemini 将课本图片一键转为精美 PPT
人工智能
Ray Liang16 小时前
被低估的量化版模型,小身材也能干大事
人工智能·ai·ai助手·mindx