本文目录:
- [一、 什么是迁移学习?](#一、 什么是迁移学习?)
- [二、 为什么需要迁移学习?](#二、 为什么需要迁移学习?)
- [三、 迁移学习的主要方法](#三、 迁移学习的主要方法)
- [四、 什么时候使用迁移学习?](#四、 什么时候使用迁移学习?)
- [五、 一个具体的实例:基于CNN的图像分类](#五、 一个具体的实例:基于CNN的图像分类)
-
- [(一) **选择预训练模型**:](#(一) 选择预训练模型:)
- (二)**模型改造**:
- (三)**训练(微调)**:
- (四)**训练与评估**:
- 总结
前言 :本文主要讲述迁移学习的思想。
一、 什么是迁移学习?
核心思想:迁移学习是一种机器学习技术,其核心在于将一个领域(称为"源领域")中学习到的知识(例如模型参数、特征表示),应用于另一个相关但不同的领域(称为"目标领域"),以提升目标领域任务的学习效率和性能。
一个生动的比喻:学习弹钢琴与学习弹电子琴。
- 源任务:弹钢琴(有大量乐谱和老师)。
- 源领域知识:识谱能力、乐理知识、手指灵活度、节奏感。
- 目标任务:弹电子琴。
- 迁移过程:你不需要从零开始学习电子琴。你已经具备的乐理、识谱能力可以直接应用。你需要学习的新知识主要是电子琴的琴键布局、音色切换等特定操作。
- 结果:你学习弹电子琴的速度远远快于一个完全没有音乐基础的人。
在机器学习中,这个过程就体现为:一个在大型图像数据集(如ImageNet)上预训练好的模型,其学到的"识图能力"(如边缘、纹理、形状等基础特征),可以被用来快速学习一个新的、数据量较小的任务(如识别特定种类的花卉)。
二、 为什么需要迁移学习?
迁移学习的兴起主要源于深度学习面临的几个关键挑战:
数据依赖与数据稀缺:
* 深度学习模型通常是"数据饥渴"的,需要海量标注数据才能达到高性能。
* 但在许多实际应用场景(如医疗影像、工业质检)中,获取大量高质量的标注数据成本极高、非常困难,甚至不可能。迁移学习可以缓解"小数据"困境。
计算资源与时间成本:
* 从零开始训练一个大型深度学习模型(如ResNet, BERT)需要强大的计算资源(多个GPU/TPU)和数天甚至数周的时间。
* 迁移学习可以利用现成的、在大型数据集上预训练好的模型,只需对其进行微调,计算成本和训练时间大大降低。
模型性能与泛化能力:
* 在源领域大数据上学到的模型,通常已经具备了非常通用且强大的特征提取能力。将这些知识迁移到目标领域,往往能获得比从零训练更好的泛化性能,特别是当目标领域数据有限时,可以有效防止过拟合。
三、 迁移学习的主要方法
迁移学习的技术实现多种多样,可以从不同维度进行分类。以下是几种主流的方法:
按技术手段分类:
1. 基于模型的迁移
* **思想**:直接复用预训练模型的全部或部分结构和权重。
* **常见做法**:
* **特征提取器**:将预训练模型(去掉最后的分类层)作为一个固定的特征提取器,然后为新的任务训练一个简单的分类器(如线性层、SVM)。
* **微调**:在特征提取器的基础上,不固定其权重,而是用目标领域的小批量数据,以较小的学习率对整个网络或最后几层进行端到端的再训练。这是目前最常用、最有效的方法。
2. 基于特征的迁移
**思想**:学习一个特征映射函数,将源领域和目标领域的特征映射到同一个特征空间,使得在这个空间里,两个领域的数据分布尽可能相似。
**常见做法**:
* **领域自适应**:通过最小化源域和目标域之间的分布差异(如使用MMD损失、对抗训练)来学习领域不变的特征表示。
* **特征嵌入**:使用自编码器等技术学习数据的低维表示,然后将其用于新任务。
3. 基于关系的迁移
**思想**:假设源领域和目标领域之间共享某种相似的关系逻辑(如图网络中的连接关系),并将这种关系知识进行迁移。
**应用场景**:多见于非IID数据,如社交网络分析、知识图谱等。
按学习情境分类:
归纳式迁移学习
源任务和目标任务不同,但领域可以相同或不同。这是最常见的场景,例如ImageNet预训练模型用于医疗影像分析。
直推式迁移学习
源任务和目标任务相同,但领域不同。例如,从一个电商平台的用户评论情感分析,迁移到另一个电商平台。
无监督迁移学习
源领域和目标领域都没有标签。侧重于学习数据的本质结构特征。
四、 什么时候使用迁移学习?
满足以下条件时,迁移学习通常能取得很好的效果:
- 目标领域数据量小:这是使用迁移学习最典型的场景。
- 源领域与目标领域高度相关:源模型学习的基础特征对目标任务有帮助。例如,用自然图像预训练的模型处理卫星图像是有效的,但用它来处理语音信号则效果不佳。
- 源模型在大规模高质量数据集上预训练过:例如ImageNet(图像)、Wikipedia/BookCorpus(文本)。这样的模型学到的特征泛化能力极强。
五、 一个具体的实例:基于CNN的图像分类
以使用PyTorch框架,将一个在ImageNet上预训练好的ResNet模型,迁移到"猫狗分类"任务为例。
步骤:
(一) 选择预训练模型:
```python
import torch
import torch.nn as nn
from torchvision import models
# 加载预训练的ResNet-18模型,并获取其权重
model = models.resnet18(weights='IMAGENET1K_V1')
```
(二)模型改造:
* ResNet原始输出是1000类(对应ImageNet的1000个类别)。
* 我们的新任务只有2类(猫和狗)。
* 需要替换最后的全连接层。
```python
# 获取全连接层输入特征的维度
num_features = model.fc.in_features
# 用一个新的、未初始化的全连接层替换原来的
# 输出维度改为2
model.fc = nn.Linear(num_features, 2)
```
(三)训练(微调):
* **选项A:仅训练分类器(特征提取)**:冻结所有预训练层的参数,只训练新替换的全连接层。
```python
# 冻结所有网络参数
for param in model.parameters():
param.requires_grad = False
# 只解锁最后一层的参数,使其可训练
for param in model.fc.parameters():
param.requires_grad = True
# 然后配置优化器,只对需要梯度的参数进行更新
optimizer = torch.optim.Adam(model.fc.parameters(), lr=1e-3)
```
* **选项B:微调全部层**:以较低的学习率训练整个网络,包括预训练部分和新加部分。这是更常用的方法。
```python
# 所有参数默认都是 requires_grad = True
# 使用较小的学习率,避免破坏预训练好的权重
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
```
(四)训练与评估:
* 使用目标领域的数据集(猫狗图片)进行训练。
* 在验证集上评估模型性能。
通过这种方式,我们利用了ResNet在ImageNet上学到的强大视觉特征,只用少量的猫狗图片和较短的训练时间,就能得到一个高性能的猫狗分类器。
总结
特性 | 传统机器学习 | 迁移学习 |
---|---|---|
数据假设 | 训练和测试数据独立同分布 | 源领域和目标领域数据分布可以不同 |
数据量 | 需要足够多的标注数据 | 目标领域只需少量标注数据 |
起点 | 从零开始学习 | 从预训练知识开始 |
效率 | 计算成本高,训练时间长 | 计算成本低,收敛快 |
适用场景 | 大数据、通用任务 | 小数据、垂直领域、快速部署 |
迁移学习已经成为当今人工智能,尤其是深度学习领域的标准实践。它打破了"每个任务都必须从零开始"的范式,让AI技术能够更高效、更普惠地应用到各行各业中,是推动AI落地的关键技术之一。
今日的第一篇分享到此结束。