【Python实现连续学习算法】复现2018年ECCV经典算法RWalk

Python实现连续学习Baseline 及经典算法RWalk

1 连续学习概念及灾难性遗忘

连续学习(Continual Learning)是一种模拟人类学习过程的机器学习方法,它旨在让模型在面对多个任务时能够连续学习,而不会遗忘已学到的知识。然而,大多数深度学习模型在连续学习多个任务时会出现"灾难性遗忘"(Catastrophic Forgetting)现象。灾难性遗忘指模型在学习新任务时会大幅度遗忘之前学到的任务知识,这是因为模型参数在新任务的训练过程中被完全覆盖。

解决灾难性遗忘问题是连续学习研究的核心。目前已有多种方法被提出,包括正则化方法、回放、架构等等的方法,其中EWC(Elastic Weight Consolidation)是一种经典的正则化方法。

2 PermutdMNIST数据集及模型

PermutedMNIST是连续学习领域的一种经典测试数据集。它通过对MNIST数据集中的像素进行随机置换生成不同的任务。每个任务都是一个由置换规则决定的分类问题,但所有任务共享相同的标签空间。

对于模型的选择,通常采用简单的全连接神经网络。网络结构可以包含若干个隐藏层,每个隐藏层具有一定数量的神经元,并使用ReLU作为激活函数。网络的输出层与标签类别数一致。

模型在训练每个任务时需要调整参数,研究灾难性遗忘问题的严重程度,并在引入算法时测试其对连续学习能力的改善效果。

python 复制代码
import random
import torch
from torchvision import datasets
import os
from torch.utils.data import DataLoader
import numpy as np
import torch.nn as nn
from torch.nn import functional as F
import warnings
warnings.filterwarnings("ignore")
# Set seeds
random.seed(2024)
torch.manual_seed(2024)
np.random.seed(2024)

# Ensure deterministic behavior
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

class PermutedMNIST(datasets.MNIST):
    def __init__(self, root="./data/mnist", train=True, permute_idx=None):
        super(PermutedMNIST, self).__init__(root, train, download=True)
        assert len(permute_idx) == 28 * 28
        if self.train:
            self.data = torch.stack([img.float().view(-1)[permute_idx] / 255
                                      for img in self.data])
        else:
            self.data = torch.stack([img.float().view(-1)[permute_idx] / 255
                                      for img in self.data])

    def __getitem__(self, index):
        if self.train:
            img, target = self.data[index], self.train_labels[index]
        else:
            img, target = self.data[index], self.test_labels[index]
        return img.view(1, 28, 28), target

    def get_sample(self, sample_size):
        random.seed(2024)
        sample_idx = random.sample(range(len(self)), sample_size)
        return [img.view(1, 28, 28) for img in self.data[sample_idx]]
def worker_init_fn(worker_id):
    # 确保每个 worker 的随机种子一致
    random.seed(2024 + worker_id)
    np.random.seed(2024 + worker_id)
def get_permute_mnist(num_task, batch_size):
    random.seed(2024)
    train_loader = {}
    test_loader = {}
    root_dir = './data/permuted_mnist'
    os.makedirs(root_dir, exist_ok=True)

    for i in range(num_task):
        permute_idx = list(range(28 * 28))
        random.shuffle(permute_idx)

        train_dataset_path = os.path.join(root_dir, f'train_dataset_{i}.pt')
        test_dataset_path = os.path.join(root_dir, f'test_dataset_{i}.pt')

        if os.path.exists(train_dataset_path) and os.path.exists(test_dataset_path):

            train_dataset = torch.load(train_dataset_path)
            test_dataset = torch.load(test_dataset_path)
        else:
            train_dataset = PermutedMNIST(train=True, permute_idx=permute_idx)
            test_dataset = PermutedMNIST(train=False, permute_idx=permute_idx)
            torch.save(train_dataset, train_dataset_path)
            torch.save(test_dataset, test_dataset_path)

        train_loader[i] = DataLoader(train_dataset,
                                     batch_size=batch_size,
                                     shuffle=True,
                                    #  num_workers=1,
                                     worker_init_fn=worker_init_fn,
                                     pin_memory=True)
        test_loader[i] = DataLoader(test_dataset,
                                    batch_size=batch_size,
                                    shuffle=False,
                                    #  num_workers=1,
                                     worker_init_fn=worker_init_fn,
                                     pin_memory=True)

    return train_loader, test_loader

class MLP(nn.Module):
    def __init__(self, input_size=28 * 28, num_classes_per_task=10, hidden_size=[400, 400, 400]):
        super(MLP, self).__init__()
        self.hidden_size = hidden_size
        self.input_size = input_size
        
        # 初始化类别计数器
        self.total_classes = num_classes_per_task
        self.num_classes_per_task = num_classes_per_task
        
        # 定义网络结构
        self.fc1 = nn.Linear(input_size, hidden_size[0])
        self.fc2 = nn.Linear(hidden_size[0], hidden_size[1])
        self.fc_before_last = nn.Linear(hidden_size[1], hidden_size[2])
        
        self.fc_out = nn.Linear(hidden_size[2], self.total_classes)
    
    def forward(self, input, task_id=-1):
        x = F.relu(self.fc1(input))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc_before_last(x))
        x = self.fc_out(x)
        return x

3 Baseline代码

没有任何连续学习算法的Baseline代码实现仅仅是将任务逐个训练。具体过程为:依次加载每个任务的数据集,独立训练模型,而不考虑模型对前一个任务的记忆能力。

python 复制代码
class Baseline:
    def __init__(self, num_classes_per_task=10, num_tasks=10, batch_size=256, epochs=2, neurons=0):
        self.num_classes_per_task = num_classes_per_task
        self.num_tasks = num_tasks
        self.batch_size = batch_size
        self.epochs = epochs
        self.neurons = neurons
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.input_size = 28 * 28

        # Initialize model
        self.model = MLP(num_classes_per_task=self.num_classes_per_task).to(self.device)
        self.criterion = nn.CrossEntropyLoss()


        # Get dataset
        self.train_loaders, self.test_loaders = get_permute_mnist(self.num_tasks, self.batch_size)
    def evaluate(self, test_loader, task_id):
        self.model.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for images, labels in test_loader:
                # Move data to GPU in batches
                images = images.view(-1,self.input_size)
                images = images.to(self.device, non_blocking=True)
                labels = labels.to(self.device, non_blocking=True)
                outputs = self.model(images, task_id)
                predicted = torch.argmax(outputs, dim=1)
                correct += (predicted == labels).sum().item()
                total += labels.size(0)

        return 100.0 * correct / total


    def train_task(self, train_loader,optimizer, task_id):
        self.model.train()
        for images, labels in train_loader:
            images = images.view(-1,self.input_size)
            images = images.to(self.device, non_blocking=True)
            labels = labels.to(self.device, non_blocking=True)
            optimizer.zero_grad()
            outputs = self.model(images, task_id)
            loss = self.criterion(outputs, labels)
            loss.backward()
            optimizer.step()

    def run(self):
        all_avg_acc = []
        
        for task_id in range(self.num_tasks):
            train_loader = self.train_loaders[task_id]
            self.model = self.model.to(self.device)
            optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-3, weight_decay=1e-4)
            for epoch in range(self.epochs):
                self.train_task(train_loader,optimizer, task_id)
            task_acc = []
            for eval_task_id in range(task_id + 1):
                accuracy = self.evaluate(self.test_loaders[eval_task_id], eval_task_id)
                task_acc.append(accuracy)
            mean_avg = np.round(np.mean(task_acc), 2)

            print(f"Task {task_id}: Task Acc = {task_acc},AVG={mean_avg}")
            all_avg_acc.append(mean_avg)
        avg_acc = np.mean(all_avg_acc)
        print(f"Task AVG Acc: {all_avg_acc},AVG = {avg_acc}")

if __name__ == '__main__':
    print('Baseline'+"=" * 50)
    random.seed(2024)
    torch.manual_seed(2024)
    np.random.seed(2024)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    baseline = Baseline(num_classes_per_task=10, num_tasks=3, batch_size=256, epochs=2)
    baseline.run()

Baseline==================================================

Task 0: Task Acc = [96.78],AVG=96.78

Task 1: Task Acc = [85.19, 97.0],AVG=91.1

Task 2: Task Acc = [52.66, 89.14, 97.27],AVG=79.69

Task AVG Acc: [96.78, 91.1, 79.69],AVG = 89.19

可以看到模型在学习新任务后,旧任务的准确率在下降,在学习完Task2后,第一个任务的准确率只有52.66,第二个任务的准确率只有89.14。

4 MAS算法

4.1 算法原理

RWalk算法是一种增量学习框架,它通过结合Fisher信息矩阵和优化路径上参数重要性的累积来平衡对旧任务的记忆保持(避免灾难性遗忘)和新任务的学习能力(减少固执性)。

论文《Chaudhry A, Dokania P K, Ajanthan T, et al. Riemannian walk for incremental learning: Understanding forgetting and intransigence[C]//Proceedings of the European conference on computer vision (ECCV). 2018: 532-547.》Riemannian Walk for Incremental Learning (RWalk) 算法中,计算重要性权重和损失函数的公式如下:

  1. 重要性权重的计算:

    • Fisher 信息矩阵的更新:
      F t θ = α F t θ + ( 1 − α ) F t − 1 θ F_t^\theta = \alpha F_t^\theta + (1 - \alpha) F_{t-1}^\theta Ftθ=αFtθ+(1−α)Ft−1θ

      其中, F t θ F_t^\theta Ftθ 是在第 t t t 次迭代时的 Fisher 信息矩阵, α \alpha α 是一个超参数。

    • 参数重要性得分的累积:
      s t 2 t 1 ( θ i ) = ∑ t = t 1 t 2 Δ L t t + Δ t ( θ i ) 1 2 F t θ i Δ θ i ( t ) 2 + ϵ s_{t_2}^{t_1}(\theta_i) = \sum_{t=t_1}^{t_2} \frac{\Delta L_t^{t+\Delta t}(\theta_i)}{\frac{1}{2} F_t^{\theta_i} \Delta \theta_i(t)^2 + \epsilon} st2t1(θi)=t=t1∑t221FtθiΔθi(t)2+ϵΔLtt+Δt(θi)

      其中, Δ L t t + Δ t ( θ i ) \Delta L_t^{t+\Delta t}(\theta_i) ΔLtt+Δt(θi) 是参数 θ i \theta_i θi 从时间步 t t t 到 t + Δ t t + \Delta t t+Δt 的损失变化, F t θ i F_t^{\theta_i} Ftθi 是第 t t t 次迭代时 θ i \theta_i θi 的 Fisher 信息, Δ θ i ( t ) = θ i ( t + Δ t ) − θ i ( t ) \Delta \theta_i(t) = \theta_i(t + \Delta t) - \theta_i(t) Δθi(t)=θi(t+Δt)−θi(t), ϵ \epsilon ϵ 是一个正的常数。

  2. 损失函数的计算:

    • 最终目标函数 (RWalk):
      L ~ k ( θ ) = L k ( θ ) + λ ∑ i = 1 P ( F k − 1 θ i + s t 0 t k − 1 ( θ i ) ) ( θ i − θ k − 1 i ) 2 \tilde{L}k(\theta) = L_k(\theta) + \lambda \sum{i=1}^P \left( F_{k-1}^{\theta_i} + s_{t_0}^{t_{k-1}}(\theta_i) \right) (\theta_i - \theta_{k-1}^i)^2 L~k(θ)=Lk(θ)+λi=1∑P(Fk−1θi+st0tk−1(θi))(θi−θk−1i)2

    其中, L k ( θ ) L_k(\theta) Lk(θ) 是第 k k k 个任务的损失函数, λ \lambda λ 是一个超参数, F k − 1 θ i F_{k-1}^{\theta_i} Fk−1θi 是第 k − 1 k-1 k−1 个任务结束时 θ i \theta_i θi 的 Fisher 信息, s t 0 t k − 1 ( θ i ) s_{t_0}^{t_{k-1}}(\theta_i) st0tk−1(θi) 是从第 t 0 t_0 t0 次迭代到第 t k − 1 t_{k-1} tk−1 次迭代 θ i \theta_i θi 的重要性得分, θ k − 1 i \theta_{k-1}^i θk−1i 是第 k − 1 k-1 k−1 个任务结束时 θ i \theta_i θi 的值。

4.2 代码实现

python 复制代码
import torch
import torch.nn as nn
import random
import warnings
import numpy as np
import warnings
warnings.filterwarnings("ignore")

# Set seeds
random.seed(2024)
torch.manual_seed(2024)
np.random.seed(2024)

# Ensure deterministic behavior
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True  # Enable for GPU efficiency

class RWalk:
    def __init__(self, num_classes_per_task=10, num_tasks=10, batch_size=256, epochs=2, neurons=0):
        self.num_classes_per_task = num_classes_per_task
        self.num_tasks = num_tasks
        self.batch_size = batch_size
        self.epochs = epochs
        self.neurons = neurons
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.input_size = 28 * 28

        self.model = MLP(num_classes_per_task=self.num_classes_per_task).to(self.device)
        self.criterion = nn.CrossEntropyLoss()
        self.scaler = torch.cuda.amp.GradScaler()  # Enable mixed precision
        self.importance_dict = {}
        self.previous_params = {}
        self.path_integral = {}

        self.train_loaders, self.test_loaders = get_permute_mnist(self.num_tasks, self.batch_size)

        self.update_params()

    def evaluate(self, test_loader, task_id):
        self.model.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for images, labels in test_loader:
                images = images.view(-1,self.input_size)
                images = images.to(self.device, non_blocking=True)
                labels = labels.to(self.device, non_blocking=True)
                outputs = self.model(images, task_id)
                predicted = torch.argmax(outputs, dim=1)
                correct += (predicted == labels).sum().item()
                total += labels.size(0)

        return 100.0 * correct / total

    def train_task(self, train_loader,optimizer, task_id):
        self.model.train()

        for images, labels in train_loader:
            images = images.view(-1,self.input_size)
            images = images.to(self.device, non_blocking=True)
            labels = labels.to(self.device, non_blocking=True)
            optimizer.zero_grad()
            outputs = self.model(images, task_id)
            if task_id > 0:
                loss = self.rwalk_multi_objective_loss(outputs, labels)
            else:
                loss = self.criterion(outputs, labels)
            loss.backward()
            optimizer.step()

    def compute_importance(self, data_loader, task_id):
        # EWC++ 
        importance_dict = {name: torch.zeros_like(param, device=self.device) for name, param in self.model.named_parameters() if 'task' not in name}
        self.model.eval()

        for images, labels in data_loader:
            images = images.view(-1,self.input_size)
            images = images.to(self.device, non_blocking=True)
            labels = labels.to(self.device, non_blocking=True)
            self.model.zero_grad()

            outputs = self.model(images, task_id=task_id)
            loss = self.criterion(outputs, labels)
            loss.backward()
            for name, param in self.model.named_parameters():
                if name in importance_dict and param.requires_grad:
                    importance_dict[name] += param.grad ** 2 / len(data_loader)

        # 移动平均更新Fisher Matrix
        for name in importance_dict:
            if name in self.importance_dict:
                self.importance_dict[name] = 0.9 * self.importance_dict[name] + 0.1 * importance_dict[name]
            else:
                self.importance_dict[name] = importance_dict[name]

    def update_path_integral(self):
        # 计算累计重要性
        for name, param in self.model.named_parameters():
            if name in self.path_integral:
                self.path_integral[name] += (param.detach() - self.previous_params[name]) ** 2
            else:
                self.path_integral[name] = (param.detach() - self.previous_params[name]) ** 2

    def update_params(self):
        for name, param in self.model.named_parameters():
            self.previous_params[name] = param.clone().detach()

    def update(self, dataset, task_id):
        self.compute_importance(dataset, task_id)
        self.update_path_integral()
        self.update_params()

    def rwalk_multi_objective_loss(self, outputs, labels, lambda_=100):
        regularization_loss = 0.0
        for name, param in self.model.named_parameters():
            if name in self.importance_dict and name in self.previous_params and name in self.path_integral:
                fisher_importance = self.importance_dict[name]
                path_penalty = self.path_integral[name]
                previous_param = self.previous_params[name]
                regularization_loss += ((fisher_importance + path_penalty) * (param - previous_param).pow(2)).sum()
        loss = self.criterion(outputs, labels)
        total_loss = loss + lambda_ * regularization_loss
        return total_loss

    def run(self):
        all_avg_acc = []
        for task_id in range(self.num_tasks):
            train_loader = self.train_loaders[task_id]
            self.model = self.model.to(self.device)
            optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-3, weight_decay=1e-4)
            for epoch in range(self.epochs):
                self.train_task(train_loader,optimizer, task_id)
            self.update(train_loader, task_id)

            task_acc = []
            for eval_task_id in range(task_id + 1):
                accuracy = self.evaluate(self.test_loaders[eval_task_id], eval_task_id)
                task_acc.append(accuracy)
            mean_avg = np.round(np.mean(task_acc), 2)
            all_avg_acc.append(mean_avg)
            print(f"Task {task_id}: Task Acc = {task_acc},AVG={mean_avg}")
        avg_acc = np.mean(all_avg_acc)
        print(f"Task AVG Acc: {all_avg_acc}, AVG = {avg_acc}")

if __name__ == '__main__':
    print('RWalk' + "=" * 50)
    for _ in range(1):
        random.seed(2024)
        torch.manual_seed(2024)
        np.random.seed(2024)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        rwalk = RWalk(num_classes_per_task=10, num_tasks=3, batch_size=256, epochs=2)
        rwalk.run()

RWalk==================================================

Task 0: Task Acc = [96.78],AVG=96.78

Task 1: Task Acc = [94.91, 95.73],AVG=95.32

Task 2: Task Acc = [86.88, 89.66, 93.76],AVG=90.1

Task AVG Acc: [96.78, 95.32, 90.1], AVG = 94.06666666666666

在学习完每个任务后,旧任务的准确率只是轻微的下降,说明该算法有效的缓解了灾难性遗忘。

相关推荐
明月_清风8 分钟前
FastAPI 从入门到实战:3 分钟构建高性能异步 API
后端·python·fastapi
bellus-15 分钟前
ubuntu26测试win10的ollama大模型性能
python
水木流年追梦16 分钟前
大模型入门-Reward 奖励模型训练
开发语言·python·算法·leetcode·正则表达式
JavaWeb学起来17 分钟前
Python学习教程(六)数据结构List(列表)
数据结构·python·python基础·python教程
liuyunshengsir30 分钟前
PyTorch 动态量化(Dynamic Quantization)
人工智能·pytorch·python
电子云与长程纠缠38 分钟前
UE5制作六边形包裹球体效果
开发语言·python·ue5
DFT计算杂谈1 小时前
KPROJ编译教程
java·前端·python·算法·conda
念恒123061 小时前
Python(循环中断)
开发语言·python
tsfy20032 小时前
Python 处理中文文件名的3个坑(附 Flask 上传解决函数)
开发语言·python·flask·文件上传·中文编码
AI技术控2 小时前
KV Cache 缓存机制的原理和应用:从 Transformer 推理到大模型服务优化
人工智能·python·深度学习·缓存·自然语言处理·transformer