在深度学习领域,从零开始训练一个高性能模型通常需要海量数据(如 ImageNet 的 120 万张图片)和昂贵的计算资源。对于大多数实际应用场景,我们更倾向于使用迁移学习 (Transfer Learning)。
本篇笔记将结合 Day 44 的代码,深入剖析如何利用预训练的 ResNet18 模型,在 CIFAR-10 数据集上实现 86%+ 的高准确率。我们将重点拆解代码实现的每一个细节。
一、 数据准备:为模型提供高质量"燃料"
数据增强是提升模型泛化能力的关键。在迁移学习中,由于模型参数量较大(ResNet18 约 1100 万参数),而在小数据集(CIFAR-10 仅 5 万张训练图)上容易过拟合,因此强力的数据增强尤为重要。
1. 训练集增强策略 (代码详解)
train_transform = transforms.Compose([
# 1. 随机裁剪 (RandomCrop)
# 先在图像四周填充 4 个像素的 0 (padding=4),图像变大 (40x40)
# 然后随机裁剪出 32x32 的区域。
# 作用:让模型学习到物体在不同位置的特征,模拟物体平移。
transforms.RandomCrop(32, padding=4),
# 2. 随机水平翻转 (RandomHorizontalFlip)
# 以 50% 的概率水平翻转图像。
# 作用:模拟物体朝向的变化(如车头朝左或朝右),增加数据多样性。
transforms.RandomHorizontalFlip(),
# 3. 颜色抖动 (ColorJitter)
# 随机调整亮度(brightness)、对比度(contrast)、饱和度(saturation) 和色相(hue)。
# 作用:模拟不同光照条件下的物体,让模型对颜色变化不敏感。
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
# 4. 随机旋转 (RandomRotation)
# 在 -15度 到 +15度 之间随机旋转。
# 作用:模拟拍摄角度的微小偏差。
transforms.RandomRotation(15),
# 5. 转为 Tensor
# 将 PIL Image (0-255) 转换为 Tensor (0.0-1.0),并调整维度顺序 (HWC -> CHW)。
transforms.ToTensor(),
# 6. 标准化 (Normalize)
# 使用 CIFAR-10 数据集的均值和标准差进行归一化:(x - mean) / std
# 作用:加速收敛,使数据分布更符合模型假设。
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
2. 测试集处理
测试集只需进行必要的格式转换和标准化,严禁使用随机增强操作,以确保评估结果的稳定性与真实性。
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
二、 模型构建:ResNet18 的"换头"手术
我们使用 torchvision.models 提供的 ResNet18。由于预训练模型是在 ImageNet(1000 类)上训练的,我们需要修改其输出层(Head)以适配 CIFAR-10(10 类)。
代码实现与解析
from torchvision.models import resnet18
import torch.nn as nn
def create_resnet18(pretrained=True, num_classes=10):
# 1. 加载预训练模型
# pretrained=True: 自动下载并加载在 ImageNet 上训练好的权重。
# 这些权重包含了提取通用视觉特征(边缘、纹理、形状)的能力。
model = resnet18(pretrained=pretrained)
# 2. 修改全连接层 (Head)
# model.fc 是 ResNet 的最后一层全连接层。
# in_features: 获取原全连接层的输入维度(ResNet18 为 512)。
in_features = model.fc.in_features
# 3. 替换为新的全连接层
# 新层初始化时权重是随机的,输出维度设为 num_classes (10)。
# 注意:这一层没有预训练权重,需要从头训练。
model.fc = nn.Linear(in_features, num_classes)
# 4. 转移到 GPU
return model.to(device)
三、 训练策略:冻结与解冻 (Freeze & Unfreeze)
这是迁移学习中最核心的技巧。为了防止新初始化的全连接层(随机权重)产生的巨大梯度破坏预训练好的骨干网络(Backbone),我们通常采用分阶段训练。
1. 冻结控制函数
def freeze_model(model, freeze=True):
"""
控制模型参数的冻结与解冻
freeze=True: 冻结卷积层,只训练全连接层
freeze=False: 解冻所有层,进行全局微调
"""
# 遍历模型的所有参数(权重和偏置)
for name, param in model.named_parameters():
# 我们始终要训练全连接层 (fc),所以只冻结非 fc 层
if 'fc' not in name:
# requires_grad=False 表示该参数不计算梯度,也不会被优化器更新
param.requires_grad = not freeze
# (可选) 打印当前冻结状态,方便调试
frozen_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
print(f"当前冻结参数量: {frozen_params}")
return model
2. 分阶段训练逻辑
我们在 train_with_freeze_schedule 函数中实现了这一逻辑:
- 阶段一 (Epoch 0 ~ freeze_epochs-1) :
- 调用
freeze_model(model, freeze=True)。 - 此时,只有全连接层在更新。骨干网络充当一个固定的特征提取器。
- 目的:让全连接层的权重快速收敛到合理范围。
- 调用
- 阶段二 (Epoch >= freeze_epochs) :
-
调用
freeze_model(model, freeze=False)。 -
解冻所有参数。
-
降低学习率:通常将学习率降低 10 倍(如从 1e-3 降到 1e-4),以免破坏预训练的特征。
-
目的:让整个网络针对 CIFAR-10 的特征进行微调 (Fine-tuning),进一步提升性能。
伪代码演示阶段切换逻辑
if epoch == freeze_epochs:
print(">>> 解冻所有层,开始全局微调!")
model = freeze_model(model, freeze=False)
# 降低学习率,精细调整
optimizer.param_groups[0]['lr'] = 1e-4
-
四、 完整训练流程详解
以下是整合了上述所有模块的训练循环核心代码,每一行都有详细注释:
def train_one_epoch(model, loader, criterion, optimizer, device):
model.train() # 切换到训练模式(启用 Dropout 和 BatchNorm 更新)
running_loss = 0.0
correct = 0
total = 0
for batch_idx, (data, target) in enumerate(loader):
# 1. 数据迁移到 GPU
data, target = data.to(device), target.to(device)
# 2. 梯度清零 (标准步骤)
optimizer.zero_grad()
# 3. 前向传播
output = model(data)
# 4. 计算损失
loss = criterion(output, target)
# 5. 反向传播
# 如果是冻结阶段,只有 fc 层的参数会有梯度
loss.backward()
# 6. 参数更新
optimizer.step()
# --- 统计指标 ---
running_loss += loss.item()
_, predicted = output.max(1) # 获取预测类别
total += target.size(0)
correct += predicted.eq(target).sum().item()
return running_loss / len(loader), 100. * correct / total
五、 实验现象与经验总结
在运行 Day 44 的代码时,你会观察到几个有趣的现象:
- 起步即巅峰 :
- 即使在冻结阶段(前 5 个 Epoch),准确率也能迅速达到 70% 左右。这归功于 ResNet 强大的特征提取能力。
- 解冻后的飞跃 :
- 第 6 个 Epoch(解冻瞬间),准确率通常会有一个明显的提升,因为卷积层开始适应新数据集的特征分布(如 CIFAR-10 的低分辨率)。
- 训练集 vs 测试集准确率倒挂 :
- 现象:训练准确率 (Training Acc) 往往低于测试准确率 (Test Acc)。
- 原因:训练集使用了强力数据增强(裁剪、旋转、变色),模型看到的是"变态"难度的图片;而测试集是"标准"图片。模型就像是在负重训练(训练集),考试(测试集)时自然觉得轻松。
- 最终性能 :
- 经过 40 个 Epoch 的训练,ResNet18 在 CIFAR-10 上通常能达到 86% - 90% 的准确率,远超普通 CNN 的表现。