深度学习笔记之BERT(四)DistilBERT

深度学习笔记之DistilBERT

引言

本节将介绍一种参数、消耗计算资源少的 BERT \text{BERT} BERT改进模型------ DistilBERT \text{DistilBERT} DistilBERT模型。

回顾:BERT模型的弊端

虽然 BERT \text{BERT} BERT性能优秀并且对各类 NLP \text{NLP} NLP下游任务通用,但依然存在一些弊端:

  • RoBERTa \text{RoBERTa} RoBERTa模型中提到过下句预测策略( Next Sentence Prediction,NSP \text{Next Sentence Prediction,NSP} Next Sentence Prediction,NSP)在训练任务过程中表现得并不优秀,并删除了该策略;
  • ALBERT \text{ALBERT} ALBERT模型中也提到过 BERT \text{BERT} BERT模型的参数量过大,从而消耗更多的时间和计算资源。

本节就从计算资源 开始,介绍一种新的算法模式------ DistilBERT \text{DistilBERT} DistilBERT,一种基于 BERT \text{BERT} BERT的知识蒸馏版本。

什么是知识蒸馏

即便是使用预训练好 的 BERT \text{BERT} BERT模型,我们在使用其执行下游任务时,依然需要消耗相当多的计算资源。例如:想要将一个模型迁移到更小的硬件上,例如手机等移动设备、笔记本电脑,它使用的计算资源、空间占用依然很高,计算效率较差。如果能够得到一个和预训练 BERT \text{BERT} BERT模型相差不大,但模型体量更小,参数更少、运行得更快、占用空间更少的模型,在使用过程中会更加方便。这体现了知识蒸馏的必要性

什么是知识蒸馏 呢 ? ? ? 它是指:基于一个已预训练好的模型作为教师模型,训练一个学生模型模仿教师模型,使学生模型的性能尽可能接近教师模型的过程。后续使用学生模型执行相关下游任务时,由于学生模型的体量更小,从而达到运行更快、占用空间更少的目的。

DistilBERT模型架构

基于上述理念, DistilBERT \text{DistilBERT} DistilBERT的模型架构表示如下:

结合论文中作者的描述观察:
论文链接在文章末尾~

  • 教师模型是一个 BERT-base \text{BERT-base} BERT-base模型 ( param:110 M ) (\text{param:110 M}) (param:110 M),它是由若干相互堆叠的注意力层构成。由于它已经是预训练好的,因而它并不是我们关注的重点;
  • 对学生模型 ( param:66 M ) (\text{param:66 M}) (param:66 M)的设计是:层内维度 (神经元数量)与教师模型相同的基础上,将 Encoder \text{Encoder} Encoder层数量减半,并且在初始化过程中从教师的 Encoder \text{Encoder} Encoder层中每两层中选择一层作为学生对应 Encoder \text{Encoder} Encoder层的初始化。
  • 整个 Inference \text{Inference} Inference过程中,教师模型没有参与;只有学生模型在反向传播过程中存在梯度更新。

那么如何实现将有效信息从教师模型蒸馏 到学生模型呢 ? ? ? 这意味着模型的训练过程将不同于传统的训练过程,因为训练学生模型的主要目标是模仿教师模型,从而训练策略发生一系列变化。

softmax温度函数

在介绍 DistilBERT \text{DistilBERT} DistilBERT模型策略之前,先介绍一下 Softmax \text{Softmax} Softmax温度函数。它的函数表达式如下所示:
P i = exp ⁡ ( Z i / T ) ∑ j exp ⁡ ( Z j / T ) \mathcal P_i = \frac{\exp (\mathcal Z_i / \mathcal T)}{\sum_{j} \exp (\mathcal Z_j / \mathcal T)} Pi=∑jexp(Zj/T)exp(Zi/T)

其中 T \mathcal T T表示温度系数。当 T = 1 \mathcal T = 1 T=1时,该表达式就是标准的 Softmax \text{Softmax} Softmax函数。而 T \mathcal T T的作用是增加归一化后概率分布的平滑程度,当 T \mathcal T T值越大时,可以让其他类别的信息也能够体现出来。

简单示例:

python 复制代码
import numpy as np
np.random.seed(42)

def softmax_temperature(l_input, temp):

    sum_result = sum([np.exp(j / temp) for j in l_input])
    return [np.exp(i / temp) / sum_result for i in l_input]


if __name__ == '__main__':
    l = np.random.randn(5)
    print(l)
    print(softmax_temperature(l, temp=1))
    print(softmax_temperature(l, temp=2))
    print(softmax_temperature(l, temp=3))

返回结果如下:
相比于 T = 1 \mathcal T = 1 T=1时各项元素之间差异较大的情况, T = 2 , 3 \mathcal T=2,3 T=2,3各项元素之间的差异明显缩小很多

python 复制代码
# randn_result
[ 0.49671415 -0.1382643   0.64768854  1.52302986 -0.23415337]

# softmax_temp_result(t = 1)
[0.1676398230582659, 0.08884020560809631, 0.19495956004254214, 0.4678433281532621, 0.0807170831378335]
# softmax_temp_result(t = 2)
[0.1933922588815045, 0.14078463743730418, 0.20855603294950892, 0.3230730334611907, 0.13419403727049165]
# softmax_temp_result(t = 3)
[0.19792006953161892, 0.16016487849367927, 0.2081352388053251, 0.27865333836010286, 0.15512647480927383]

而在知识蒸馏 过程中,我们通常也使用这种 Softmax \text{Softmax} Softmax温度函数将隐藏知识从教师模型迁移至学生模型中。

DistilBERT模型策略

介绍 DistilBERT \text{DistilBERT} DistilBERT模型策略之前,可以对原始 BERT \text{BERT} BERT模型策略做一些简单变动。在RoBERTa \text{RoBERTa} RoBERTa模型一节中,提到了下句预测任务的效果不佳,在这里完全可以直接将其移除 ,仅使用掩码语言模型 ( Masked Language Model,MLM ) (\text{Masked Language Model,MLM}) (Masked Language Model,MLM)进行训练,并且使用动态掩码进行训练。

对于 DistilBERT \text{DistilBERT} DistilBERT模型,它包含三种训练策略:

  • 掩码语言模型策略 ( Masked Language Model,MLM ) (\text{Masked Language Model,MLM}) (Masked Language Model,MLM):该策略是从 BERT \text{BERT} BERT模型中继承过来的策略,仅使用在学生模型中;此时掩码样本输入对应的标签 被称作硬目标,该策略中学生模型使用标准 Softmax \text{Softmax} Softmax函数 ( T = 1 ) (\mathcal T = 1) (T=1)进行预测,称作硬预测。
  • 蒸馏策略 ( Distillation loss ) (\text{Distillation loss}) (Distillation loss):该策略使用 Softmax \text{Softmax} Softmax温度函数分别对教师模型和学生模型的输出进行预测,其中教师模型的预测结果被称作软目标,对应学生模型的预测被称作软预测,通过将两者做交叉熵操作来使两种概率分布尽可能地逼近:
    对于教师模型经过 Softmax \text{Softmax} Softmax温度函数的软预测输出 T ( x ) ∈ R 1 × N \mathcal T(x) \in \mathbb R^{1 \times N} T(x)∈R1×N和对应学生模型的软预测输出 S ( x ) ∈ R 1 × N \mathcal S(x) \in \mathbb R^{1 \times N} S(x)∈R1×N,其中 N N N表示词表长度。
    T ( x ) = ( t 1 , t 2 , ⋯   , t N ) S ( x ) = ( s 1 , s 2 , ⋯   . s N ) L c r o s s = − ∑ i = 1 N t i ∗ log ⁡ ( s i ) \begin{aligned} & \mathcal T(x) = (t_1,t_2,\cdots,t_{N}) \\ & \mathcal S(x) = (s_1,s_2,\cdots.s_{N}) \\ & \mathcal L_{cross} = - \sum_{i=1}^{N} t_i * \log(s_i) \end{aligned} T(x)=(t1,t2,⋯,tN)S(x)=(s1,s2,⋯.sN)Lcross=−i=1∑Nti∗log(si)
  • 余弦嵌入策略 ( Cosine Embedding loss ) (\text{Cosine Embedding loss}) (Cosine Embedding loss):该策略思想与蒸馏策略 基本相同,都是为了使软目标和软预测这两种概率分布尽可能逼近,具体方式是计算两预测分布向量之间夹角的余弦值:
    L c o s i n e = 1 − cos ⁡ [ T ( x ) , S ( x ) ] \mathcal L_{cosine} = 1 - \cos \left[\mathcal T(x), \mathcal S(x)\right] Lcosine=1−cos[T(x),S(x)]
    理想状态下, L c o s i n e \mathcal L_{cosine} Lcosine取到最小值 0 0 0时,此时向量 T ( x ) \mathcal T(x) T(x)和 S ( x ) \mathcal S(x) S(x)已经被完全对齐。这也意味着通过一系列参数调整后,教师和学生模型的隐藏知识已被对齐。

最终使用的完整损失就是上述三种策略结果进行整合:
L D i s t i l = L M L M + L c r o s s + L c o s i n e 3 \mathcal L_{Distil} = \frac{\mathcal L_{MLM} + \mathcal L_{cross} + \mathcal L_{cosine}}{3} LDistil=3LMLM+Lcross+Lcosine

效果展示

根据论文中的内容, DistilBERT \text{DistilBERT} DistilBERT模型在各下游任务中的效果表示如下:

DistilBERT \text{DistilBERT} DistilBERT模型可以达到 BERT-base \text{BERT-base} BERT-base模型几乎 97 97 97%的准确度。在准确度没有丢失过多的情况下,并且该模型更加轻便,我们可以将其部署到终端设备上,与 BERT-base \text{BERT-base} BERT-base模型相比, DistilBERT \text{DistilBERT} DistilBERT模型的运算速度提高了 60 60 60%。

并且 Hugging Face \text{Hugging Face} Hugging Face工作人员针对问答任务对 DistilBERT \text{DistilBERT} DistilBERT模型进行微调,将其部署在 IPhone 7 Plus \text{IPhone 7 Plus} IPhone 7 Plus上,发现该模型大小仅有 207 MB 207\text{MB} 207MB,并且该模型的运算速度相比 BERT-base \text{BERT-base} BERT-base快了 71 71 71%。

Reference \text{Reference} Reference:
论文连接

《BERT基础教程------Transformer大模型实战》

相关推荐
程序猿000001号21 分钟前
使用PyTorch Lightning简化深度学习模型开发
pytorch·深度学习·目标检测
不如语冰1 小时前
pytorch学习笔记汇总
人工智能·pytorch·笔记·python·深度学习·神经网络·学习
OpenBayes1 小时前
OpenBayes 教程上新丨腾讯 Hunyuan3D-1.0 上线,10s 实现 3D 图像生成
人工智能·深度学习·3d·ai·3d 模型·腾讯混元·教程上新
scdifsn1 小时前
动手学深度学习11.1. 优化和深度学习-笔记&练习(PyTorch)
pytorch·笔记·深度学习·深度学习优化
知来者逆1 小时前
计算机视觉单阶段实例分割实践指南与综述
人工智能·深度学习·机器学习·计算机视觉·目标跟踪·目标分割
Charge_A1 小时前
深度学习作业 - 作业十一 - LSTM
人工智能·深度学习·lstm
MarkHD3 小时前
第二十三天 神经网络构建-多层感知机(MLP)
人工智能·深度学习·神经网络
FreedomLeo14 小时前
Python机器学习笔记(七、深度学习-神经网络)
python·深度学习·神经网络·机器学习
MYT_flyflyfly5 小时前
LRM-典型 Transformer 在视觉领域的应用,单个图像生成3D图像
人工智能·深度学习·transformer
明月醉窗台5 小时前
深度学习(15)从头搭建模型到训练、预测示例总结
人工智能·python·深度学习·目标检测·计算机视觉