【论文精读】Mask-guided BERT for Few Shot Text Classification

Mask-guided BERT for Few Shot Text Classification

  • 前言
  • Abstract
  • [1. Introduction](#1. Introduction)
  • [2. Related Work](#2. Related Work)
    • [2.1 Masking model inputs](#2.1 Masking model inputs)
    • [2.2 Few-Shot Learning in NLP](#2.2 Few-Shot Learning in NLP)
    • [2.3 Contrastive learning](#2.3 Contrastive learning)
  • [3. Methodology](#3. Methodology)
    • [3.1 Problem definition](#3.1 Problem definition)
    • [3.2 BERT for sentence classification](#3.2 BERT for sentence classification)
    • [3.3 The Mask-BERT framework](#3.3 The Mask-BERT framework)
    • [3.4 Selection of anchor samples](#3.4 Selection of anchor samples)
    • [3.5 Generation of input masks](#3.5 Generation of input masks)
    • [3.6 Objective function](#3.6 Objective function)
  • [4. Experiments](#4. Experiments)
    • [4.1 Datasets](#4.1 Datasets)
    • [4.2 Experimental Setup](#4.2 Experimental Setup)
    • [4.3 Experimental Results](#4.3 Experimental Results)
    • [4.4 Result Analysis](#4.4 Result Analysis)
  • [5. Conclusions](#5. Conclusions)
  • [6. Limitations](#6. Limitations)
  • 阅读总结

前言

一篇应用在小样本设置下文本分类任务的文章,标题Mask-guided特别具有吸引力,其实对于小样本设置,主要的展开方向有三种,第一是数据,第二是模型,第三是算法,数据上可以采用数据增强,模型上可以采用不同架构的模型以尽可能缩小假设空间,算法上可以在损失函数上做文章,本篇文章实际上是数据增强和对比损失的结合,还是有一定的借鉴意义。


Paper: https://arxiv.org/pdf/2302.10447.pdf
Code:

Abstract

基于Transformer的语言模型已经在很多任务上表现出色,但是训练时需要大量监督数据,在低资源场景具有挑战性。本文提出Mask-BERT,帮助BERT解决小样本学习问题。其核心在于有选择对文本输入应用掩码操作,从而引导模型专注于具有判别性质的token。此外作者还引入对比学习损失函数,帮助模型更好分离不同类别的文本。

1. Introduction

尽管基于Transformer的模型取得了巨大的成功,但是训练时依赖大量的监督数据,这在很多场景下是难以满足的。很多工作采用模型设计、数据增强和一些特殊的训练策略来应对,如元学习、Prompt等。但是这些工作存在一定的局限性,Prompt需要进行Prompt工程,并且基于Prompt-tuning生成的Prompt缺乏可解释性,元学习方法调参困难。最重要的是,这两种方法在设计和部署上都较为复杂。

在之前的实验中,作者发现语言模型通常面对捷径学习的困扰,即并不是学习文本的语义信息,而是依赖于任务无关的信息。

受神经科学和BERT研究的启发,作者提出了Mask-BERT框架来强化BERT应对小样本学习的能力。本文主要的贡献如下:

  1. 提出了一个简单的、不同于Prompt和元学习的框架Mask-BERT,增强BERT小样本学习能力。
  2. 一个新颖的掩码策略,用于过滤输入文本中不相关信息,将模型的注意力引导到判别性质的token上。
  3. 应用对比学习方法, 实验结果证明其在泛化性能上的有效性。

2. Related Work

2.1 Masking model inputs

掩码思想来源于CV领域,特定掩码可以引导模型筛除任务无关的信息,专注于任务相关的token。这种思想其实和人类阅读时关注度不同不谋而合。

2.2 Few-Shot Learning in NLP

基于Transformer的预训练语言模型的FSL方法可以分为三类:

  1. 基于Prompt的方法。
  2. 元学习方法。
  3. 基于微调的方法。

本文的Mask-BERT就是基于微调的方法,去除了传统的最后一层预训练语言模型输出,替换为任务特定的MLP,通过小样本学习进行微调。

2.3 Contrastive learning

对比学习基于相似性学习策略,广泛应用与视觉表征、图表征和NLP任务。受到先前工作的启发,作者采用对比学习并利用锚样本让同一类别样本紧凑,不同类别样本表示远离。

3. Methodology

3.1 Problem definition

给定数据集 D b = { ( x i , y i ) } i = 1 N b D_b=\{(x_i,y_i)\}^{N_b}{i=1} Db={(xi,yi)}i=1Nb和小样本数据集 D n = { ( x i , y i ) } i = 1 N n D_n=\{(x_i,y_i)\}^{N_n}{i=1} Dn={(xi,yi)}i=1Nn,二者互不相交。本文目标是在基础数据集上预训练,在小样本数据集上得到良好的泛化性能。

3.2 BERT for sentence classification

在BERT上下文中,一个句子被定义为:
KaTeX parse error: Can't use function '' in math mode at position 26: ...S},w_1,...,w_n\]̲ 其中w_{CLS}被视为...

每层BERT都包括一个多头注意力块和MLP。输出可以表示如下:
z l ′ = MHSA ⁡ ( LN ⁡ ( z l − 1 ) ) + z l − 1 , l = 1 , ... , L z l = M L P ( L N ( z l ′ ) ) + z l ′ , l = 1 , ... , L \begin{array}{l} z^{l^{\prime}}=\operatorname{MHSA}\left(\operatorname{LN}\left(z^{l-1}\right)\right)+z^{l-1}, l=1, \ldots, L\\ z^{l}=M L P\left(L N\left(z^{l^{\prime}}\right)\right)+z^{l^{\prime}}, l=1, \ldots, L \end{array} zl′=MHSA(LN(zl−1))+zl−1,l=1,...,Lzl=MLP(LN(zl′))+zl′,l=1,...,L

对于文本分类,特殊token最后一层输出通常喂入MLP中做分类预测。

3.3 The Mask-BERT framework

FSL一个核心挑战在于如何高效将先验知识从源域转移到目标域。作者设计Mask-BERT过滤任务无关的输入,并指导模型专注于任务相关的关键token。

上图是模型的结构与算法,整体思路如下:

  1. 首先对BERT在base数据集上进行微调。
  2. 接着选取锚样本计算对应的mask。
  3. 最后在目标数据集和mask后的锚样本下进行微调。

作者只对base数据集进行mask有两点原因:

  1. 想要充分利用小样本数据集中的信息。
  2. 小样本数据分布稀疏,难以识别出重要的特征。

3.4 Selection of anchor samples

小样本数据集过小容易过拟合,因此采用从base数据集中采样得到的锚样本,可以提高模型的鲁棒性。选取的样本遵循两个原则:

  1. 锚样本尽量是中心样本,而不是噪声。
  2. 锚样本不能包含小样本数据集相关信息。

具体来说,作者使用在base数据集上微调的BERT作为特征提取器,定位每个类别的中心,计算每个base样本距离类别中心的距离 d b d_b db,以及小样本数据集距离 d n d_n dn,选取 K K K个 d b − d n d_b-d_n db−dn值最小的样本作为锚样本。

3.5 Generation of input masks

为了尽可能利用先验知识,作者设计了一个mask机制用于从文本中选取目标相关的文本片段,采用积分梯度的方法,可以计算输入token的贡献度。为了保证语义的连贯性,作者保留对分类任务贡献最大的连续文本片段。

mask操作后得到的文本可以减少源域和目标域之间的距离,不同的mask增加了模型的鲁棒性。

3.6 Objective function

预测部分,将特殊token的最后一层输出喂入全连接层,计算交叉熵损失:
y ^ = W T z C L S L + b L cross = − ∑ d ∈ D n ∪ D b s u b ∑ c = 1 C y d c ln ⁡ y ^ d c \begin{array}{c} \hat{y}=W^{T} z_{C L S}^{L}+b \\ L_{\text {cross }}=-\sum_{d \in D_{n} \cup D_{b}^{s u b}} \sum_{c=1}^{C} y_{d c} \ln \hat{y}_{d c} \end{array} y^=WTzCLSL+bLcross =−∑d∈Dn∪Dbsub∑c=1Cydclny^dc

为了更好分离出不同类样本,聚合同类样本,作者加入了对比损失,如下所示:
L c t r a = − log ⁡ ∑ e cos ⁡ ( z i , z i ′ ) ∑ e cos ⁡ ( z i , z i ′ ) + ∑ e cos ⁡ ( z i , z j ) L_{c t r a}=-\log \frac{\sum e^{\cos \left(z_{i}, z_{i^{\prime}}\right)}}{\sum e^{\cos \left(z_{i}, z_{i^{\prime}}\right)}+\sum e^{\cos \left(z_{i}, z_{j}\right)}} Lctra=−log∑ecos(zi,zi′)+∑ecos(zi,zj)∑ecos(zi,zi′)

最后目标损失函数如下:
L t o t a l = L c r o s s + L c t r a L_{total}=L_{cross}+L_{ctra} Ltotal=Lcross+Lctra

4. Experiments

实验部分,作者与三种小样本学习方法进行对比,并执行了消融实验以验证模型各个组成部分的作用。

4.1 Datasets

实验在6个公开的数据集上进行,数据集相关信息见下表:

4.2 Experimental Setup

Mask-BERT将与如下NLP模型进行对比:

  • BERT
  • FPT-BERT,在BERT基础上进一步预训练。
  • Re-init-BERT,重新初始化BERT的顶层。
  • CPFT,一个对比学习框架。
  • CNN-BERT,应用CNN对BERT的输出进行分类。
  • SN-FT,基于度量的元学习方法。
  • NSP-BERT,基于提示学习的SOTA方法。

4.3 Experimental Results

实验结果如上表所示。总的来说,Mask-BERT和NSP-BERT在开放数据集上表现出相似的性能,这可能是因为BERT在开放数据集上预训练的结果。Mask-BERT在医学领域数据集上表现最佳,表明模型适合具有挑战的领域。此外,mask比例从0.05---0.85变化,模型性能表现稳定。

4.4 Result Analysis

消融实验结果见下表:

分析上表,有如下结论:

  1. 加入对比损失函数可以有效提高模型的性能。
  2. 锚样本可以更好利用源域的知识。
  3. Mask操作可以引导模型专注于重要的token。

一些中间结果的可视化如下图所示:

可视化表明BERT和Mask-BERT都可以高效分离样本,Mask-BERT可以让类别分布更均匀,避免BERT中出现的样本集群分布。测试集结果表明BERT难以分离不同的类别,而Mask-BERT可以解决这个问题,让同类样本更紧凑。

5. Conclusions

本文提出了一个简单的模块化框架Mask-BERT,旨在提高BERT模型的小样本能力。作者使用mask的锚样本用于引导模型学习重要token信息,并采用对比损失让相同标签样本更紧凑,不同标签样本远离。

6. Limitations

作者只证明了在BERT系列模型上的先进性,并没有和其他先进模型作对比。

阅读总结

作者虽然没强调,但是本质上这是一篇在小样本场景应用数据增强的方法。整篇文章没有特别的创新点,并且标题中的"mask-guided"也特别容易误导读者,让读者以为模型真正在学习不同token的重要性。在我看来,整篇文章有以下几点不足或疑惑:

  1. 如果mask后的长度太短,目标域文本长度很长,不可能对实验结果没有影响。
  2. 实验结果没有明显提升,缺少理论的证明如t检验。
  3. 好奇mask比例发生这么明显的变化,为什么实验结果也很稳定,是没什么作用吗?
  4. 通过数据增强的方式引导模型更专注于重要token的说法过于牵强,因为这种方法并没有从本质上提高模型专注于重要token的能力。
  5. 没有和当前最先进的其他架构模型进行比较,如ChatGPT、GPT-4等。

当然了, 对于这样的看似是缝合的工作,不能一棒子打死,其实很多有建树性的工作都是基于前人的工作展开的,只不过不能只是简单的缝合,而要思考怎么让以前的方法在现在的领域可用,怎么才能够针对性去解决问题,拿着锤子找钉子不如先想想钉子放在哪更为合适。

相关推荐
qq_529025296 分钟前
Torch.gather
python·深度学习·机器学习
IT古董39 分钟前
【漫话机器学习系列】017.大O算法(Big-O Notation)
人工智能·机器学习
凯哥是个大帅比39 分钟前
人工智能ACA(五)--深度学习基础
人工智能·深度学习
m0_748232921 小时前
DALL-M:基于大语言模型的上下文感知临床数据增强方法 ,补充
人工智能·语言模型·自然语言处理
szxinmai主板定制专家1 小时前
【国产NI替代】基于FPGA的32通道(24bits)高精度终端采集核心板卡
大数据·人工智能·fpga开发
海棠AI实验室1 小时前
AI的进阶之路:从机器学习到深度学习的演变(三)
人工智能·深度学习·机器学习
机器懒得学习1 小时前
基于YOLOv5的智能水域监测系统:从目标检测到自动报告生成
人工智能·yolo·目标检测
QQ同步助手2 小时前
如何正确使用人工智能:开启智慧学习与创新之旅
人工智能·学习·百度
AIGC大时代2 小时前
如何使用ChatGPT辅助文献综述,以及如何进行优化?一篇说清楚
人工智能·深度学习·chatgpt·prompt·aigc