作者的话 :在前面的文章中,我们学习了如何让AI在固定数据集 上学习。但在现实世界中,数据是源源不断 产生的------今天识别猫狗,明天要识别鸟类,后天可能还要识别新的动物品种。传统深度学习面临一个致命问题:灾难性遗忘(Catastrophic Forgetting) ------当模型学习新任务时,会迅速遗忘之前学过的任务。人类可以在学习法语时不遗忘英语,但神经网络不行。**终身学习(Lifelong Learning / Continual Learning)**就是要解决这个问题,让AI像人类一样持续学习、终身进化。本文将带你深入理解终身学习的原理与算法!
一、为什么需要终身学习?
1.1 灾难性遗忘:神经网络的致命弱点
现象演示:
# 实验:先学任务A,再学任务B
# 阶段1:训练MNIST分类(数字0-4)
model.train_on_task_A() # 准确率:98%
# 阶段2:继续训练MNIST分类(数字5-9)
model.train_on_task_B() # 任务B准确率:97%
# 测试任务A的准确率
accuracy_A = model.evaluate_task_A() # 准确率:23%!!!
发生了什么?
- 学习数字5-9时,网络权重被更新
- 原来用于识别0-4的特征被覆盖
- 灾难性遗忘发生!
1.2 人类 vs 神经网络
| 特性 | 人类 | 神经网络(传统训练) |
|---|---|---|
| 学习新任务 | 学习法语时不遗忘英语 | 学习新类别时遗忘旧类别 |
| 增量学习 | 可以不断积累知识 | 必须重新训练所有数据 |
| 存储效率 | 不需要重看旧书 | 需要存储所有历史数据 |
| 适应性 | 强 | 弱 |
1.3 终身学习的应用场景
| 领域 | 场景 | 说明 |
|---|---|---|
| 推荐系统 | 用户兴趣持续变化 | 新兴趣不能覆盖旧兴趣 |
| 自动驾驶 | 新场景适应 | 新城市的路况学习 |
| 医疗诊断 | 新疾病识别 | 学习新病症同时不忘旧病症 |
| 机器人 | 新技能学习 | 学会开门后仍要会端杯子 |
| NLP | 领域适应 | 法律+医疗+新闻领域 |
二、灾难性遗忘的原理
2.1 为什么会发生灾难性遗忘?
神经网络的权重共享:
任务A(识别猫/狗):
特征提取器参数:θ_shared
分类器参数:θ_A
任务B(识别鸟/鱼):
特征提取器参数:θ_shared(与任务A共享!)
分类器参数:θ_B
问题:当训练任务B时,θ_shared被更新,破坏了任务A学到的特征
2.2 遗忘的数学分析
损失函数变化:
L_B(θ):新任务的损失
在任务A上的性能下降:
L_A(θ_{t+1}) - L_A(θ_t) ≈ ∇_θ L_A(θ_t)^T Δθ
如果 ∇_θ L_A · ∇_θ L_B > 0:梯度方向一致,不会遗忘
如果 ∇_θ L_A · ∇_θ L_B < 0:梯度方向冲突,会发生遗忘
三、终身学习的分类
3.1 四大类方法
终身学习方法
├── 正则化方法(Regularization-based)
│ ├── EWC(Elastic Weight Consolidation)
│ ├── SI(Synaptic Intelligence)
│ ├── LwF(Learning without Forgetting)
│ └── MAS(Memory Aware Synapses)
│
├── 回放方法(Replay-based)
│ ├── Experience Replay
│ ├── GEM(Gradient Episodic Memory)
│ ├── A-GEM
│ └── iCaRL
│
├── 动态架构方法(Dynamic Architecture)
│ ├── Progressive Neural Networks
│ ├── PackNet
│ └── HAT
│
└── 参数隔离方法(Parameter Isolation)
3.2 方法对比
| 方法类别 | 核心思想 | 优点 | 缺点 | 代表算法 |
|---|---|---|---|---|
| 正则化 | 限制重要参数的变化 | 存储开销小 | 可能限制学习能力 | EWC, LwF |
| 回放 | 保存旧样本重放 | 效果通常最好 | 需要存储旧数据 | GEM, iCaRL |
| 动态架构 | 为新任务分配新参数 | 理论上无遗忘 | 模型不断增长 | Progressive NN |
| 参数隔离 | 不同任务用不同参数 | 完全无遗忘 | 参数利用率低 | PackNet |
四、正则化方法
4.1 EWC:弹性权重巩固
核心思想:
保护对旧任务重要的参数,允许不重要的参数变化。
Fisher信息矩阵:
F_{i,i} = E[(∂log p/∂θ_i)²]
Fisher信息越大,参数对任务越重要。
EWC损失函数:
L(θ) = L_B(θ) + λ/2 * Σ F_{i,i} (θ_i - θ_{A,i}*)^2
其中:
- L_B(θ):新任务的损失
- F_{i,i}:参数i的Fisher信息
- θ_{A,i}*:任务A训练后的最优参数
- λ:正则化强度
4.2 EWC实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
class EWC:
"""Elastic Weight Consolidation"""
def __init__(self, model, dataset, device='cuda'):
self.model = model
self.device = device
# 保存旧任务的最优参数
self.params = {n: p.clone().detach()
for n, p in model.named_parameters() if p.requires_grad}
# 计算Fisher信息矩阵
self.fisher = self._compute_fisher(dataset)
def _compute_fisher(self, dataset, num_samples=200):
"""计算Fisher信息矩阵的对角线"""
fisher = {n: torch.zeros_like(p)
for n, p in self.model.named_parameters() if p.requires_grad}
self.model.eval()
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)
for i, (x, y) in enumerate(dataloader):
if i >= num_samples:
break
x, y = x.to(self.device), y.to(self.device)
self.model.zero_grad()
output = self.model(x)
log_probs = F.log_softmax(output, dim=1)
log_likelihood = log_probs.max(1)[0]
log_likelihood.backward()
for n, p in self.model.named_parameters():
if p.grad is not None:
fisher[n] += p.grad.clone().detach() ** 2
for n in fisher:
fisher[n] /= num_samples
return fisher
def penalty(self, model):
"""计算EWC惩罚项"""
loss = 0
for n, p in model.named_parameters():
if p.requires_grad and n in self.fisher:
_loss = self.fisher[n] * (p - self.params[n]) ** 2
loss += _loss.sum()
return loss
五、回放方法
5.1 经验回放
核心思想:
保存一部分旧任务的样本,在学习新任务时混合回放。
class ExperienceReplay:
"""经验回放"""
def __init__(self, memory_size=2000):
self.memory_size = memory_size
self.buffer = []
def add_samples(self, x, y):
"""添加样本到缓冲区"""
for i in range(len(x)):
self.buffer.append((x[i].cpu(), y[i].cpu()))
if len(self.buffer) > self.memory_size:
self.buffer = self.buffer[-self.memory_size:]
def get_batch(self, batch_size):
"""从缓冲区采样"""
if len(self.buffer) == 0:
return None, None
indices = np.random.choice(len(self.buffer),
min(batch_size, len(self.buffer)),
replace=False)
x_batch = torch.stack([self.buffer[i][0] for i in indices])
y_batch = torch.tensor([self.buffer[i][1] for i in indices])
return x_batch, y_batch
5.2 GEM:梯度情景记忆
核心思想:
限制梯度更新方向,使其不与旧任务的梯度方向冲突。
约束优化问题:
min_g 1/2 ||g - g_new||^2
s.t. g^T g_old^k >= 0, ∀k
其中:
- g_new:新任务的梯度
- g_old^k:旧任务k的梯度
- 约束:新梯度与所有旧梯度的夹角小于90度
六、实战项目:MNIST增量学习
6.1 实验设置
任务划分:
任务1:识别数字 0, 1
任务2:识别数字 2, 3
任务3:识别数字 4, 5
任务4:识别数字 6, 7
任务5:识别数字 8, 9
目标:依次学习这5个任务,评估在所有任务上的性能
6.2 完整实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
class SimpleMLP(nn.Module):
"""简单的多层感知机"""
def __init__(self, input_size=784, hidden_size=256, num_classes=10):
super(SimpleMLP, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, num_classes)
self.relu = nn.ReLU()
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
class ContinualLearningExperiment:
"""终身学习实验框架"""
def __init__(self, device='cuda'):
self.device = device
# 加载MNIST数据
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
self.train_dataset = datasets.MNIST('./data', train=True,
download=True, transform=transform)
self.test_dataset = datasets.MNIST('./data', train=False,
transform=transform)
# 任务划分
self.tasks = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
self.results = {'naive': [], 'ewc': [], 'replay': []}
def get_task_data(self, task_classes, train=True):
"""获取特定任务的数据"""
dataset = self.train_dataset if train else self.test_dataset
indices = [idx for idx, (_, label) in enumerate(dataset)
if label in task_classes]
return Subset(dataset, indices)
def evaluate(self, model, test_loader):
"""评估模型"""
model.eval()
correct = 0
total = 0
with torch.no_grad():
for x, y in test_loader:
x, y = x.to(self.device), y.to(self.device)
output = model(x)
pred = output.argmax(dim=1)
correct += (pred == y).sum().item()
total += y.size(0)
return correct / total
七、评估指标
| 指标 | 公式 | 说明 |
|---|---|---|
| 平均准确率(ACC) | 1/T * Σ a_{T,i} | 学习完所有任务后的平均准确率 |
| 遗忘率(FOR) | 1/(T-1) * Σ (a_{i,i} - a_{T,i}) | 旧任务性能的下降程度 |
| 后向迁移(BWT) | 1/(T(T-1)) * Σ(a_{j,i} - a_{i,i}) | 学习新任务对旧任务的帮助 |
八、总结
8.1 终身学习的核心要点
- 核心问题:灾难性遗忘------学习新任务时遗忘旧任务
- 主要方法:正则化方法(EWC)、回放方法(GEM)、动态架构、参数隔离
- 评估指标:平均准确率、遗忘率、后向迁移、前向迁移
8.2 学习路径总结
第33-39篇:学习算法
├── 强化学习(DQN、PPO、MARL)
├── 模型预测控制(MPC)
├── 模仿学习(BC、DAgger)
└── 元学习(MAML、Prototypical Networks)
第40篇:终身学习(本篇文章)
├── 灾难性遗忘问题
├── 正则化方法(EWC、LwF)
├── 回放方法(GEM、Experience Replay)
└── 评估指标
未来方向
└── 通用人工智能(AGI):持续学习、适应、进化
下一篇预告:【第41篇】神经架构搜索NAS:自动设计神经网络
我们将探讨如何让AI自动设计神经网络架构,摆脱手工设计的局限,发现更高效的模型结构!
本文为系列第40篇,详细介绍了终身学习的原理与实战。有任何问题欢迎在评论区交流!
标签:终身学习、Continual Learning、灾难性遗忘、Catastrophic Forgetting、EWC、Experience Replay、增量学习