【Python学习打卡-Day44】站在巨人的肩膀上:玩转PyTorch预训练模型与迁移学习

📋 前言

各位伙伴们,大家好!如果说之前我们从零开始搭建 CNN 是在"手动锻造兵器",那么今天,我们将学习如何 wield (挥舞) 一把已经由无数大师千锤百炼过的"神器"------预训练模型 。Day 44 的主题是迁移学习 (Transfer Learning),这是一个能让我们以极低的成本,在自己的数据集上达到惊人效果的强大范式。

这不仅仅是一种技术,更是一种思想上的飞跃。它告诉我们,不必事事从零开始,善于利用前人的成果,才能更快地解决新问题。准备好,让我们一起站上巨人的肩膀,看得更远!


一、思想的飞跃:为什么要用预训练模型?

在深入代码之前,我们必须先理解其背后的深刻思想。为什么我们不总是从一个随机初始化的模型开始训练呢?

想象一下:从零训练一个模型,就像教一个刚出生的婴儿认识世界。你需要海量的数据和漫长的时间,他才能慢慢学会识别物体的边缘、纹理、形状,最后才能区分"猫"和"狗"。

而使用预训练模型,则像是请来一位在 ImageNet (一个包含上百万张图片、上千个分类的超大数据集) 上身经百战的"视觉专家"。这位专家已经具备了强大的、通用的视觉特征提取能力(比如识别纹理、轮廓、甚至是动物的眼睛和皮毛)。

我们的任务(比如在 CIFAR-10 上分类)对他来说,就像是给他一个新的、更具体的任务。我们不再需要从零教他看世界,只需要在他已有的知识基础上,进行微调 (Fine-tuning),让他适应我们的新任务即可。

这就是迁移学习的核心:

  • 起点更高:模型初始参数不是随机的,而是包含了大量通用知识,这使得训练起点非常高。
  • 收敛更快:因为起点高,模型能更快地找到解决新问题的方向,大大缩短了训练时间。
  • 性能更好:尤其是在我们自己的数据集不够大的情况下,预训练模型带来的通用知识可以有效防止过拟合,达到更高的准确率。

二、群星闪耀时:认识经典的预训练模型

深度学习的历史长河中,诞生了许多里程碑式的CNN架构。它们不仅在当年的 ImageNet 竞赛中大放异彩,也成为了后来无数研究的基础。

模型 年份 关键创新点 ImageNet Top-5错误率 特点
AlexNet 2012 ReLU激活函数、Dropout、GPU训练 15.3% 深度学习复兴的开创者
VGGNet 2014 统一3×3卷积核,结构简洁优美 7.3% "大力出奇迹"的典范,至今仍是很好的基线
GoogLeNet 2014 Inception模块,网络中的网络 6.7% 在保证精度的同时,极大降低了参数量
ResNet 2015 残差连接,解决了超深网络梯度消失问题 3.57% 影响至今的革命性架构,开启百层网络时代
MobileNet 2017 深度可分离卷积,极致轻量化 7.4% 移动端部署的首选
EfficientNet 2019 复合缩放,自动搜索最佳网络配置 2.6% 精度与效率的极致平衡

这些模型就像一个个武林门派,各有绝学。而我们今天要实战的 ResNet,无疑是其中名声最响、应用最广的"名门正派"。


三、实战演练:用ResNet18微调CIFAR-10分类任务

理论讲完,开练!我们将使用在 ImageNet 上预训练好的 ResNet18 模型,来解决我们的 CIFAR-10 分类问题。

3.1 核心策略:"先冻结,后解冻"

微调预训练模型有一个非常经典的策略,就像驯服一匹烈马:

  1. 第一阶段:冻结主干,只训"头"

    • 冻结 (Freeze) :我们将预训练模型中负责提取特征的卷积层(称为 Backbone 或主干网络)的参数全部冻结,使其在训练中不更新。
    • 训练 (Train) :我们只训练我们自己新加上去的分类层(称为 Head 或分类头)。
    • 目的:这是为了让新的分类头先快速适应主干网络输出的特征,而不会因为随机初始化的分类头产生巨大的梯度,从而破坏掉宝贵的预训练权重。
  2. 第二阶段:解冻全身,整体微调

    • 解冻 (Unfreeze):在分类头训练几轮稳定后,我们将整个网络的参数全部解冻。
    • 微调 (Fine-tune) :使用一个非常小的学习率,对整个网络进行训练。
    • 目的:让整个网络,包括主干,都对我们的新数据进行微小的调整,使其更加"专精"于我们的任务。

3.2 代码实现与分析

下面是本次实战的核心代码,它完美地实现了上述"两阶段"训练策略。

python 复制代码
# 【我的代码】
# Day 44 作业: 使用预训练的ResNet18在CIFAR-10上进行微调

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# --- 配置 ---
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

# --- 1. 数据加载与增强 ---
# (与之前代码相同,此处省略)
train_transform = transforms.Compose([...])
test_transform = transforms.Compose([...])
train_dataset = datasets.CIFAR10(...)
test_dataset = datasets.CIFAR10(...)
train_loader = DataLoader(...)
test_loader = DataLoader(...)

# --- 2. 模型定义与修改 ---
def create_resnet18(pretrained=True, num_classes=10):
    # a. 加载预训练的ResNet18模型
    model = models.resnet18(pretrained=pretrained)
    
    # b. 替换最后一层全连接层以适应我们的10分类任务
    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, num_classes)
    
    return model.to(device)

# --- 3. 冻结/解冻层的工具函数 ---
def freeze_model(model, freeze=True):
    # 冻结/解冻除fc层外的所有参数
    for name, param in model.named_parameters():
        if 'fc' not in name:
            param.requires_grad = not freeze
    
    # 打印状态,方便观察
    if freeze:
        print("已冻结模型卷积层参数 (只训练分类头)")
    else:
        print("已解冻模型所有参数 (进行整体微调)")
    return model

# --- 4. 包含"两阶段"逻辑的训练函数 ---
def train_with_freeze_schedule(model, ..., epochs, freeze_epochs=5):
    # 初始冻结卷积层
    if freeze_epochs > 0:
        model = freeze_model(model, freeze=True)
    
    for epoch in range(epochs):
        # 在指定轮次后解冻所有层
        if epoch == freeze_epochs:
            model = freeze_model(model, freeze=False)
            # 解冻后通常需要降低学习率
            optimizer.param_groups[0]['lr'] = 1e-4  
        
        # ... (标准的训练和验证循环) ...
    # ... (省略具体循环代码,与笔记中一致)
    
# --- 5. 主函数 ---
def main():
    epochs = 40
    freeze_epochs = 5 # 前5轮冻结
    
    model = create_resnet18(pretrained=True, num_classes=10)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(...)
    
    train_with_freeze_schedule(model, ..., epochs=epochs, freeze_epochs=freeze_epochs)

if __name__ == "__main__":
    main()

四、见证奇迹:结果分析与洞察

从训练日志中,我们可以清晰地看到迁移学习的巨大威力:

  1. 火箭般的启动速度 :观察日志,在第5个epoch结束、刚刚解冻所有层 时,训练在第6个epoch的测试准确率瞬间从34%暴涨到67.23%!这正是预训练权重强大通用能力的体现。相比之下,我们从零训练的CNN需要20多个epoch才能达到类似水平。

  2. "谦虚"的训练过程 :在前期,我们可能会观察到训练集准确率低于测试集的现象。这是因为我们的训练集应用了大量的数据增强(随机裁剪、翻转、颜色抖动等),相当于给模型出了"难题";而测试集是标准的"考卷"。这证明了数据增强的有效性,让模型学习得更鲁棒。

  3. 突破性能瓶颈 :最终,我们的模型在测试集上达到了 86.30% 的准确率!这比我们之前从零训练的CNN(约80%)高出了一大截。这就是"巨人肩膀"的力量,让我们轻松地突破了自己模型的性能上限。


五、作业探索:百尺竿头,更进一步

今天的作业是对我们学习成果的最好检验。

作业一:尝试其他预训练模型(如MobileNetV2)

ResNet虽好,但如果我们的应用场景是手机等移动设备,可能就显得有些"重"了。这时,轻量级网络 MobileNetV2 就派上了用场。替换模型非常简单:

python 复制代码
def create_mobilenet_v2(pretrained=True, num_classes=10):
    model = models.mobilenet_v2(pretrained=pretrained)
    
    # 注意:MobileNetV2的分类器层名叫 'classifier'
    in_features = model.classifier[1].in_features
    model.classifier[1] = nn.Linear(in_features, num_classes)
    
    return model.to(device)

# 在main函数中调用即可
# model = create_mobilenet_v2(pretrained=True, num_classes=10)

通过尝试不同的模型,我们可以直观地感受到不同架构在收敛速度、最终精度和参数量上的权衡(trade-off),这对于未来做技术选型至关重要。

作业二:深入ResNet内部,理解残差连接

ResNet 的精髓在于残差连接 (Residual Connection),也被称为"快捷连接 (Shortcut Connection)"。

  • 它解决了什么问题?

    在 ResNet 诞生之前,人们发现网络越深,性能反而会下降("退化问题"),且梯度容易消失,导致无法训练。

  • 它如何解决?

    残差连接允许信息"跳过"一层或多层直接流向后方。输入 x 不仅经过了卷积层的变换 F(x),还通过一条"捷径"直接与 F(x) 相加,得到最终输出 H(x) = F(x) + x

    这相当于模型在学习"残差"F(x),即与原始输入的差异。如果某个层学不到任何有用信息,F(x)可以趋近于0,信息x也能无损地传递下去,保证了网络不会因为加深而变差。

  • 如何观察?

    在 VSCode 或 PyCharm 中,我们可以按住 Ctrl 点击 resnet18 进入其源码,可以看到它是由多个 BasicBlock 组成的。而 BasicBlockforward 方法中,清晰地定义了 out = self.conv2(out)out += identity 这两个步骤,这就是残差连接的实现!

🌟 心得与总结

Day 44 是认知上的一次巨大刷新。我深刻地体会到:

  1. 不要重复造轮子:社区已经为我们提供了大量强大、可靠的预训练模型,善用它们是AI工程师的基本素养。
  2. 策略重于蛮力:"先冻结、后微调"的策略看似简单,却蕴含着深刻的训练智慧,它让我们能够平稳、高效地利用预训练知识。
  3. 代码背后是思想 :无论是 ResNet 的残差连接,还是迁移学习的整体范式,都闪耀着计算机科学家们解决问题的智慧之光。理解这些思想,比单纯会调用API重要得多。

今天,我们不仅学会了一项强大的技术,更学会了一种高效解决问题的思维方式。


再次感谢 @浙大疏锦行 老师的精彩课程,带领我们从"炼丹新手"向"高级炼丹师"迈出了坚实的一步!

相关推荐
Irene199121 小时前
Python 卸载与安装(以卸载3.13.3,装3.13.13为例)
python
予早21 小时前
使用 pyrasite-ng 和 guppy3 做内存分析
python·内存分析
hef2881 天前
如何生成特定SQL的AWR报告_@awrsqrpt.sql深度剖析单条语句性能
jvm·数据库·python
Jinkxs1 天前
从语法纠错到项目重构:Python+Copilot 的全流程开发效率提升指南
python·重构·copilot
技术专家1 天前
Stable Diffusion系列的详细讨论 / Detailed Discussion of the Stable Diffusion Series
人工智能·python·算法·推荐算法·1024程序员节
m0_488913011 天前
万字长文带你梳理Llama开源家族:从Llama-1到Llama-3,看这一篇就够了!
人工智能·学习·机器学习·大模型·产品经理·llama·uml
段一凡-华北理工大学1 天前
【大模型+知识图谱+工业智能体技术架构】~系列文章01:快速了解与初学入门!!!
人工智能·python·架构·知识图谱·工业智能体
IT小Qi1 天前
iperf3网络测试工具
网络·python·测试工具·信息与通信·ip
以神为界1 天前
Python入门实操:基础语法+爬虫入门+模块使用全指南
开发语言·网络·爬虫·python·安全·web
xcjbqd01 天前
Python API怎么加Token认证_JWT生成与验证拦截器实现
jvm·数据库·python