人工智能【第40篇】终身学习入门:让AI持续进化

作者的话 :在前面的文章中,我们学习了如何让AI在固定数据集 上学习。但在现实世界中,数据是源源不断 产生的------今天识别猫狗,明天要识别鸟类,后天可能还要识别新的动物品种。传统深度学习面临一个致命问题:灾难性遗忘(Catastrophic Forgetting) ------当模型学习新任务时,会迅速遗忘之前学过的任务。人类可以在学习法语时不遗忘英语,但神经网络不行。**终身学习(Lifelong Learning / Continual Learning)**就是要解决这个问题,让AI像人类一样持续学习、终身进化。本文将带你深入理解终身学习的原理与算法!


一、为什么需要终身学习?

1.1 灾难性遗忘:神经网络的致命弱点

现象演示

复制代码
# 实验:先学任务A,再学任务B
# 阶段1:训练MNIST分类(数字0-4)
model.train_on_task_A()  # 准确率:98%

# 阶段2:继续训练MNIST分类(数字5-9)
model.train_on_task_B()  # 任务B准确率:97%

# 测试任务A的准确率
accuracy_A = model.evaluate_task_A()  # 准确率:23%!!!

发生了什么?

  • 学习数字5-9时,网络权重被更新
  • 原来用于识别0-4的特征被覆盖
  • 灾难性遗忘发生!

1.2 人类 vs 神经网络

特性 人类 神经网络(传统训练)
学习新任务 学习法语时不遗忘英语 学习新类别时遗忘旧类别
增量学习 可以不断积累知识 必须重新训练所有数据
存储效率 不需要重看旧书 需要存储所有历史数据
适应性

1.3 终身学习的应用场景

领域 场景 说明
推荐系统 用户兴趣持续变化 新兴趣不能覆盖旧兴趣
自动驾驶 新场景适应 新城市的路况学习
医疗诊断 新疾病识别 学习新病症同时不忘旧病症
机器人 新技能学习 学会开门后仍要会端杯子
NLP 领域适应 法律+医疗+新闻领域

二、灾难性遗忘的原理

2.1 为什么会发生灾难性遗忘?

神经网络的权重共享

复制代码
任务A(识别猫/狗):
  特征提取器参数:θ_shared
  分类器参数:θ_A
  
任务B(识别鸟/鱼):
  特征提取器参数:θ_shared(与任务A共享!)
  分类器参数:θ_B

问题:当训练任务B时,θ_shared被更新,破坏了任务A学到的特征

2.2 遗忘的数学分析

复制代码
损失函数变化:
  L_B(θ):新任务的损失

在任务A上的性能下降:
  L_A(θ_{t+1}) - L_A(θ_t) ≈ ∇_θ L_A(θ_t)^T Δθ

如果 ∇_θ L_A · ∇_θ L_B > 0:梯度方向一致,不会遗忘
如果 ∇_θ L_A · ∇_θ L_B < 0:梯度方向冲突,会发生遗忘

三、终身学习的分类

3.1 四大类方法

复制代码
终身学习方法
├── 正则化方法(Regularization-based)
│   ├── EWC(Elastic Weight Consolidation)
│   ├── SI(Synaptic Intelligence)
│   ├── LwF(Learning without Forgetting)
│   └── MAS(Memory Aware Synapses)
│
├── 回放方法(Replay-based)
│   ├── Experience Replay
│   ├── GEM(Gradient Episodic Memory)
│   ├── A-GEM
│   └── iCaRL
│
├── 动态架构方法(Dynamic Architecture)
│   ├── Progressive Neural Networks
│   ├── PackNet
│   └── HAT
│
└── 参数隔离方法(Parameter Isolation)

3.2 方法对比

方法类别 核心思想 优点 缺点 代表算法
正则化 限制重要参数的变化 存储开销小 可能限制学习能力 EWC, LwF
回放 保存旧样本重放 效果通常最好 需要存储旧数据 GEM, iCaRL
动态架构 为新任务分配新参数 理论上无遗忘 模型不断增长 Progressive NN
参数隔离 不同任务用不同参数 完全无遗忘 参数利用率低 PackNet

四、正则化方法

4.1 EWC:弹性权重巩固

核心思想

保护对旧任务重要的参数,允许不重要的参数变化。

Fisher信息矩阵

复制代码
F_{i,i} = E[(∂log p/∂θ_i)²]

Fisher信息越大,参数对任务越重要。

EWC损失函数

复制代码
L(θ) = L_B(θ) + λ/2 * Σ F_{i,i} (θ_i - θ_{A,i}*)^2

其中:
- L_B(θ):新任务的损失
- F_{i,i}:参数i的Fisher信息
- θ_{A,i}*:任务A训练后的最优参数
- λ:正则化强度

4.2 EWC实现

复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy

class EWC:
    """Elastic Weight Consolidation"""
    def __init__(self, model, dataset, device='cuda'):
        self.model = model
        self.device = device
        
        # 保存旧任务的最优参数
        self.params = {n: p.clone().detach() 
                      for n, p in model.named_parameters() if p.requires_grad}
        
        # 计算Fisher信息矩阵
        self.fisher = self._compute_fisher(dataset)
    
    def _compute_fisher(self, dataset, num_samples=200):
        """计算Fisher信息矩阵的对角线"""
        fisher = {n: torch.zeros_like(p) 
                 for n, p in self.model.named_parameters() if p.requires_grad}
        
        self.model.eval()
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)
        
        for i, (x, y) in enumerate(dataloader):
            if i >= num_samples:
                break
            
            x, y = x.to(self.device), y.to(self.device)
            self.model.zero_grad()
            output = self.model(x)
            log_probs = F.log_softmax(output, dim=1)
            log_likelihood = log_probs.max(1)[0]
            log_likelihood.backward()
            
            for n, p in self.model.named_parameters():
                if p.grad is not None:
                    fisher[n] += p.grad.clone().detach() ** 2
        
        for n in fisher:
            fisher[n] /= num_samples
        
        return fisher
    
    def penalty(self, model):
        """计算EWC惩罚项"""
        loss = 0
        for n, p in model.named_parameters():
            if p.requires_grad and n in self.fisher:
                _loss = self.fisher[n] * (p - self.params[n]) ** 2
                loss += _loss.sum()
        return loss

五、回放方法

5.1 经验回放

核心思想

保存一部分旧任务的样本,在学习新任务时混合回放。

复制代码
class ExperienceReplay:
    """经验回放"""
    def __init__(self, memory_size=2000):
        self.memory_size = memory_size
        self.buffer = []
    
    def add_samples(self, x, y):
        """添加样本到缓冲区"""
        for i in range(len(x)):
            self.buffer.append((x[i].cpu(), y[i].cpu()))
        
        if len(self.buffer) > self.memory_size:
            self.buffer = self.buffer[-self.memory_size:]
    
    def get_batch(self, batch_size):
        """从缓冲区采样"""
        if len(self.buffer) == 0:
            return None, None
        
        indices = np.random.choice(len(self.buffer), 
                                  min(batch_size, len(self.buffer)), 
                                  replace=False)
        
        x_batch = torch.stack([self.buffer[i][0] for i in indices])
        y_batch = torch.tensor([self.buffer[i][1] for i in indices])
        
        return x_batch, y_batch

5.2 GEM:梯度情景记忆

核心思想

限制梯度更新方向,使其不与旧任务的梯度方向冲突。

复制代码
约束优化问题:
  min_g 1/2 ||g - g_new||^2
  s.t.  g^T g_old^k >= 0, ∀k

其中:
- g_new:新任务的梯度
- g_old^k:旧任务k的梯度
- 约束:新梯度与所有旧梯度的夹角小于90度

六、实战项目:MNIST增量学习

6.1 实验设置

复制代码
任务划分:
  任务1:识别数字 0, 1
  任务2:识别数字 2, 3
  任务3:识别数字 4, 5
  任务4:识别数字 6, 7
  任务5:识别数字 8, 9

目标:依次学习这5个任务,评估在所有任务上的性能

6.2 完整实现

复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np

class SimpleMLP(nn.Module):
    """简单的多层感知机"""
    def __init__(self, input_size=784, hidden_size=256, num_classes=10):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, num_classes)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class ContinualLearningExperiment:
    """终身学习实验框架"""
    def __init__(self, device='cuda'):
        self.device = device
        
        # 加载MNIST数据
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])
        
        self.train_dataset = datasets.MNIST('./data', train=True, 
                                           download=True, transform=transform)
        self.test_dataset = datasets.MNIST('./data', train=False, 
                                          transform=transform)
        
        # 任务划分
        self.tasks = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
        self.results = {'naive': [], 'ewc': [], 'replay': []}
    
    def get_task_data(self, task_classes, train=True):
        """获取特定任务的数据"""
        dataset = self.train_dataset if train else self.test_dataset
        indices = [idx for idx, (_, label) in enumerate(dataset) 
                  if label in task_classes]
        return Subset(dataset, indices)
    
    def evaluate(self, model, test_loader):
        """评估模型"""
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for x, y in test_loader:
                x, y = x.to(self.device), y.to(self.device)
                output = model(x)
                pred = output.argmax(dim=1)
                correct += (pred == y).sum().item()
                total += y.size(0)
        return correct / total

七、评估指标

指标 公式 说明
平均准确率(ACC) 1/T * Σ a_{T,i} 学习完所有任务后的平均准确率
遗忘率(FOR) 1/(T-1) * Σ (a_{i,i} - a_{T,i}) 旧任务性能的下降程度
后向迁移(BWT) 1/(T(T-1)) * Σ(a_{j,i} - a_{i,i}) 学习新任务对旧任务的帮助

八、总结

8.1 终身学习的核心要点

  1. 核心问题:灾难性遗忘------学习新任务时遗忘旧任务
  2. 主要方法:正则化方法(EWC)、回放方法(GEM)、动态架构、参数隔离
  3. 评估指标:平均准确率、遗忘率、后向迁移、前向迁移

8.2 学习路径总结

复制代码
第33-39篇:学习算法
  ├── 强化学习(DQN、PPO、MARL)
  ├── 模型预测控制(MPC)
  ├── 模仿学习(BC、DAgger)
  └── 元学习(MAML、Prototypical Networks)

第40篇:终身学习(本篇文章)
  ├── 灾难性遗忘问题
  ├── 正则化方法(EWC、LwF)
  ├── 回放方法(GEM、Experience Replay)
  └── 评估指标

未来方向
  └── 通用人工智能(AGI):持续学习、适应、进化

下一篇预告:【第41篇】神经架构搜索NAS:自动设计神经网络

我们将探讨如何让AI自动设计神经网络架构,摆脱手工设计的局限,发现更高效的模型结构!


本文为系列第40篇,详细介绍了终身学习的原理与实战。有任何问题欢迎在评论区交流!

标签:终身学习、Continual Learning、灾难性遗忘、Catastrophic Forgetting、EWC、Experience Replay、增量学习

相关推荐
这张生成的图像能检测吗7 个月前
(论文速读)视觉语言模型的无遗忘学习
人工智能·深度学习·计算机视觉·clip·持续学习·灾难性遗忘
大千AI助手8 个月前
灾难性遗忘:神经网络持续学习的核心挑战与解决方案
人工智能·深度学习·神经网络·大模型·llm·持续学习·灾难性遗忘
Better Bench2 年前
【博士每天一篇文献-算法】持续学习经典算法之LwF: Learning without forgetting
知识蒸馏·持续学习·连续学习·终身学习
deardao2 年前
持续学习的综述: 理论、方法与应用(三:泛化分析)
人工智能·深度学习·持续学习·增量学习·终身学习
Better Bench2 年前
【博士每天一篇文献-算法】Fearnet Brain-inspired model for incremental learning
持续学习·灾难性遗忘·脑启发·内存效率·海马体网络·前额叶皮层·外侧杏仁核
Better Bench2 年前
【博士每天一篇文献-综述】A survey on few-shot class-incremental learning
元学习·小样本学习·持续学习·灾难性遗忘·增量学习·过拟合·少量样本增量学习
uncle_ll2 年前
机器学习——终身学习
人工智能·机器学习·ai·llm·终身学习·lll
Better Bench2 年前
【博士每天一篇论文-算法】Continual Learning Through Synaptic Intelligence,SI算法
算法·持续学习·连续学习·终身学习·正则化·突触智能·脑科学
阿航6262 年前
机器人持续学习基准LIBERO系列6——获取并显示实际深度图
迁移学习·持续学习·增量学习·机器人学习·终身学习·libero