基于提示的少样本语言学习的对比学习方法10.25

基于提示的少样本语言学习的对比学习方法

摘要

GPT-3在使用自然语言提示和上下文学习方面展示出的令人印象深刻的性能,激发了在这一范式下更好地微调中等规模模型的相关工作。沿着这一研究方向,本文提出了一种对比学习 框架,该框架通过对同一类别的输入进行聚类,以提高使用有限示例训练的模型的泛化能力。具体而言,提出了一种监督对比框架 ,该框架通过不同的增强"视图"对同一类别的输入进行聚类,并将来自不同类别的输入进行排斥。通过添加不同的语言提示和上下文演示来创建一个示例的不同"视图"。将对比损失与基于提示的少样本学习中的标准掩码语言建模(MLM)损失(standard masked language modeling (MLM) loss) 相结合,实验结果表明,我们的方法在15个不同的语言任务中可以超越当前最先进的方法。我们的框架对任务或基础模型的要求很少,并且可以应用于许多最新的方法,只需进行少量修改。

对比学习(Contrastive Learning)是一种无监督学习方法,旨在学习数据的表示或特征,使得相似样本在表示空间中更加接近,而不相似样本则更加远离。其核心思想是通过比较样本之间的相似性和差异性来学习有用的表示。

对比学习通常使用正样本对和负样本对来进行训练。正样本对是指来自同一类别或相似性较高的样本对,而负样本对则是来自不同类别或相似性较低的样本对。训练过程中,模型被要求将正样本对的表示靠近,而将负样本对的表示推开。
监督对比框架(Supervised Contrastive framework)是一种结合了监督学习和对比学习的方法。它在对比学习的基础上引入了有监督的标签信息,旨在进一步提高模型的性能和表示学习能力。

在监督对比框架中,与传统的对比学习不同,每个样本都有一个与之关联的监督标签。模型的训练目标是通过对比损失来优化样本对的相似性,同时利用监督标签进行监督信号的引导。
标准掩码语言建模(Standard Masked Language Modeling,MLM)损失是一种用于训练语言模型的损失函数,常用于预训练阶段。

在标准MLM中,输入序列中的一部分词汇会被随机掩盖(通常使用特殊的"掩码"符号表示)。模型的任务是根据上下文中的其他词汇来预测这些被掩盖的词汇。

具体而言,对于输入序列中的每个位置,有一定的概率将其掩盖。然后,模型需要根据上下文中的其他词汇来预测被掩盖的词汇。模型会输出一个概率分布,表示每个词汇在该位置的可能性。标准MLM损失使用交叉熵损失来比较模型的预测分布与真实的被掩盖词汇。

引言

基于提示的微调方法 通过将微调任务形成掩码语言问题,缩小了预训练和微调之间的差距。语言提示是附加到查询输入的文本片段,使模型能够提供更好的预测。例如,通过向语言模型提供"这个故事不值得一读,真的很___",模型会更有可能将空白处填入"糟糕(terrible)"而不是"伟大(great)"。在这里,"真的很__"被称为提示的模板,"糟糕(terrible)"或"伟大(great)"是标签词。最近的LM-BFF表明,在输入中附加演示(例如"这是一部了不起的电影,真的很棒")可以帮助模型更好地理解标签词,从而进一步改善结果。

旨在通过给模型提供特定的提示(prompt)来引导其生成期望的输出。

在传统的微调过程中,通常使用大量的标注数据来调整预训练模型的参数,以适应特定的下游任务。然而,对于某些任务,特别是在数据有限的情况下,收集大量标注数据可能是昂贵或困难的。

Prompt-based fine-tuning方法则提供了一种更有效的方式来解决这个问题。在Prompt-based fine-tuning中,一个特定的提示被添加到输入序列中,以指导模型生成期望的输出。这个提示可以是一个问题、一段描述、或者是一个完整的句子模板,具体取决于任务的需求。通过设计合适的提示,可以引导模型产生与任务相关的输出,而不需要大量的标注数据。
将微调任务形成掩码语言问题(Converting Fine-tuning Task into Masked Language Problem)是一种在微调预训练语言模型时的技术,通过将微调任务转化为掩码语言问题来进行训练。这种方法的基本思想是将微调任务转化为一种掩码语言建模(Masked Language Modeling,MLM)问题。

具体而言,将微调任务形成掩码语言问题的方法是在微调数据中随机选择一部分词汇,并将其用特殊的掩码符号进行替换。然后,模型需要根据上下文中的其他词汇来预测被掩盖的词汇。通过使用掩码语言问题作为微调任务,模型可以在微调过程中学习到更好的语义表示,并适应特定任务的需求。

在这项工作中展示了在特征空间中应用监督对比学习(SupCon)可以在基于提示的少样本语言学习的微调过程中带来益处,前提是进行适当的数据增强。

特征空间(Feature Space)指的是将原始数据映射到的一个高维空间,其中每个维度对应于一个特征或特征表示。在特征空间中,每个样本可以由一组特征向量表示。
数据增强(Data Augmentation)用于通过对原始数据进行一系列变换和扩充来增加训练数据的多样性。

在训练机器学习模型时,通常需要大量的标注数据来训练一个准确和鲁棒的模型。然而,有时候获取大量标注数据可能是困难或昂贵的。这时,数据增强技术可以通过在原始数据上进行一系列变换和扩充,生成新的数据样本,从而扩充训练数据的规模。

数据增强的目的是通过对数据进行合理的变换,使得变换后的数据在保持标签不变的同时,呈现出与原始数据类似但略有差异的特征。这样可以帮助模型更好地泛化和适应各种不同的输入情况。

数据增强是SupCon(监督对比学习)的关键组成部分。虽然存在许多数据增强技术,如Cutmix、Mixup用于计算机视觉,以及EDA、AEDA用于文本,但数据增强仍然具有挑战性。

然而,具有演示的基于提示的少样本学习实际上提供了一种自然的方式来创建单个示例的多个"视图"(增强),即对于一组固定的标签词,可以采样不同的模板和不同的演示来附加到输入文本中(如图1所示)。

这使得能够构建一致而完整的多样化输入文本。通过将SupCon应用于具有非常不同内容但具有相同标签的两个示例输入进行聚类,本文的方法能够在特征空间获得额外的监督,这在只给出少量标记示例时至关重要。

本文的主要贡献包括:

• 用于基于提示的少样本学习的监督对比学习框架。

• 使用提示进行对比学习的有效数据增强方法,适用于基于提示的学习者。

相关工作&背景

少样本学习 通常通过元学习、数据增强等方法来解决。受到GPT-3的上下文学习启发,基于提示的微调最近在自然语言处理领域占据主导地位。Basu等在他们的少样本半监督意图分类中应用了对比学习,使用EDA作为数据增强方法。与Basu等不同,本文的方法适用于基于提示的微调,并且实验证明我们提出的增强方法优于EDA。

监督对比损失 SupCon是对比学习的一种特殊形式,它在特征空间中以类别级别对两个增强批次进行聚类。设 ˜ x 2 k − 1 、˜ x 2 k ˜x_{2k−1}、˜x_{2k} ˜x2k−1、˜x2k为输入批次x~k~的两个增强视图, z 2 k 、 z 2 k − 1 为˜ x 2 k − 1 、˜ x 2 k z_{2k}、z_{2k−1}为˜x_{2k−1}、˜x_{2k} z2k、z2k−1为˜x2k−1、˜x2k的特征。则SupCon损失可以计算为

其中 y k y_k yk是批次 x k x_k xk的标签。

方法

问题表述 在LM-BFF中遵循少样本设置,假设可以访问预训练语言模型M,带有标签空间Y的训练数据集D~train~和测试数据集D~test~。在D~train~中,每个类别只有K = 16个示例。

基于提示和演示的微调 基于提示的方法将分类问题视为掩码语言建模(MLM)问题。它们的输入包括一个句子(sent)和一个掩码模板(temp)

(即,x~prompt~ = sent,temp([mask])),并找到最佳的标记来填充[mask]。这导致了一个MLM损失L~MLM~ = MLM(x~prompt~,y),其中y是与x~prompt~对应的标签词。LM-BFF进一步附加了标签词的演示以改善结果:x~prompt+demo~ = sent~0~,temp~0~([mask]),sent~i~,temp~0~(word~i~),其中word~i~是sent~i~的标签词,而sent~i~是从训练集中采样的。然后,分类损失变为:

在LM-BFF或附录B中可以找到更多的数学公式。

基于语言的监督对比损失 为了在输入文本的多个视图上应用SupCon,首先需要获得文本的两个视图:

其中x~1~与LM-BFF中的x~prompt+demo~相同。采样一个新的模板(temp~j~)、演示(sent~k~)和相应的标签词(word~k~),用它们替换x~1~中的内容,以创建输入x~2~的第二个视图。通过公式(1),可以计算出x~1~和x~2~的SupCon损失。总损失则为:

请参考我们的附录C获取更多的数学细节。

计算开销 算法1中展示了算法。一般来说,本文的方法通过L~total~ = L~MLM~ + L~SupCon~进行学习,而基线方法LM-BFF仅通过LMLM进行学习。学习LSupCon需要进行额外的前向传播和反向传播(在算法1中用蓝色突出显示),这会导致计算成本增加1.5倍。

实验

实验评估数据集和协议 在LM-BFF中研究的15个分类任务上评估,并遵循相同的设置,以便进行公平比较。对比学习算法受益于大批量训练。因此,报告的基线结果使用与相同的大批量大小。

本方法针对每个任务使用单个提示/模板(主要提示)进行预测,并使用一组提示(辅助提示)生成用于对比学习的输入的多个视图。使用的主要提示在附录D中展示。

辅助提示可以是手动设计的,也可以由搜索算法生成。在这项工作中使用LM-BFF项目页面中生成的前20个提示,并从这20个提示中随机选择模板来生成输入的第二个视图。除非另有说明,同时使用随机模板和随机演示来创建对比学习的输入的第二个视图。

15个任务的主要结果

使用RoBERTa-base模型(RoBERTa-large请参见附录E)。将本文提出的方法与带有演示的LM-BFF方法以及不带演示的PET方法进行比较。

表1显示,SupCon损失可以持续提升基线的基于提示的微调方法LM-BFF的性能。引入SupCon损失在QQP任务中最大提升了6.3%,在15个任务中平均提升了2.5%,这可能是由于SupCon学习到的更加泛化的表示。平均而言,我们的模型在更困难的任务上有更大的改进。

要强调的是,基线LM-BFF的输入已经在每次调优迭代中附加了不同的随机采样演示。因此,本方法的改进不能归因于学习方程3中的LMLM时输入的多样性,而是归因于LSupCon。表1还显示,即使对于没有演示的基于提示的方法,本方法也能很好地工作。PET是一种没有演示的方法,其性能一直比LM-BFF差。然而,通过额外的SupCon损失,PET的少样本性能平均可以提高2.3%。并且具有和没有演示之间的差距可以大大缩小。

在某些任务中,例如SST-2、SST-5、QNLI、QQP、RTE、MRPC、MR和CR,SupCon损失对性能的贡献甚至可能大于仅使用演示的标签词。

SupCon vs. other losses

进一步展示了本文的方法优于两种最新的旨在改进基于提示的语言模型的方法。在ADAPET中,作者将传统的交叉熵损失替换为基于提示的微调方法PET中的解耦标签损失和标签条件损失,而没有演示。Contextual Calibration通过考虑无上下文的输入(即空格或"N/A")来校准输出概率。(详见附录I)

从表2中可以观察到,在12个任务中,LSupCon优于其他损失函数,而在其他任务中表现相当。Contextual Calibration在整体上并没有取得好的结果。

猜测有两个原因。首先,Contextual Calibration是为没有微调的大型模型(零-shot设置)设计的。其次,Contextual Calibration中的上下文学习形式与这里研究的演示不同。

Ensemblevs. 本模型

集成模型与我们的单一模型相比,本方法使用20个生成的模板(辅助提示)构建输入句子的多个视图。但是,主要预测只使用单个提示(主要提示)和一组标签词。因此,只有一个模型。在这里,将本文提出的模型与由20个单独训练的模型组成的集成模型进行比较。

从表3中,本方法甚至在参数数量增加了20倍的情况下也优于集成模型,这表明本方法更高效地利用了生成的提示。推测由于少样本学习器的过拟合特性,集成模型的成员未能产生实质性的多样化预测分布。

改进与任务难度的关系

在这里展示了在任务难度较高的任务上取得的改进更大。为了证明这一点,首先按照基线(LM-BFF)的性能对15个任务进行排序,并将此排名作为任务难度的代理指标。接下来,报告了在前K个最难任务上的平均改进,其中K从1到15。图2显示了这些结果。第一个柱子表示在最难任务上的改进,第二个柱子表示在最难和次难任务上的平均改进,依此类推。最后一个柱子表示在所有15个任务上的平均改进。

对比实验

相关推荐
奶香臭豆腐27 分钟前
C++ —— 模板类具体化
开发语言·c++·学习
dundunmm1 小时前
机器学习之scikit-learn(简称 sklearn)
python·算法·机器学习·scikit-learn·sklearn·分类算法
古希腊掌管学习的神1 小时前
[机器学习]sklearn入门指南(1)
人工智能·python·算法·机器学习·sklearn
波音彬要多做1 小时前
41 stack类与queue类
开发语言·数据结构·c++·学习·算法
m0_748256782 小时前
WebGIS实战开源项目:智慧机场三维可视化(学习笔记)
笔记·学习·开源
Schwertlilien2 小时前
图像处理-Ch5-图像复原与重建
c语言·开发语言·机器学习
南七澄江3 小时前
各种网站(学习资源及其他)
开发语言·网络·python·深度学习·学习·机器学习·ai
IT古董6 小时前
【漫话机器学习系列】014.贝叶斯法则(Bayes Theorem)
人工智能·机器学习
Crossoads7 小时前
【汇编语言】端口 —— 「从端口到时间:一文了解CMOS RAM与汇编指令的交汇」
android·java·汇编·深度学习·网络协议·机器学习·汇编语言
AAA.建材批发刘哥7 小时前
Linux快速入门-Linux文件系统管理
linux·运维·服务器·c语言·学习方法