【Python学习打卡-Day44】站在巨人的肩膀上:玩转PyTorch预训练模型与迁移学习

📋 前言

各位伙伴们,大家好!如果说之前我们从零开始搭建 CNN 是在"手动锻造兵器",那么今天,我们将学习如何 wield (挥舞) 一把已经由无数大师千锤百炼过的"神器"------预训练模型 。Day 44 的主题是迁移学习 (Transfer Learning),这是一个能让我们以极低的成本,在自己的数据集上达到惊人效果的强大范式。

这不仅仅是一种技术,更是一种思想上的飞跃。它告诉我们,不必事事从零开始,善于利用前人的成果,才能更快地解决新问题。准备好,让我们一起站上巨人的肩膀,看得更远!


一、思想的飞跃:为什么要用预训练模型?

在深入代码之前,我们必须先理解其背后的深刻思想。为什么我们不总是从一个随机初始化的模型开始训练呢?

想象一下:从零训练一个模型,就像教一个刚出生的婴儿认识世界。你需要海量的数据和漫长的时间,他才能慢慢学会识别物体的边缘、纹理、形状,最后才能区分"猫"和"狗"。

而使用预训练模型,则像是请来一位在 ImageNet (一个包含上百万张图片、上千个分类的超大数据集) 上身经百战的"视觉专家"。这位专家已经具备了强大的、通用的视觉特征提取能力(比如识别纹理、轮廓、甚至是动物的眼睛和皮毛)。

我们的任务(比如在 CIFAR-10 上分类)对他来说,就像是给他一个新的、更具体的任务。我们不再需要从零教他看世界,只需要在他已有的知识基础上,进行微调 (Fine-tuning),让他适应我们的新任务即可。

这就是迁移学习的核心:

  • 起点更高:模型初始参数不是随机的,而是包含了大量通用知识,这使得训练起点非常高。
  • 收敛更快:因为起点高,模型能更快地找到解决新问题的方向,大大缩短了训练时间。
  • 性能更好:尤其是在我们自己的数据集不够大的情况下,预训练模型带来的通用知识可以有效防止过拟合,达到更高的准确率。

二、群星闪耀时:认识经典的预训练模型

深度学习的历史长河中,诞生了许多里程碑式的CNN架构。它们不仅在当年的 ImageNet 竞赛中大放异彩,也成为了后来无数研究的基础。

模型 年份 关键创新点 ImageNet Top-5错误率 特点
AlexNet 2012 ReLU激活函数、Dropout、GPU训练 15.3% 深度学习复兴的开创者
VGGNet 2014 统一3×3卷积核,结构简洁优美 7.3% "大力出奇迹"的典范,至今仍是很好的基线
GoogLeNet 2014 Inception模块,网络中的网络 6.7% 在保证精度的同时,极大降低了参数量
ResNet 2015 残差连接,解决了超深网络梯度消失问题 3.57% 影响至今的革命性架构,开启百层网络时代
MobileNet 2017 深度可分离卷积,极致轻量化 7.4% 移动端部署的首选
EfficientNet 2019 复合缩放,自动搜索最佳网络配置 2.6% 精度与效率的极致平衡

这些模型就像一个个武林门派,各有绝学。而我们今天要实战的 ResNet,无疑是其中名声最响、应用最广的"名门正派"。


三、实战演练:用ResNet18微调CIFAR-10分类任务

理论讲完,开练!我们将使用在 ImageNet 上预训练好的 ResNet18 模型,来解决我们的 CIFAR-10 分类问题。

3.1 核心策略:"先冻结,后解冻"

微调预训练模型有一个非常经典的策略,就像驯服一匹烈马:

  1. 第一阶段:冻结主干,只训"头"

    • 冻结 (Freeze) :我们将预训练模型中负责提取特征的卷积层(称为 Backbone 或主干网络)的参数全部冻结,使其在训练中不更新。
    • 训练 (Train) :我们只训练我们自己新加上去的分类层(称为 Head 或分类头)。
    • 目的:这是为了让新的分类头先快速适应主干网络输出的特征,而不会因为随机初始化的分类头产生巨大的梯度,从而破坏掉宝贵的预训练权重。
  2. 第二阶段:解冻全身,整体微调

    • 解冻 (Unfreeze):在分类头训练几轮稳定后,我们将整个网络的参数全部解冻。
    • 微调 (Fine-tune) :使用一个非常小的学习率,对整个网络进行训练。
    • 目的:让整个网络,包括主干,都对我们的新数据进行微小的调整,使其更加"专精"于我们的任务。

3.2 代码实现与分析

下面是本次实战的核心代码,它完美地实现了上述"两阶段"训练策略。

python 复制代码
# 【我的代码】
# Day 44 作业: 使用预训练的ResNet18在CIFAR-10上进行微调

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# --- 配置 ---
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

# --- 1. 数据加载与增强 ---
# (与之前代码相同,此处省略)
train_transform = transforms.Compose([...])
test_transform = transforms.Compose([...])
train_dataset = datasets.CIFAR10(...)
test_dataset = datasets.CIFAR10(...)
train_loader = DataLoader(...)
test_loader = DataLoader(...)

# --- 2. 模型定义与修改 ---
def create_resnet18(pretrained=True, num_classes=10):
    # a. 加载预训练的ResNet18模型
    model = models.resnet18(pretrained=pretrained)
    
    # b. 替换最后一层全连接层以适应我们的10分类任务
    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, num_classes)
    
    return model.to(device)

# --- 3. 冻结/解冻层的工具函数 ---
def freeze_model(model, freeze=True):
    # 冻结/解冻除fc层外的所有参数
    for name, param in model.named_parameters():
        if 'fc' not in name:
            param.requires_grad = not freeze
    
    # 打印状态,方便观察
    if freeze:
        print("已冻结模型卷积层参数 (只训练分类头)")
    else:
        print("已解冻模型所有参数 (进行整体微调)")
    return model

# --- 4. 包含"两阶段"逻辑的训练函数 ---
def train_with_freeze_schedule(model, ..., epochs, freeze_epochs=5):
    # 初始冻结卷积层
    if freeze_epochs > 0:
        model = freeze_model(model, freeze=True)
    
    for epoch in range(epochs):
        # 在指定轮次后解冻所有层
        if epoch == freeze_epochs:
            model = freeze_model(model, freeze=False)
            # 解冻后通常需要降低学习率
            optimizer.param_groups[0]['lr'] = 1e-4  
        
        # ... (标准的训练和验证循环) ...
    # ... (省略具体循环代码,与笔记中一致)
    
# --- 5. 主函数 ---
def main():
    epochs = 40
    freeze_epochs = 5 # 前5轮冻结
    
    model = create_resnet18(pretrained=True, num_classes=10)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(...)
    
    train_with_freeze_schedule(model, ..., epochs=epochs, freeze_epochs=freeze_epochs)

if __name__ == "__main__":
    main()

四、见证奇迹:结果分析与洞察

从训练日志中,我们可以清晰地看到迁移学习的巨大威力:

  1. 火箭般的启动速度 :观察日志,在第5个epoch结束、刚刚解冻所有层 时,训练在第6个epoch的测试准确率瞬间从34%暴涨到67.23%!这正是预训练权重强大通用能力的体现。相比之下,我们从零训练的CNN需要20多个epoch才能达到类似水平。

  2. "谦虚"的训练过程 :在前期,我们可能会观察到训练集准确率低于测试集的现象。这是因为我们的训练集应用了大量的数据增强(随机裁剪、翻转、颜色抖动等),相当于给模型出了"难题";而测试集是标准的"考卷"。这证明了数据增强的有效性,让模型学习得更鲁棒。

  3. 突破性能瓶颈 :最终,我们的模型在测试集上达到了 86.30% 的准确率!这比我们之前从零训练的CNN(约80%)高出了一大截。这就是"巨人肩膀"的力量,让我们轻松地突破了自己模型的性能上限。


五、作业探索:百尺竿头,更进一步

今天的作业是对我们学习成果的最好检验。

作业一:尝试其他预训练模型(如MobileNetV2)

ResNet虽好,但如果我们的应用场景是手机等移动设备,可能就显得有些"重"了。这时,轻量级网络 MobileNetV2 就派上了用场。替换模型非常简单:

python 复制代码
def create_mobilenet_v2(pretrained=True, num_classes=10):
    model = models.mobilenet_v2(pretrained=pretrained)
    
    # 注意:MobileNetV2的分类器层名叫 'classifier'
    in_features = model.classifier[1].in_features
    model.classifier[1] = nn.Linear(in_features, num_classes)
    
    return model.to(device)

# 在main函数中调用即可
# model = create_mobilenet_v2(pretrained=True, num_classes=10)

通过尝试不同的模型,我们可以直观地感受到不同架构在收敛速度、最终精度和参数量上的权衡(trade-off),这对于未来做技术选型至关重要。

作业二:深入ResNet内部,理解残差连接

ResNet 的精髓在于残差连接 (Residual Connection),也被称为"快捷连接 (Shortcut Connection)"。

  • 它解决了什么问题?

    在 ResNet 诞生之前,人们发现网络越深,性能反而会下降("退化问题"),且梯度容易消失,导致无法训练。

  • 它如何解决?

    残差连接允许信息"跳过"一层或多层直接流向后方。输入 x 不仅经过了卷积层的变换 F(x),还通过一条"捷径"直接与 F(x) 相加,得到最终输出 H(x) = F(x) + x

    这相当于模型在学习"残差"F(x),即与原始输入的差异。如果某个层学不到任何有用信息,F(x)可以趋近于0,信息x也能无损地传递下去,保证了网络不会因为加深而变差。

  • 如何观察?

    在 VSCode 或 PyCharm 中,我们可以按住 Ctrl 点击 resnet18 进入其源码,可以看到它是由多个 BasicBlock 组成的。而 BasicBlockforward 方法中,清晰地定义了 out = self.conv2(out)out += identity 这两个步骤,这就是残差连接的实现!

🌟 心得与总结

Day 44 是认知上的一次巨大刷新。我深刻地体会到:

  1. 不要重复造轮子:社区已经为我们提供了大量强大、可靠的预训练模型,善用它们是AI工程师的基本素养。
  2. 策略重于蛮力:"先冻结、后微调"的策略看似简单,却蕴含着深刻的训练智慧,它让我们能够平稳、高效地利用预训练知识。
  3. 代码背后是思想 :无论是 ResNet 的残差连接,还是迁移学习的整体范式,都闪耀着计算机科学家们解决问题的智慧之光。理解这些思想,比单纯会调用API重要得多。

今天,我们不仅学会了一项强大的技术,更学会了一种高效解决问题的思维方式。


再次感谢 @浙大疏锦行 老师的精彩课程,带领我们从"炼丹新手"向"高级炼丹师"迈出了坚实的一步!

相关推荐
星河天欲瞩2 小时前
【深度学习Day1】环境配置(CUDA、PyTorch)
人工智能·pytorch·python·深度学习·学习·机器学习·conda
木木木一2 小时前
Rust学习记录--C12 实例:写一个命令行程序
学习·算法·rust
Irene.ll2 小时前
DAY32 官方文档的阅读
python
Pyeako2 小时前
Opencv计算机视觉--轮廓检测&模板匹配
人工智能·python·opencv·计算机视觉·边缘检测·轮廓检测·模板匹配
DBBH2 小时前
DBBH的AI学习笔记
人工智能·笔记·学习
青衫码上行2 小时前
如何构建maven项目
java·学习·maven
Knight_AL2 小时前
Flink 核心算子详解:map / flatMap / filter / process
大数据·python·flink
FJW0208142 小时前
Python推导式与生成器
开发语言·python
June bug2 小时前
【实习笔记】正交实验法设计测试用例
笔记·学习·测试用例