Mask-guided BERT for Few Shot Text Classification
- 前言
- Abstract
- [1. Introduction](#1. Introduction)
- [2. Related Work](#2. Related Work)
-
- [2.1 Masking model inputs](#2.1 Masking model inputs)
- [2.2 Few-Shot Learning in NLP](#2.2 Few-Shot Learning in NLP)
- [2.3 Contrastive learning](#2.3 Contrastive learning)
- [3. Methodology](#3. Methodology)
-
- [3.1 Problem definition](#3.1 Problem definition)
- [3.2 BERT for sentence classification](#3.2 BERT for sentence classification)
- [3.3 The Mask-BERT framework](#3.3 The Mask-BERT framework)
- [3.4 Selection of anchor samples](#3.4 Selection of anchor samples)
- [3.5 Generation of input masks](#3.5 Generation of input masks)
- [3.6 Objective function](#3.6 Objective function)
- [4. Experiments](#4. Experiments)
-
- [4.1 Datasets](#4.1 Datasets)
- [4.2 Experimental Setup](#4.2 Experimental Setup)
- [4.3 Experimental Results](#4.3 Experimental Results)
- [4.4 Result Analysis](#4.4 Result Analysis)
- [5. Conclusions](#5. Conclusions)
- [6. Limitations](#6. Limitations)
- 阅读总结
前言
一篇应用在小样本设置下文本分类任务的文章,标题Mask-guided特别具有吸引力,其实对于小样本设置,主要的展开方向有三种,第一是数据,第二是模型,第三是算法,数据上可以采用数据增强,模型上可以采用不同架构的模型以尽可能缩小假设空间,算法上可以在损失函数上做文章,本篇文章实际上是数据增强和对比损失的结合,还是有一定的借鉴意义。
Paper: https://arxiv.org/pdf/2302.10447.pdf
Code: 无
Abstract
基于Transformer的语言模型已经在很多任务上表现出色,但是训练时需要大量监督数据,在低资源场景具有挑战性。本文提出Mask-BERT,帮助BERT解决小样本学习问题。其核心在于有选择对文本输入应用掩码操作,从而引导模型专注于具有判别性质的token。此外作者还引入对比学习损失函数,帮助模型更好分离不同类别的文本。
1. Introduction
尽管基于Transformer的模型取得了巨大的成功,但是训练时依赖大量的监督数据,这在很多场景下是难以满足的。很多工作采用模型设计、数据增强和一些特殊的训练策略来应对,如元学习、Prompt等。但是这些工作存在一定的局限性,Prompt需要进行Prompt工程,并且基于Prompt-tuning生成的Prompt缺乏可解释性,元学习方法调参困难。最重要的是,这两种方法在设计和部署上都较为复杂。
在之前的实验中,作者发现语言模型通常面对捷径学习的困扰,即并不是学习文本的语义信息,而是依赖于任务无关的信息。
受神经科学和BERT研究的启发,作者提出了Mask-BERT框架来强化BERT应对小样本学习的能力。本文主要的贡献如下:
- 提出了一个简单的、不同于Prompt和元学习的框架Mask-BERT,增强BERT小样本学习能力。
- 一个新颖的掩码策略,用于过滤输入文本中不相关信息,将模型的注意力引导到判别性质的token上。
- 应用对比学习方法, 实验结果证明其在泛化性能上的有效性。
2. Related Work
2.1 Masking model inputs
掩码思想来源于CV领域,特定掩码可以引导模型筛除任务无关的信息,专注于任务相关的token。这种思想其实和人类阅读时关注度不同不谋而合。
2.2 Few-Shot Learning in NLP
基于Transformer的预训练语言模型的FSL方法可以分为三类:
- 基于Prompt的方法。
- 元学习方法。
- 基于微调的方法。
本文的Mask-BERT就是基于微调的方法,去除了传统的最后一层预训练语言模型输出,替换为任务特定的MLP,通过小样本学习进行微调。
2.3 Contrastive learning
对比学习基于相似性学习策略,广泛应用与视觉表征、图表征和NLP任务。受到先前工作的启发,作者采用对比学习并利用锚样本让同一类别样本紧凑,不同类别样本表示远离。
3. Methodology
3.1 Problem definition
给定数据集 D b = { ( x i , y i ) } i = 1 N b D_b=\{(x_i,y_i)\}^{N_b}{i=1} Db={(xi,yi)}i=1Nb和小样本数据集 D n = { ( x i , y i ) } i = 1 N n D_n=\{(x_i,y_i)\}^{N_n}{i=1} Dn={(xi,yi)}i=1Nn,二者互不相交。本文目标是在基础数据集上预训练,在小样本数据集上得到良好的泛化性能。
3.2 BERT for sentence classification
在BERT上下文中,一个句子被定义为:
KaTeX parse error: Can't use function '' in math mode at position 26: ...S},w_1,...,w_n\]̲ 其中w_{CLS}被视为...
每层BERT都包括一个多头注意力块和MLP。输出可以表示如下:
z l ′ = MHSA ( LN ( z l − 1 ) ) + z l − 1 , l = 1 , ... , L z l = M L P ( L N ( z l ′ ) ) + z l ′ , l = 1 , ... , L \begin{array}{l} z^{l^{\prime}}=\operatorname{MHSA}\left(\operatorname{LN}\left(z^{l-1}\right)\right)+z^{l-1}, l=1, \ldots, L\\ z^{l}=M L P\left(L N\left(z^{l^{\prime}}\right)\right)+z^{l^{\prime}}, l=1, \ldots, L \end{array} zl′=MHSA(LN(zl−1))+zl−1,l=1,...,Lzl=MLP(LN(zl′))+zl′,l=1,...,L
对于文本分类,特殊token最后一层输出通常喂入MLP中做分类预测。
3.3 The Mask-BERT framework
FSL一个核心挑战在于如何高效将先验知识从源域转移到目标域。作者设计Mask-BERT过滤任务无关的输入,并指导模型专注于任务相关的关键token。
上图是模型的结构与算法,整体思路如下:
- 首先对BERT在base数据集上进行微调。
- 接着选取锚样本计算对应的mask。
- 最后在目标数据集和mask后的锚样本下进行微调。
作者只对base数据集进行mask有两点原因:
- 想要充分利用小样本数据集中的信息。
- 小样本数据分布稀疏,难以识别出重要的特征。
3.4 Selection of anchor samples
小样本数据集过小容易过拟合,因此采用从base数据集中采样得到的锚样本,可以提高模型的鲁棒性。选取的样本遵循两个原则:
- 锚样本尽量是中心样本,而不是噪声。
- 锚样本不能包含小样本数据集相关信息。
具体来说,作者使用在base数据集上微调的BERT作为特征提取器,定位每个类别的中心,计算每个base样本距离类别中心的距离 d b d_b db,以及小样本数据集距离 d n d_n dn,选取 K K K个 d b − d n d_b-d_n db−dn值最小的样本作为锚样本。
3.5 Generation of input masks
为了尽可能利用先验知识,作者设计了一个mask机制用于从文本中选取目标相关的文本片段,采用积分梯度的方法,可以计算输入token的贡献度。为了保证语义的连贯性,作者保留对分类任务贡献最大的连续文本片段。
mask操作后得到的文本可以减少源域和目标域之间的距离,不同的mask增加了模型的鲁棒性。
3.6 Objective function
预测部分,将特殊token的最后一层输出喂入全连接层,计算交叉熵损失:
y ^ = W T z C L S L + b L cross = − ∑ d ∈ D n ∪ D b s u b ∑ c = 1 C y d c ln y ^ d c \begin{array}{c} \hat{y}=W^{T} z_{C L S}^{L}+b \\ L_{\text {cross }}=-\sum_{d \in D_{n} \cup D_{b}^{s u b}} \sum_{c=1}^{C} y_{d c} \ln \hat{y}_{d c} \end{array} y^=WTzCLSL+bLcross =−∑d∈Dn∪Dbsub∑c=1Cydclny^dc
为了更好分离出不同类样本,聚合同类样本,作者加入了对比损失,如下所示:
L c t r a = − log ∑ e cos ( z i , z i ′ ) ∑ e cos ( z i , z i ′ ) + ∑ e cos ( z i , z j ) L_{c t r a}=-\log \frac{\sum e^{\cos \left(z_{i}, z_{i^{\prime}}\right)}}{\sum e^{\cos \left(z_{i}, z_{i^{\prime}}\right)}+\sum e^{\cos \left(z_{i}, z_{j}\right)}} Lctra=−log∑ecos(zi,zi′)+∑ecos(zi,zj)∑ecos(zi,zi′)
最后目标损失函数如下:
L t o t a l = L c r o s s + L c t r a L_{total}=L_{cross}+L_{ctra} Ltotal=Lcross+Lctra
4. Experiments
实验部分,作者与三种小样本学习方法进行对比,并执行了消融实验以验证模型各个组成部分的作用。
4.1 Datasets
实验在6个公开的数据集上进行,数据集相关信息见下表:
4.2 Experimental Setup
Mask-BERT将与如下NLP模型进行对比:
- BERT
- FPT-BERT,在BERT基础上进一步预训练。
- Re-init-BERT,重新初始化BERT的顶层。
- CPFT,一个对比学习框架。
- CNN-BERT,应用CNN对BERT的输出进行分类。
- SN-FT,基于度量的元学习方法。
- NSP-BERT,基于提示学习的SOTA方法。
4.3 Experimental Results
实验结果如上表所示。总的来说,Mask-BERT和NSP-BERT在开放数据集上表现出相似的性能,这可能是因为BERT在开放数据集上预训练的结果。Mask-BERT在医学领域数据集上表现最佳,表明模型适合具有挑战的领域。此外,mask比例从0.05---0.85变化,模型性能表现稳定。
4.4 Result Analysis
消融实验结果见下表:
分析上表,有如下结论:
- 加入对比损失函数可以有效提高模型的性能。
- 锚样本可以更好利用源域的知识。
- Mask操作可以引导模型专注于重要的token。
一些中间结果的可视化如下图所示:
可视化表明BERT和Mask-BERT都可以高效分离样本,Mask-BERT可以让类别分布更均匀,避免BERT中出现的样本集群分布。测试集结果表明BERT难以分离不同的类别,而Mask-BERT可以解决这个问题,让同类样本更紧凑。
5. Conclusions
本文提出了一个简单的模块化框架Mask-BERT,旨在提高BERT模型的小样本能力。作者使用mask的锚样本用于引导模型学习重要token信息,并采用对比损失让相同标签样本更紧凑,不同标签样本远离。
6. Limitations
作者只证明了在BERT系列模型上的先进性,并没有和其他先进模型作对比。
阅读总结
作者虽然没强调,但是本质上这是一篇在小样本场景应用数据增强的方法。整篇文章没有特别的创新点,并且标题中的"mask-guided"也特别容易误导读者,让读者以为模型真正在学习不同token的重要性。在我看来,整篇文章有以下几点不足或疑惑:
- 如果mask后的长度太短,目标域文本长度很长,不可能对实验结果没有影响。
- 实验结果没有明显提升,缺少理论的证明如t检验。
- 好奇mask比例发生这么明显的变化,为什么实验结果也很稳定,是没什么作用吗?
- 通过数据增强的方式引导模型更专注于重要token的说法过于牵强,因为这种方法并没有从本质上提高模型专注于重要token的能力。
- 没有和当前最先进的其他架构模型进行比较,如ChatGPT、GPT-4等。
当然了, 对于这样的看似是缝合的工作,不能一棒子打死,其实很多有建树性的工作都是基于前人的工作展开的,只不过不能只是简单的缝合,而要思考怎么让以前的方法在现在的领域可用,怎么才能够针对性去解决问题,拿着锤子找钉子不如先想想钉子放在哪更为合适。