Bert基础(十三)--Bert变体之知识蒸馏训练

B站视频

1、 训练学生BERT模型(TinyBERT模型)

在TinyBERT模型中,我们使用两阶段学习框架。

  • 通用蒸馏
  • 特定任务蒸馏
    这种两阶段学习框架能够蒸馏预训练阶段和微调阶段的知识。下面,让我们看看每个阶段的详细情况。

1.1 通用蒸馏

通用蒸馏基于预训练阶段。这里,我们使用大型的预训练BERT模型(BERT-base模型)作为教师,并通过蒸馏将知识迁移到小型的学生BERT模型(TinyBERT模型)。需要注意的是,所有层上的知识都得到了蒸馏。

我们知道,教师BERT模型是在通用数据集(英语维基百科和多伦多图书语料库数据集)上预训练的。因此,在进行蒸馏时,也就是在将知识从教师(BERT-base模型)迁移到学生(TinyBERT模型)时,我们使用相同的数据集。

经过蒸馏,学生BERT模型将获得教师BERT模型的知识,我们可以把预训练过的学生BERT模型称为通用TinyBERT模型

在通用蒸馏后,我们得到了一个通用TinyBERT模型,它只是预训练过的学生BERT模型。下面,我们将为下游任务微调通用TinyBERT模型。

1.2 特定任务蒸馏

特定任务蒸馏基于微调阶段。在这里,我们将为一项具体的任务对通用TinyBERT模型进行微调。与DistilBERT模型不同的是,除了在预训练阶段进行蒸馏外,TinyBERT模型还在微调阶段进行蒸馏。

首先采用预训练的BERT-base模型,并为特定任务对其进行微调,然后将这个微调后的BERT-base模型作为教师。我们开始进行蒸馏,将知识从微调后的BERT-base模型迁移到通用TinyBERT模型中。经过蒸馏,通用TinyBERT模型将包含来自教师的特定任务的知识。我们可以把这个通用TinyBERT模型称为微调的TinyBERT模型。

通用蒸馏和特定任务蒸馏之间的区别。

请注意,为了在微调阶段进行蒸馏,需要更多特定任务的数据。也就是说,特定任务蒸馏对数据量有更大的需求。因此,我们将使用一种数据增强方法来获得更大的数据集,并基于该数据集来微调通用TinyBERT模型。下面,我们将展示如何进行数据增强。

1.3 数据增强方法

首先,让我们看一个例句。

假设有一个句子Paris is a beautiful city。使用BERT模型词元分析器对该句进行分词,并将标记存储在X中,如下所示:X = [Paris, is, a, beautiful, city]。

然后将X复制到另一个名为X_masked的列表中,可得X_masked = [Paris, is, a, beautiful, city]。

现在,对于列表X中的每个元素(单词)i,做如下处理。

(1) 检查X[i]是否是一个单词。如果它是一个单词,那么就用[MASK]标记替代X_masked[i]的值。然后,使用BERT-base模型来预测被掩盖的单词。我们将预测概率最大的K个单词作为候选单词列表,将其存储在名为candidates的列表中。假设K = 5,即我们预测出5个最有可能的单词,并将它们存储在candidates列表中。

(2) 如果X[i]不是一个单词,那么将不对其进行掩码处理,而是使用GloVe嵌入查找与X[i]最相似的K个单词,并将它们存储在candidates列表中。然后,从一个均匀分布 p ∼ U n i f o r m ( 0 , 1 ) p\sim Uniform(0,1) p∼Uniform(0,1)中随机抽取一个值p,并引入一个新变量,称为阈值 p t p_t pt。将阈值设为 p t = 0.4 p_t = 0.4 pt=0.4。

(3) 如果p小于或等于 p t p_t pt,那么就用candidates列表中的任何一个随机抽取的单词替换X_masked[i]。

(4) 如果p大于 p t p_t pt,那么就用实际的词X[i]替换X_masked[i]。

我们对句子中的每一个单词执行前面的步骤,并将更新后的X_masked列表添加到一个名为data_aug的列表中。对数据集中的每个句子重复应用这种数据增强方法N次。假设N = 10,那么对于每一个句子,都执行数据增强步骤,并得到10个新句子。

现在,我们了解了数据增强方法的工作原理。让我们来看一个例子,假设我们有如下列表。

python 复制代码
X = [Paris, is, a, beautiful, city]

将X复制到一个名为X_masked的新列表中,如下所示。

python 复制代码
X_masked = [Paris, is, a, beautiful, city]

现在对X[i]做如下处理。

当i = 0时,得到X[0] = Paris。判断X[0]是否是一个单词。由于它是一个单词,因此我们用[MASK]标记替换X_masked[0],如下所示。

python 复制代码
X_masked = [[MASK], is, a, beautiful, city]

然后,我们使用BERT-base模型预测K个最有可能是掩码标记的单词,并将它们存储在candidates列表中。假设K = 3,那么BERT-base模型将预测的3个最有可能的单词存储在candidates列表中。以下是BERT-base模型对掩码标记所预测的3个最有可能的单词。

接着,从一个均匀分布 p ∼ U n i f o r m ( 0 , 1 ) p\sim Uniform(0,1) p∼Uniform(0,1)中中随机抽取一个值 p p p。假设 p t = 0.3 p_t = 0.3 pt=0.3,检查 p p p是否小于或等于阈值 p t p_t pt。假设 p t = 0.4 p_t = 0.4 pt=0.4, p p p小于阈值 p t p_t pt,我们就把X_masked[0]替换为candidates列表中的一个随机词。假设从candidates列表中随机抽选了it这个单词,那么X_masked列表变为:

python 复制代码
X_masked = [it, is, a, beautiful, city]

最后,我们将X_masked添加到data_aug列表中。

以这种方式,我们重复以上步骤N次,以获得更多的数据。有了这样的增强数据集后,就可以对通用TinyBERT模型进行微调了。

简而言之,在TinyBERT模型中,我们不仅对所有层进行知识蒸馏,也在预训练阶段和微调阶段应用了蒸馏方法。

与BERT-base模型相比,TinyBERT模型的运算效率提升了96%,速度快9.4倍。我们可以在GitHub上下载预训练的TinyBERT模型。

到目前为止,我们已经学会了如何将知识从一个预训练的大型BERT模型迁移到小型BERT模型中。那么能否将知识从预训练的BERT模型迁移到简单的神经网络中呢?

2、将知识从BERT模型迁移到神经网络中

在本节中,让我们先看一篇有趣的论文:滑铁卢大学的"Distilling Task-Specific Knowledge from BERT into Simple Neural Networks"。在论文中,研究人员阐述了如何通过知识蒸馏将特定任务的知识从BERT模型迁移到一个简单的神经网络中。下面,我们将仔细分析它是如何实现的。

2.1 教师−学生架构

为了了解如何将特定任务的知识从BERT模型迁移到神经网络中,首先,让我们看看教师BERT模型和学生网络的细节。

  • 教师BERT模型

    同样使用预训练的BERT模型作为教师BERT模型。在这里,我们使用预训练的BERT-large模型。请注意,我们是将特定任务的知识从教师迁移给学生,因此,要先针对特定任务微调预训练的BERT-large模型,然后将其作为教师。

    假设我们要让学生网络做情感分析,那么预训练的BERT-large模型就需要为情感分析任务进行微调。

  • 学生网络

    学生网络是一个简单的双向LSTM,可以简单表示为BiLSTM。学生网络架构根据不同任务而变化,让我们先看看学生网络在单句分类任务中的架构。

    假设对句子I love Paris进行情感分析。首先,得到句子的嵌入,然后将嵌入送入双向LSTM。双向LSTM从两个方向(向前和向后)读取句子,可以得到前向和后向的隐藏状态。

    接着,将前向隐藏状态和后向隐藏状态送入带有ReLU激活的全连接层,它将返回logit向量作为输出。将logit向量送入softmax函数,得到该句是正面还是负面的概率,如图

    现在,让我们来看看句子匹配任务的学生网络架构。假设我们想了解给定的两个句子是否相似。在这种情况下,学生网络使用连体BiLSTM。

    首先,得到句子1和句子2的嵌入,并将其送入双向LSTM 1(BiLSTM 1)和双向LSTM 2(BiLSTM 2)。从BiLSTM 1和BiLSTM 2中获得前向隐藏状态和后向隐藏状态。假设 h s 1 h_{s1} hs1表示从BiLSTM 1得到的前向隐藏状态和后向隐藏状态, h s 2 h_{s2} hs2是从BiLSTM 2得到的前向隐藏状态和后向隐藏状态。然后使用一个连接−比较操作将 h s 1 h_{s1} hs1和 h s 2 h_{s2} hs2串联起来,得到如下公式。
    f ( h s 1 , h s 2 ) = [ h s 1 , h s 2 , h s 1 ⊙ h s 2 ∣ h s 1 , h s 2 ∣ ] f(h_{s1},h_{s2}) = [h_{s1},h_{s2},h_{s1}⊙h_{s2}|h_{s1},h_{s2}|] f(hs1,hs2)=[hs1,hs2,hs1⊙hs2∣hs1,hs2∣]

    在上面的公式中,⊙表示逐元素相乘。接下来,将串联的结果送入带有ReLU激活的全连接层,得到logit向量。然后将logit向量送入softmax函数,该函数返回给定句子对相似或不相似的概率,如图所示。

2.2 训练学生网络

我们将特定任务的知识从教师迁移给学生,因此,如前所述,将采用为特定任务微调后的预训练的BERT模型作为教师。也就是说,教师是预训练且经过微调的BERT模型,学生则使用BiLSTM。

通过最小化损失来训练学生网络,损失是学生损失 L s t u d e n t L_{student} Lstudent和蒸馏损失 L d i s t i l l L_{distill} Ldistill的加权和。这与在介绍知识蒸馏时所学到的内容相似,其公式如下。
L = α ⋅ L s t u d e n t + β ⋅ L d i s t i l l L = \alpha ·L_{student}+\beta·L_{distill} L=α⋅Lstudent+β⋅Ldistill

假设 β \beta β的值为 ( 1 − α ) (1-\alpha) (1−α),则上面的公式如下所示。
L = α ⋅ L s t u d e n t + ( 1 − α ) ⋅ L d i s t i l l L = \alpha ·L_{student}+(1-\alpha)·L_{distill} L=α⋅Lstudent+(1−α)⋅Ldistill

我们知道,蒸馏损失一般是软目标和软预测之间的交叉熵损失。但在这里,我们使用均方损失作为蒸馏损失,因为它比交叉熵损失的表现更好,其公式如下。

L d i s t i l l = M S E ( Z T , Z S ) L_{distill} = MSE(Z^T,Z^S) Ldistill=MSE(ZT,ZS)

在上面的公式中, Z T Z^T ZT表示教师网络的logit向量, Z S Z^S ZS表示学生网络的logit向量。

学生损失还是硬目标和硬预测之间的标准交叉熵损失。

我们通过最小化损失函数[插图]来训练学生网络。为了从教师BERT模型中蒸馏知识,将其迁移至学生网络,我们需要一个大型数据集。因此,需要使用一种与任务无关的数据增强方法来增加数据量。

2.3 数据增强方法

这里,我们使用以下方法来进行与任务无关的数据增强。

  • 掩码方法

  • 基于词性的词汇替换方法

  • n-gram采样方法

  • 掩码方法

    在掩码方法中,我们用[MASK]随机掩盖句子中的一个单词,概率为 P m a s k P_{mask} Pmask,并用掩码标记创建一个新句子。假设执行一个情感分析任务,在数据集中有句子I was listening to music。基于概率 P m a s k P_{mask} Pmask,随机掩盖一个单词。假设掩盖了music,那么得到一个新句子I was listening to [MASK]。

    由于[MASK]是一个未知标记,因此模型将无法产生确信的logit向量。带有[MASK]标记的I was listening to [MASK]句子产生的是一个相对欠确信的logit向量,也就是比原句I was listening to music的置信度要低,这将有助于模型理解每个单词对于所属标记的关联度。

  • 基于词性的词汇替换方法

    在基于词性的词汇替换方法中,根据概率 P p o s P_{pos} Ppos,我们用其他单词代替句子中的某一个单词,但词性必须一致。

    以Where did you go这个句子为例,did是一个动词。假如用另一个动词来代替did,原句变成Where do you go。可以看到,我们使用do替换了did,得到了一个新句子。

  • n-gram采样方法

    在n-gram采样方法中,我们只从句子中以概率 P n g P_{ng} Png随机抽取n个单词,n的值在1到5中随机选择。

相关推荐
IE061 分钟前
深度学习系列76:流式tts的一个简单实现
人工智能·深度学习
GIS数据转换器6 分钟前
城市生命线安全保障:技术应用与策略创新
大数据·人工智能·安全·3d·智慧城市
一水鉴天1 小时前
为AI聊天工具添加一个知识系统 之65 详细设计 之6 变形机器人及伺服跟随
人工智能
m0_743106464 小时前
【论文笔记】MV-DUSt3R+:两秒重建一个3D场景
论文阅读·深度学习·计算机视觉·3d·几何学
m0_743106464 小时前
【论文笔记】TranSplat:深度refine的camera-required可泛化稀疏方法
论文阅读·深度学习·计算机视觉·3d·几何学
井底哇哇7 小时前
ChatGPT是强人工智能吗?
人工智能·chatgpt
Coovally AI模型快速验证7 小时前
MMYOLO:打破单一模式限制,多模态目标检测的革命性突破!
人工智能·算法·yolo·目标检测·机器学习·计算机视觉·目标跟踪
AI浩8 小时前
【面试总结】FFN(前馈神经网络)在Transformer模型中先升维再降维的原因
人工智能·深度学习·计算机视觉·transformer
可为测控8 小时前
图像处理基础(4):高斯滤波器详解
人工智能·算法·计算机视觉
一水鉴天8 小时前
为AI聊天工具添加一个知识系统 之63 详细设计 之4:AI操作系统 之2 智能合约
开发语言·人工智能·python