[论文日读]Intel新研究揭秘:如何让预训练模型持续学习而不遗忘?全面解读预训练时代的持续学习方法

引言:持续学习的挑战与PTM的兴起

在现实世界的应用中,数据通常以流式的形式出现,这要求学习系统能够随着数据的演化不断吸收新知识。持续学习(Continual Learning, CL)的目标就是实现这一点,在学习新知识的同时克服对旧知识的灾难性遗忘。

传统的CL方法从零开始利用流入的数据构建模型,模型的能力随着流入数据的增加而增强。然而,预训练模型(Pre-trained Models, PTMs)时代的到来改变了已有的模型范式,特别是在利用PTMs强大的表征能力进行CL方面。

本文是一个综述性的文章,涵盖了基于PTM的CL的最新进展。其将现有的方法分为三类,并提供了方法之间的相似性、差异性以及各自的优缺点的比较分析。此外还提供了实验结果,对比了各种最新方法。

论文概览与贡献

1. 论文标题、机构、论文链接和项目地址

  • 论文标题:Continual Learning with Pre-Trained Models: A Survey
  • 机构:National Key Laboratory for Novel Software Technology, Nanjing University; School of Artificial Intelligence, Nanjing University
  • 论文链接:arxiv.org/pdf/2401.16...
  • 项目地址:github.com/sun-hailong...

2. 论文主要贡献

  • 提出了第一个全面的基于预训练模型的持续学习的调查,包括问题定义、基准数据集和评估协议。
  • 对每类中的代表性方法进行了评估,涵盖了七个基准数据集。
  • 强调了预训练模型持续学习当前的挑战和未来方向,旨在揭示未被充分研究的方面。

持续学习(CL)的定义与重要性

1. CL的基本概念

持续学习(Continual Learning,简称CL)是一种旨在模拟人类学习过程的机器学习范式,它关注于在连续到来的数据流中学习新知识的同时,保持对既有知识的记忆 。在现实应用中,数据通常以 的形式出现,这就要求学习系统能够不断适应和演化 。CL的目标是通过拟合一个模型 <math xmlns="http://www.w3.org/1998/Math/MathML"> f ( x ) f(x) </math>f(x),来最小化所有已见任务 的期望风险。在这个过程中,模型应该能够在学习新任务的同时,不遗忘之前学习的任务

2. CL中的灾难性遗忘问题

CL面临的一个关键问题是灾难性遗忘(catastrophic forgetting) ,这是指在学习新知识时,对先前学习的任务的性能显著下降的现象。为了解决这个问题,许多研究致力于在CL中寻找方法来减轻或避免灾难性遗忘。

预训练模型(PTM)时代的CL方法

1. PTM的优势与CL的结合

预训练模型(Pre-trained Models,简称PTMs)是从大规模数据集中学习的模型,它们在多种下游任务中表现出了强大的泛化能力。这些模型通常具有丰富的表示能力,可以作为CL的良好基础。

在PTM时代,基于PTM的CL方法能够利用预训练模型的知识。因此已有的基于预训练模型(PTM)的持续学习(CL)方法通常在预训练的模型基础上进行微调,以适应新的任务,同时尽量减少对旧知识的遗忘。

2. PTM基础:ViT与其在CL中的应用

Vision Transformer(ViT)是一种基于Transformer架构的图像处理模型,它通过将输入图像划分为多个不重叠的块,并将这些块送入Transformer网络来提取特征。

在CL中,ViT是一种代表性的PTM。例如,通过冻结预训练的ViT权重,可以使用轻量级的调整技术(如视觉提示(vesion prompt)调整或适配器(adaptor)学习)快速适应下游任务,同时保持泛化能力。

在基于预训练模型(PTM)的CL方法中,已有多个方法来利用PTM的优势,包括基于提示(prompt)的方法、基于表示的方法和基于模型混合的方法。这些方法利用PTM不同的特性来促进持续学习:

  • 提示基方法通过提示调整来轻量级更新PTM,
  • 基于表示的方法直接利用PTM的泛化能力来构建分类器。
  • 模型混合方法则在学习过程中设计一组模型,并在推理时使用模型合并、模型集成等混合技术来得出最终预测。

基于提示(Prompt)的方法

基于提示(Prompt)的方法 是PTM基于CL的一种方法,它利用了预训练模型(PTM)的强大泛化能力。这些方法通过提示微调(prompt tuning)来轻量级地更新PTM,以适应下游任务,同时保持预训练权重不变,从而减轻遗忘问题。

视觉提示调整(VPT)通过在图像块特征前添加一组可学习的参数(即提示)来实现,这些提示在训练过程中被优化,以编码特定任务的信息。

为了解决单个提示随新任务的顺序优化可能导致的灾难性遗忘问题,一些研究提出了提示池(prompt pool)的概念,它收集一组提示,允许在训练和推理期间进行实例特定的提示。然而,这需要一个合适的提示选择机制。例如,L2P设计了一个键-查询匹配策略,为每个提示分配一个可学习的键,并使用余弦距离选择与查询实例最相似的键。

尽管基于提示的方法在桥接领域差距和编码任务特定知识方面具有许多优势,但也存在一些缺点。例如,提示选择过程可能会收敛到一个点,使得提示选择仅集中在特定子集上。此外,由于键和提示值在学习过程中不断变化,这些参数的更新将抹去先前任务的更新,导致匹配级别和提示级别的遗忘。

基于表示(Representation)的方法

基于表示(Representation)的方法 利用PTM的强大表示能力来分类新任务。例如,SimpleCIL方法通过冻结预训练权重并提取每个类的中心(即原型),然后直接使用这些原型作为分类器权重进行分类。这种方法的优势在于它直观、可解释,并且更新成本低,适合实际应用。

然而,这种方法也有缺点。例如,从不同模型中提取特征以形成类原型时,忽略了模型间的冗余。此外,当下游任务涉及多个领域时,仅在第一阶段适应模型可能不足以跨数据集桥接领域差距。

基于模型混合(Model Mixture)的方法

基于模型混合(Model Mixture)的方法 在连续学习过程中创建一组模型 ,并在推理期间进行模型合并或模型集成。例如,ESN创建了一组基于相同PTM的分类器,并在推理时设计了一个投票策略。另一种方法是模型合并,如LAE通过指数移动平均(EMA)来合并在线模型和离线模型的参数。

模型混合方法的优点包括多样化的决策、强调不同阶段知识的重要性以及最终推理成本不随模型集的增加而增加。然而,这些方法也有缺点,例如设计模型集成需要保存所有历史模型,消耗大量内存缓冲区;而模型合并方法虽然不需要这么大的成本,但合并大型模型的权重也需要大量额外的计算。

实验设计与评估

数据集选择与分割

在实验设计中,选择了CIFAR100、CUB200、ImageNet-R、ImageNet-A、ObjectNet、Omnibenchmark和VTAB等七个基准数据集。这些数据集与ImageNet存在较大的领域差距,增加了CL的难度。数据集分割遵循"B-m, Inc-n"的格式,即第一个数据集包含m个类别,每个后续数据集包含n个类别。在分割之前,我们使用相同的随机种子随机打乱所有类别,以确保公平比较。

训练细节与性能衡量

所有模型都使用PyTorch和Pilot工具箱部署,并采用相同的网络骨架。我们选择在ImageNet21K上预训练的最具代表性的ViT-B/16作为网络骨架。性能衡量采用在第b阶段后的Top-1准确率Ab,以及平均性能¯A作为性能指标。实验结果显示,在典型的CL基准数据集上,几乎所有方法都表现良好,而在与预训练数据集领域差距较大的基准上,一些方法表现出问题。这表明在PTM时代,应该提出更具挑战性的基准数据集。此外,我们还发现,基于表示的方法(例如ADAM和RanPAC)比其他方法(除了DAP)表现更具竞争力,这表明提示基和模型混合方法中的表示可以进一步开发以提高性能。

实验结果与分析

1. 不同方法在各数据集上的表现

实验结果显示,在典型的持续学习基准数据集(如CIFAR100)上,几乎所有方法都表现良好。然而,在与预训练数据集(如ImageNet)存在较大领域差异的数据集(例如ImageNet-A)上,一些方法表现出问题。这表明在预训练模型(PTM)时代,需要提出更具挑战性的基准来作为持续学习的基准。

代表性方法的比较结果表明,基于表示的方法(如ADAM和RanPAC)通常比其他方法(除了DAP,稍后将讨论)表现更好。这表明基于提示和模型混合的方法中的表示可以进一步开发以提高性能。

值得注意的是,简单的基线方法SimpleCIL的性能优于典型的基于提示的方法(如L2P和DualPrompt),这验证了预训练模型的强大表示能力。这意味着更复杂的学习系统并不保证更好的性能,甚至可能在不兼容的模块之间引入噪声。

2. 对比公平性的讨论

从实验结果中可以看出,除了DAP之外,基于提示的方法表现不佳。然而,在DAP中发现了一个可能影响未来比较公平性的致命问题。具体来说,DAP通过特定方程生成特定实例的提示,但该方程中的参数依赖于同一批次的投票。在推理过程中,它将来自同一任务的实例聚集在同一批次中,并为同一批次使用相同的参数生成。换句话说,这等同于直接标注任务标识,简化了难度。当我们将测试批次大小设置为1时,即去除DAP中的批次信息(表示为DAP w/o BI),我们观察到性能急剧下降。DAP w/o BI甚至比典型的基于提示的方法L2P表现更差,验证了核心改进来自批次投票信息。由于机器学习模型应该独立测试,利用此类上下文信息显然会导致不公平的比较。本文旨在指出这种不公平性,并使持续学习比较回归正轨。

未来方向与挑战

1. 预训练大型语言模型的持续学习

在当前由预训练模型(PTM)主导的环境中,大型语言模型(如GPT)的持续学习能力变得越来越重要。这些模型需要适应不断变化的信息,例如全球事件的变化。例如,在2020年选举之后,GPT需要从"当前美国总统是谁?→唐纳德·特朗普"更新为"乔·拜登"。通常,这将需要使用更新的数据集进行全面重训练,因为增量微调可能会导致其他相关知识的覆盖。

2. 计算资源受限下的学习

大型PTM的持续调整通常会产生显著的计算成本。在PTM的背景下,模型的部署不仅限于基于云的环境,还扩展到边缘设备。一个相关的例子是为智能手机个人助理应用程序训练LLM,这需要本地训练和推理。这种情况需要计算效率高的持续学习算法。

3. 新基准的需求与PTM的理论优势

CL的本质是使学习系统能够获得它以前缺乏的知识。然而,鉴于PTM使用的广泛训练数据集(如ImageNet),这些模型很少遇到不熟悉的信息。因此,在PTM的预训练数据集的子集上训练PTM可能是多余的。

总结与展望

本文探讨了预训练模型(PTM)在持续学习(CL)领域的最新进展。通过对现有方法的分类和比较分析,发现PTM基于CL的方法可以分为三大类:

  1. 基于提示(prompt-based)的方法
  2. 基于表示(representation-based)的方法
  3. 基于模型混合(model mixture-based)的方法。

每种方法都有其独特的优势和局限性,但总体而言,它们都在解决持续学习中的灾难性遗忘问题方面取得一定效果。

相关推荐
达柳斯·绍达华·宁几秒前
CNN中的平移不变性和平移等变性
人工智能·神经网络·cnn
技术无疆43 分钟前
【Python】Streamlit:为数据科学与机器学习打造的简易应用框架
开发语言·人工智能·python·深度学习·神经网络·机器学习·数据挖掘
xuehaishijue1 小时前
红外画面空中目标检测系统源码分享
人工智能·目标检测·计算机视觉
羊小猪~~1 小时前
机器学习/数据分析--用通俗语言讲解时间序列自回归(AR)模型,并用其预测天气,拟合度98%+
人工智能·python·机器学习·数据挖掘·数据分析·回归·时序数据库
浊酒南街1 小时前
吴恩达深度学习笔记:卷积神经网络(Foundations of Convolutional Neural Networks)2.7-2.8
人工智能·深度学习·神经网络
DuoRuaiMiFa2 小时前
ChatGPT全新功能Canvas上线:开启智能编程与写作新篇章
人工智能·chatgpt
DisonTangor2 小时前
Windows 11将新增基于AI的搜索、生成式填充和其它AI功能
人工智能
soso19682 小时前
【AI自然语言处理应用】通过API调用通义晓蜜CCAI-对话分析AIO应用
人工智能·自然语言·ccai
网安-搬运工2 小时前
RAG再总结之如何使大模型更好使用外部数据:四个不同层级及查询-文档对齐策略
人工智能·自然语言处理·大模型·llm·大语言模型·ai大模型·rag