迁移学习与对抗迁移学习 - 完整学习笔记
目录
1. 背景与动机
1.1 传统机器学习的局限
核心问题 :传统监督学习假设训练数据和测试数据服从相同的分布(i.i.d. 假设)
现实挑战:
- 📊 数据分布偏移:训练集和测试集来自不同场景(如不同医院的医疗影像)
- 💰 标注成本高昂:新任务需要重新收集大量标注数据
- ⏰ 训练时间长:从零开始训练大模型耗时耗力
- 🌍 跨域需求:模型需要在多个相关但不同的场景下工作
1.2 迁移学习的核心思想
定义 :利用在**源域(Source Domain)上学到的知识,帮助解决目标域(Target Domain)**上的问题。
关键公式:
源域: D_s = {(x_s^i, y_s^i)}_{i=1}^{n_s},分布 P_s(X, Y)
目标域: D_t = {(x_t^j, y_t^j)}_{j=1}^{n_t},分布 P_t(X, Y)
挑战: P_s ≠ P_t(分布不一致)
核心假设 :源域和目标域之间存在某种关联性,使得知识可以迁移。
1.3 为什么需要对抗迁移学习?
传统迁移学习方法(如微调)的问题:
- ❌ 域差异显式处理不足:仅通过参数共享,未明确对齐分布
- ❌ 负迁移风险:当域差异过大时,迁移反而降低性能
- ❌ 特征表示域相关:提取的特征仍带有源域特性
对抗迁移学习的创新:
- ✅ 通过对抗训练显式最小化域之间的差异
- ✅ 学习域不变特征(Domain-Invariant Features)
- ✅ 理论上有更强的泛化保证
2. 迁移学习基础
2.1 核心概念
2.1.1 域(Domain)与任务(Task)
域的定义:
Domain D = {X, P(X)}
- X: 特征空间
- P(X): 边缘概率分布
任务的定义:
Task T = {Y, P(Y|X)}
- Y: 标签空间
- P(Y|X): 条件概率分布(学习目标)
迁移学习分类:
| 类型 | 特征空间 | 边缘分布 | 标签空间 | 条件分布 |
|---|---|---|---|---|
| 同构迁移 | X_s = X_t | P_s(X) ≠ P_t(X) | Y_s = Y_t | P_s(Y|X) ≠ P_t(Y|X) |
| 异构迁移 | X_s ≠ X_t | - | Y_s ≠ Y_t | - |
2.1.2 迁移学习的设定
根据目标域标签的可用性:
- 监督迁移学习:目标域有标注数据
- 半监督迁移学习:目标域有少量标注数据
- 无监督迁移学习 :目标域完全无标注(对抗迁移学习的主要场景)
2.2 迁移学习的主要方法
2.2.1 基于实例的迁移(Instance-based)
核心思想:重加权源域样本,使其分布接近目标域
方法:
- 重要性加权 :为每个源域样本 x s i x_s^i xsi 分配权重 w i = P t ( x s i ) P s ( x s i ) w_i = \frac{P_t(x_s^i)}{P_s(x_s^i)} wi=Ps(xsi)Pt(xsi)
- 样本选择:仅使用与目标域相似的源域样本
优点 :简单直观
缺点:需要估计密度比,在高维空间困难
2.2.2 基于特征的迁移(Feature-based)
核心思想 :学习一个共同特征表示,使源域和目标域在该空间中分布相似
典型方法:
- 深度域适应(Deep Domain Adaptation)
- 对抗域适应(Adversarial Domain Adaptation) ← 本笔记重点
优势:
- 适用于深度学习
- 可端到端训练
- 理论基础较强
2.2.3 基于参数的迁移(Parameter-based)
核心思想:在源任务和目标任务间共享模型参数
典型方法:
-
预训练 + 微调(Pre-training & Fine-tuning)
python# 伪代码示例 model = PretrainedModel(weights='imagenet') # 源域预训练 model.classifier = NewClassifier(num_classes_target) model.train(target_data) # 目标域微调 -
多任务学习(Multi-task Learning):同时训练多个相关任务
优点 :工程上最常用,效果稳定
缺点:对域差异大的情况效果有限
2.3 理论基础:域适应的泛化界
Ben-David 等人的理论(2010):
目标域的期望风险上界:
ε_t(h) ≤ ε_s(h) + d_H(D_s, D_t) + λ
其中:
- ε_t(h): 目标域误差
- ε_s(h): 源域误差
- d_H(D_s, D_t): H-散度,衡量域之间的差异
- λ: 理想联合假设的误差(最优情况下的误差)
关键启示:
- 需要在源域上表现好(最小化 ε_s)
- 需要显式减小域差异 (最小化 d_H)← 对抗迁移学习的核心
- 源域和目标域需要有共同的理想假设(λ 较小)
3. 对抗迁移学习
3.1 核心原理
基本思想:借鉴**生成对抗网络(GAN)**的思想,通过对抗训练学习域不变特征。
3.1.1 架构组成
对抗迁移学习通常包含三个组件:
┌─────────────────────────────────────────────────────┐
│ 输入: x_s (源域) / x_t (目标域) │
└─────────────────────┬───────────────────────────────┘
│
▼
┌─────────────────────────┐
│ Feature Extractor (F) │ ← 特征提取器
│ 提取域不变特征 f │
└─────────┬─────────┬─────┘
│ │
┌────────▼──────┐ │
│ Classifier │ │ ← 任务分类器(预测标签)
│ (C) │ │
└───────────────┘ │
│
┌────────▼──────────┐
│ Domain Classifier │ ← 域判别器(区分源域/目标域)
│ (D) │
└───────────────────┘
3.1.2 对抗训练机制
目标函数:
min_F,C max_D L = L_task(F, C) - λ * L_domain(F, D)
其中:
1. L_task: 任务损失(分类交叉熵)
L_task = -E_{(x_s,y_s)~D_s} [log C(F(x_s))]
2. L_domain: 域分类损失
L_domain = -E_{x_s~D_s}[log D(F(x_s))]
-E_{x_t~D_t}[log(1-D(F(x_t)))]
3. λ: 平衡系数
训练过程:
- 特征提取器 F 和任务分类器 C :最小化任务损失,同时欺骗域判别器(使其无法区分源域和目标域)
- 域判别器 D:最大化域分类准确性(正确区分源域和目标域)
直观理解:
- F 学习的特征要"欺骗" D,让 D 无法判断特征来自哪个域
- 当 D 无法区分域时,说明特征已经域不变
3.1.3 梯度反转层(Gradient Reversal Layer, GRL)
实现技巧:通过 GRL 实现对抗训练的"max"操作
python
# PyTorch 伪代码
class GradientReversalLayer(torch.autograd.Function):
@staticmethod
def forward(ctx, x, lambda_):
ctx.lambda_ = lambda_
return x.view_as(x)
@staticmethod
def backward(ctx, grad_output):
# 反转梯度并乘以系数
return grad_output.neg() * ctx.lambda_, None
# 使用时
features = feature_extractor(x)
reversed_features = GradientReversalLayer.apply(features, lambda_param)
domain_pred = domain_classifier(reversed_features)
工作原理:
- 前向传播:正常传递特征
- 反向传播 :将域判别器的梯度取反后传给特征提取器
- 结果:F 更新方向与 D 相反,实现对抗
3.2 与传统方法的区别
| 特性 | 微调(Fine-tuning) | 对抗迁移学习 |
|---|---|---|
| 域差异处理 | 隐式(通过参数共享) | 显式(对抗训练对齐) |
| 目标域标签 | 需要 | 可无监督 |
| 理论保证 | 弱 | 较强(基于 H-散度) |
| 训练复杂度 | 低 | 中等(需平衡对抗) |
| 适用场景 | 域差异小 | 域差异大 |
4. 经典方法详解
4.1 DANN (Domain-Adversarial Neural Network)
论文 :Unsupervised Domain Adaptation by Backpropagation (Ganin & Lempitsky, 2015)
4.1.1 核心创新
- 首次将 GAN 思想应用于域适应
- 引入梯度反转层(GRL),优雅实现对抗训练
4.1.2 完整算法
网络结构:
Input (x) → Feature Extractor (Gf)
├→ Label Predictor (Gy) → Class Prediction
└→ GRL → Domain Classifier (Gd) → Domain Prediction
损失函数:
L(Gf, Gy, Gd) = L_y(Gf, Gy) - λ * L_d(Gf, Gd)
L_y = -∑_{x_s,y_s} y_s * log(Gy(Gf(x_s))) # 源域分类损失
L_d = -∑_{x_s} log(Gd(Gf(x_s)))
-∑_{x_t} log(1-Gd(Gf(x_t))) # 域分类损失
训练策略:
- λ 从 0 逐渐增加(退火策略):
λ_p = 2/(1+exp(-γ*p)) - 1,其中 p 为训练进度
4.1.3 代码实现(PyTorch)
python
import torch
import torch.nn as nn
class GRL(torch.autograd.Function):
"""梯度反转层"""
@staticmethod
def forward(ctx, x, alpha):
ctx.alpha = alpha
return x.view_as(x)
@staticmethod
def backward(ctx, grad_output):
return grad_output.neg() * ctx.alpha, None
class DANN(nn.Module):
def __init__(self, num_classes=10):
super(DANN, self).__init__()
# 特征提取器(示例:简单CNN)
self.feature_extractor = nn.Sequential(
nn.Conv2d(3, 64, 5),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, 5),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(128*5*5, 1024),
nn.ReLU()
)
# 标签分类器
self.label_classifier = nn.Sequential(
nn.Linear(1024, 256),
nn.ReLU(),
nn.Linear(256, num_classes)
)
# 域判别器
self.domain_classifier = nn.Sequential(
nn.Linear(1024, 256),
nn.ReLU(),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, x, alpha=1.0):
features = self.feature_extractor(x)
# 标签预测
class_output = self.label_classifier(features)
# 域预测(使用GRL)
reversed_features = GRL.apply(features, alpha)
domain_output = self.domain_classifier(reversed_features)
return class_output, domain_output
# 训练循环示例
def train_dann(model, source_loader, target_loader, epochs=100):
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
class_criterion = nn.CrossEntropyLoss()
domain_criterion = nn.BCELoss()
for epoch in range(epochs):
# 计算当前的 lambda (退火)
p = epoch / epochs
alpha = 2. / (1. + np.exp(-10 * p)) - 1
for (x_s, y_s), (x_t, _) in zip(source_loader, target_loader):
# 前向传播
class_s, domain_s = model(x_s, alpha)
_, domain_t = model(x_t, alpha)
# 计算损失
loss_class = class_criterion(class_s, y_s)
loss_domain_s = domain_criterion(domain_s, torch.ones_like(domain_s))
loss_domain_t = domain_criterion(domain_t, torch.zeros_like(domain_t))
loss_domain = loss_domain_s + loss_domain_t
total_loss = loss_class + loss_domain
# 反向传播
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
关键技巧:
- ⚙️ Lambda 调度:避免训练初期域对齐过强导致特征退化
- ⚙️ Batch 平衡:确保每个 batch 同时包含源域和目标域样本
- ⚙️ 学习率调整:域判别器和特征提取器可使用不同学习率
4.2 ADDA (Adversarial Discriminative Domain Adaptation)
论文 :Adversarial Discriminative Domain Adaptation (Tzeng et al., 2017)
4.2.1 与 DANN 的区别
DANN :共享特征提取器,通过 GRL 实现对抗
ADDA :分离源域和目标域的特征提取器,更类似标准 GAN
架构:
阶段1: 预训练
Source Images → Source Encoder → Classifier (使用源域标签训练)
阶段2: 对抗适应
Source Images → Source Encoder (固定) ──┐
├→ Domain Discriminator
Target Images → Target Encoder (优化) ──┘
4.2.2 训练流程
步骤 1:在源域上预训练
python
# 伪代码
source_encoder = Encoder()
classifier = Classifier()
train_on_source_data(source_encoder, classifier, source_data)
步骤 2:初始化目标编码器
python
target_encoder = copy.deepcopy(source_encoder) # 权重初始化
步骤 3:对抗训练(固定 classifier 和 source_encoder)
python
# 目标:让 target_encoder 提取的特征与 source_encoder 相似
for epoch in range(num_epochs):
# 训练判别器 D
feat_s = source_encoder(x_source) # 固定
feat_t = target_encoder(x_target)
loss_D = -log(D(feat_s)) - log(1 - D(feat_t))
# 训练目标编码器(欺骗判别器)
feat_t = target_encoder(x_target)
loss_G = -log(D(feat_t)) # 希望 D 判断为源域
4.2.3 优缺点分析
优点:
- ✅ 更灵活:可为每个域定制编码器架构
- ✅ 训练稳定:交替优化,类似标准 GAN
- ✅ 避免灾难性遗忘:源域编码器固定
缺点:
- ❌ 参数量翻倍(两个编码器)
- ❌ 需要两阶段训练
4.3 其他重要方法
4.3.1 基于距离度量的方法
Deep CORAL (Correlation Alignment):
- 直接最小化源域和目标域特征的二阶统计量差异(协方差)
python
def coral_loss(source_features, target_features):
d = source_features.size(1)
# 计算协方差矩阵
cov_s = torch.mm(source_features.t(), source_features) / (len(source_features) - 1)
cov_t = torch.mm(target_features.t(), target_features) / (len(target_features) - 1)
# Frobenius 范数
loss = torch.norm(cov_s - cov_t, p='fro') ** 2 / (4 * d * d)
return loss
MMD (Maximum Mean Discrepancy):
-
衡量两个分布在 RKHS(再生核希尔伯特空间)中的距离
MMD(D_s, D_t) = ||E[φ(x_s)] - E[φ(x_t)]||²_H
4.3.2 生成式方法
CycleGAN for Domain Adaptation:
- 学习源域到目标域的图像风格转换
- 保持语义内容不变,仅改变域特征(如素描↔照片)
优点 :可视化效果好,适合图像域转换
缺点:计算开销大,需要无配对图像
5. 实践应用
5.1 计算机视觉
5.1.1 经典任务:MNIST → MNIST-M
场景:手写数字(灰度)→ 彩色背景数字
python
# 完整示例(简化版)
import torch
from torchvision import datasets, transforms
# 数据加载
source_dataset = datasets.MNIST(root='./data', train=True,
transform=transforms.ToTensor())
target_dataset = MNIST_M(root='./data', train=True) # 自定义数据集
# 模型训练
model = DANN(num_classes=10)
train_dann(model, source_dataset, target_dataset, epochs=50)
# 在目标域上评估
target_test = MNIST_M(root='./data', train=False)
accuracy = evaluate(model, target_test)
print(f"Target Domain Accuracy: {accuracy:.2%}")
典型结果:
- 无适应:~60% 准确率
- DANN:~85-90% 准确率
5.1.2 跨域物体识别
任务:合成图像(渲染) → 真实图像
| 数据集对 | 源域 | 目标域 | 应用 |
|---|---|---|---|
| Office-31 | Amazon (产品图) | DSLR (相机拍摄) | 物体分类 |
| VisDA | 3D 渲染 | 真实照片 | 物体识别 |
| GTA5 → Cityscapes | 游戏场景 | 真实街景 | 语义分割 |
工程建议:
- 使用预训练 ResNet/ViT 作为 backbone
- 冻结早期层,仅对齐高层特征
- 数据增强(RandomCrop, ColorJitter)提升鲁棒性
5.2 自然语言处理
5.2.1 情感分析跨域
场景:电影评论 → 产品评论
python
# 伪代码示例
from transformers import BertModel
class BertDANN(nn.Module):
def __init__(self):
super().__init__()
self.bert = BertModel.from_pretrained('bert-base-uncased')
self.classifier = nn.Linear(768, 2) # 正面/负面
self.domain_classifier = nn.Linear(768, 1)
def forward(self, input_ids, attention_mask, alpha=1.0):
outputs = self.bert(input_ids, attention_mask=attention_mask)
pooled = outputs.pooler_output # [CLS] token
sentiment = self.classifier(pooled)
reversed = GRL.apply(pooled, alpha)
domain = self.domain_classifier(reversed)
return sentiment, domain
5.2.2 命名实体识别(NER)
挑战:新闻文本 → 社交媒体文本(语言风格差异大)
解决方案:
- 对齐词嵌入空间(如 word2vec)
- 使用对抗训练对齐上下文表示(LSTM/Transformer 输出)
5.3 医疗影像
场景:不同医院/设备的医学影像差异
示例:
- 源域:医院 A 的 CT 扫描(Siemens 设备)
- 目标域:医院 B 的 CT 扫描(GE 设备)
- 任务:肺结节检测
特殊考虑:
- ⚠️ 隐私保护:使用联邦学习 + 域适应
- ⚠️ 小样本:结合 few-shot learning
- ⚠️ 可解释性:添加注意力机制可视化
6. 方法对比与选择
6.1 方法选择决策树
开始
│
├─ 目标域有标签?
│ ├─ 是 → 监督微调(Fine-tuning)
│ └─ 否 ↓
│
├─ 域差异大小?
│ ├─ 小 → 预训练模型 + 少量微调
│ ├─ 中 → DANN / Deep CORAL
│ └─ 大 → ADDA / CycleGAN(图像)
│
├─ 计算资源?
│ ├─ 充足 → 生成式方法(CycleGAN)
│ └─ 有限 → 判别式方法(DANN)
│
└─ 可解释性需求?
├─ 高 → MMD + 特征可视化
└─ 低 → 端到端对抗方法
6.2 性能对比(典型数据集)
Office-31 数据集(Amazon → Webcam)
| 方法 | 准确率 | 训练时间 | 参数量 |
|---|---|---|---|
| 无适应 | 61.3% | - | - |
| Fine-tuning | 68.5% | 1× | 1× |
| Deep CORAL | 74.2% | 1.2× | 1× |
| DANN | 82.0% | 1.5× | 1× |
| ADDA | 86.2% | 2× | 2× |
注:时间和参数量相对于基线模型
6.3 常见问题与解决方案
问题 1:负迁移(Negative Transfer)
现象:迁移后性能低于不迁移
原因:
- 源域和目标域关联性弱
- 对抗训练过强,破坏了任务相关特征
解决:
- 使用多源域迁移(选择最相关的源域)
- 逐步增加对抗强度(lambda 调度)
- 添加域相似性预检测
问题 2:训练不稳定
现象:损失震荡,域判别器准确率过高/过低
解决:
python
# 1. 梯度裁剪
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# 2. 判别器正则化
loss_domain += 0.01 * torch.norm(domain_classifier.weight)
# 3. 使用 Spectral Normalization
from torch.nn.utils import spectral_norm
domain_classifier = spectral_norm(nn.Linear(1024, 1))
问题 3:模式崩溃(Mode Collapse)
现象:所有目标域样本被映射到相似特征
解决:
- 使用 条件对抗 (CDAN):判别器同时考虑特征和预测标签
- 添加 diversity loss
- 使用多个判别器
7. 进阶话题
7.1 理论保证
7.1.1 H-散度与泛化界
定理(Ben-David et al., 2010):
对于假设空间 H,目标域误差满足:
ε_t(h) ≤ ε_s(h) + d_H(D_s, D_t) + C
其中:
- d_H(D_s, D_t) = 2 sup_{h∈H} |P_s(h(x)=1) - P_t(h(x)=1)|
- C = min_{h*∈H} [ε_s(h*) + ε_t(h*)]
启示:
- ✅ 对抗训练直接最小化 d_H
- ✅ 理论上保证了泛化性能
7.1.2 对抗训练的收敛性
挑战:GAN 式训练可能不收敛
理论工作:
- Wasserstein GAN:使用 Wasserstein 距离代替 JS 散度
- 梯度惩罚(Gradient Penalty):稳定训练
7.2 多源域适应
场景:有多个源域 {D_s^1, D_s^2, ..., D_s^k}
方法 1:加权融合
python
# 为每个源域学习权重
weights = attention_network([f_s1, f_s2, ..., f_sk])
combined_feature = sum(w_i * f_i for w_i, f_i in zip(weights, features))
方法 2:Domain-Specific Batch Normalization
python
# 每个域使用独立的 BN 参数
class MultiDomainBN(nn.Module):
def __init__(self, num_domains, num_features):
self.bn_layers = nn.ModuleList([
nn.BatchNorm2d(num_features) for _ in range(num_domains)
])
def forward(self, x, domain_id):
return self.bn_layers[domain_id](x)
7.3 部分域适应(Partial Domain Adaptation)
挑战:目标域类别是源域的子集
示例:
- 源域:ImageNet(1000 类)
- 目标域:Office(31 类)
解决方案:
- 类别权重重加权:降低源域独有类别的权重
- 渐进式对齐:先对齐共享类别
7.4 开放集域适应(Open Set Domain Adaptation)
挑战:目标域包含源域中不存在的类别("未知"类)
方法:
- 引入拒绝选项:对不确定样本输出"未知"
- 使用对抗开放集识别:区分"已知"和"未知"类
7.5 最新进展(2023-2024)
7.5.1 基于 Transformer 的域适应
- Vision Transformer (ViT) + 域适应
- 自注意力机制捕捉长距离依赖,提升域不变性
7.5.2 自监督 + 域适应
- 使用对比学习(SimCLR, MoCo)预训练
- 无标注情况下学习更好的初始表示
7.5.3 持续域适应(Continual Domain Adaptation)
- 在线适应不断变化的目标域
- 防止灾难性遗忘
8. 总结与资源
8.1 核心要点总结
-
迁移学习的本质:利用源域知识解决目标域问题,突破 i.i.d. 假设限制
-
对抗迁移学习的优势:
- ✅ 显式对齐域分布
- ✅ 无监督/半监督场景
- ✅ 理论基础扎实
-
实践建议:
- 先尝试简单方法(微调、CORAL)
- 根据域差异选择方法
- 注意超参数调优(λ 调度最关键)
-
未来方向:
- 大模型时代的域适应
- 多模态迁移学习
- 高效域适应(少样本、轻量化)
8.2 推荐资源
📚 经典论文
-
理论基础:
- A Theory of Learning from Different Domains (Ben-David et al., 2010)
-
对抗方法:
- Domain-Adversarial Training of Neural Networks (Ganin et al., 2016)
- Adversarial Discriminative Domain Adaptation (Tzeng et al., 2017)
-
综述:
- A Survey on Transfer Learning (Pan & Yang, 2010)
- Deep Visual Domain Adaptation: A Survey (Wang & Deng, 2018)
💻 代码库
- Transfer Learning Library:https://github.com/thuml/Transfer-Learning-Library
- DALIB:https://github.com/thuml/Domain-Adaptation-Library
- Awesome Domain Adaptation:https://github.com/zhaoxin94/awesome-domain-adaptation
📊 数据集
| 数据集 | 领域 | 规模 | 任务 |
|---|---|---|---|
| Office-31 | 物体识别 | 4,652 | 31 类物体 |
| Office-Home | 物体识别 | 15,500 | 65 类物体,4 域 |
| VisDA | 物体识别 | 280K | 12 类,合成→真实 |
| DomainNet | 物体识别 | 600K | 345 类,6 域 |
🛠️ 工具与框架
- PyTorch Domain Library:集成常用域适应方法
- MMClassification:包含多种域适应实现
- Detectron2:支持域适应目标检测
附录:常见符号表
| 符号 | 含义 |
|---|---|
| D s , D t D_s, D_t Ds,Dt | 源域、目标域 |
| P s ( X ) , P t ( X ) P_s(X), P_t(X) Ps(X),Pt(X) | 源域、目标域边缘分布 |
| P ( Y ∣ X ) P(Y|X) P(Y∣X) | 条件分布(任务) |
| f = F ( x ) f = F(x) f=F(x) | 特征表示 |
| y ^ = C ( f ) \hat{y} = C(f) y^=C(f) | 分类器输出 |
| d = D ( f ) d = D(f) d=D(f) | 域判别器输出 |
| λ \lambda λ | 对抗损失权重 |
| d H d_H dH | H-散度(域差异度量) |