Transformer论文讲解

Transformer论文讲解

论文:Attention Is All You Need

作者:Ashish Vaswani, Noam Shazeer, Niki Parmar et al.

期刊/会议:NIPS 2017

PDF:arxiv.org/pdf/1706.03...

Github:Transformer包含在tensorflow

前言

Transformer摒弃了CNN和RNN,使用纯注意力机制构建的模型。Transformer的出现对NLP和CV领域产生深远影响,打破了NLP和CV之间的边界,大量基于该模型的变体层出不穷,如:BERT、ViT和CLIP等模型。可见,Transformer的重要性。图1展示的是整个Transformer模型的核心内容,本文将从模型结构缩放点积注意力多头注意力位置编码 和总结五部分来讲解。

1. 模型结构

图2展示了Transformer模型的结构,主要由编码器、解码器和输出层 组成。文章中作者将N设置为6层,为了便于残差连接,作者设定所有子层以及嵌入层都是512维度的输出,即 <math xmlns="http://www.w3.org/1998/Math/MathML"> d m o d e l = 512 d_{model}=512 </math>dmodel=512。

编码器 由6个相同层组成,每一层由相同的两部分组成:一个多头注意力层和一个Feed Forward层 。这两个部分后面都进行残差连接和LayerNorm归一化。Feed Forward层其实就是简单的MLP层,由两个线性层组成,中间用LeRU函数进行激活。多头注意力层在后续第3节会讲解。

编码器 也是由6个相同层组成,每一层由相同的三部分组成:一个带掩码的多头注意力层、一个多头注意力层和一个MLP层 。与编码器一样都在各部分后添加残差和LayerNorm归一化,不同的是在多头注意力层中输入的是由编码器输出的key、value和经过带掩码的多头注意力层输出的query

输出层 在Transformer中输出层由softmax和线性层组成。输出层可以针对不同下游任务进行更换和调整,以适应不同的任务需要。

在代码实现中,Transformer的输入的queries、keys和values实际都是同一个源文本或者目标文本的词编码。Input Embedding输入的是(X)词向量,而Output Embedding的右移一位(Y)词向量(第T时刻的输出作为预测第T+1时刻的输入,这样不断生成词右移)

我相信很多新人第一次看到这个模型结构后,会想由6层的编码器和解码器组成的模样是长啥样的?特别是编码器的输出怎么传递到解码器中,是编码器中每一层的输出对应解码器中的每一层,还是经过6层的编码器后再输入到解码器中的每一层。一开始我的想法是如图3(a)所示的图,但其实正确的是图3中的(b)所示。

2. 缩放点积注意力

缩放点积注意力(Scaled Dot-Product Attention)的结构如图4中(a)所示,输入维度为 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k d_k </math>dk的queries和keys和维度为 <math xmlns="http://www.w3.org/1998/Math/MathML"> d v d_v </math>dv的values,计算queries和keys的点积(自主性提示与非自主性提示配对),将每个键除以 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk ,并应用 softmax 函数来获得values的权重,与values点积,引导得出最配的values。其公式为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt{d_k}})V </math>Attention(Q,K,V)=softmax(dk QKT)V

与之前的点积注意力机制不同的是,除以一个 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 。作者在论文中解释 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 的作用:在维度 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k d_k </math>dk较小的时候,加性注意力机制和点击注意力机制的表现相同,但在较大的 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k d_k </math>dk下,加性注意力优于点积注意力,作者怀疑是由于较大的 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k d_k </math>dk值,点积数量级增长,除以 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 来缩放抵消这种影响

3. 多头注意力

作者发现将queries、keys和values分别线性投影到 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k d_k </math>dk、 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k d_k </math>dk 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> d v d_v </math>dv 维,而不是使用 <math xmlns="http://www.w3.org/1998/Math/MathML"> d m o d e l d_{model} </math>dmodel维的queries、keys和values,执行单个注意力函数是有益。多头注意力允许模型共同关注不同位置的不同表示子空间的信息。使用单个注意力头,平均抑制这一点。多头注意力公式为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> M u l t i H e a d ( Q , K , V ) = C o n c a t ( h e a d 1 , . . . , h e a d h ) W O w h e r e h e a d i = A t t e n t i o n ( Q W i Q , K W i K , V W i V ) MultiHead(Q,K,V)=Concat(head_1,...,head_h)W^O \\where \quad head_i = Attention(QW_i^Q,KW_i^K,VW_i^V) </math>MultiHead(Q,K,V)=Concat(head1,...,headh)WOwhereheadi=Attention(QWiQ,KWiK,VWiV)

其中投影的参数矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> W i Q ∈ R d m o d e l ⋅ d k , W i K ∈ R d m o d e l ⋅ d k , W i V ∈ R d m o d e l ⋅ d v 和 W O ∈ R h d v ⋅ d m o d e l W_i^Q \in R^{d_{model}\cdot d_k},W_i^K \in R^{d_{model}\cdot d_k},W_i^V \in R^{d_{model}\cdot d_v} 和 W^O \in R^{hd_v \cdot d_{model}} </math>WiQ∈Rdmodel⋅dk,WiK∈Rdmodel⋅dk,WiV∈Rdmodel⋅dv和WO∈Rhdv⋅dmodel。在论文中,作者使用h=8个并行注意力头,设置 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k d_k </math>dk= <math xmlns="http://www.w3.org/1998/Math/MathML"> d v d_v </math>dv= <math xmlns="http://www.w3.org/1998/Math/MathML"> d m o d e l / h d_{model}/h </math>dmodel/h=64

4. 位置编码

注意力机制没有像RNN那样记录序列中词元的顺序,因此,需要注入与序列中词元的相对位置和绝对位置。在输入中添加位置编码(Positional Encoding),位置编码维度为 <math xmlns="http://www.w3.org/1998/Math/MathML"> d m o d e l d_{model} </math>dmodel,与序列维度一致,便于求和。作者选择的位置编码函数为带频率的正弦和余弦函数。其公式为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> P E ( p o s , 2 i ) = sin ⁡ ( p o s / 1000 0 2 i / d m o d e l ) P E ( p o s , 2 i + 1 ) = cos ⁡ ( p o s / 1000 0 2 i / d m o d e l ) PE(pos,2i)=\sin(pos/10000^{2i/d_model})\\ PE(pos,2i+1)=\cos(pos/10000^{2i/d_model}) </math>PE(pos,2i)=sin(pos/100002i/dmodel)PE(pos,2i+1)=cos(pos/100002i/dmodel)

其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> p o s pos </math>pos 是位置, <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i是维度。作者选择这个函数是因为他们假设它将能使模型轻松学习关注相对位置,波长形成从2π到10000·2π的几何级数。因为对于任何固定的偏移 <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K,可以将 <math xmlns="http://www.w3.org/1998/Math/MathML"> P E p o s + k PE_{pos+k} </math>PEpos+k 表示为 <math xmlns="http://www.w3.org/1998/Math/MathML"> P E p o s PE_{pos} </math>PEpos 的线性函数。对于位置编码的理解,个人推荐看Transformer Architecture: The Positional Encoding的英文文章。

5. 总结

作者提出了 Transformer模型,这是第一个完全基于注意力的序列转导模型,用多头自注意力替换了编码器-解码器架构中最常用的RNN层。对于翻译任务,Transformer 可以比基于RNN或CNN的架构快得多(2017年之前)。在 WMT 2014 英德和 WMT 2014 英法翻译任务上达到了新的技术水平。作者展望了模型可以运用到其他NLP或CV的任务,开启了Transformer时代。


以上是我个人论文解读,可能存在很多不足。

若有问题,可以在评论区指出,大家一起交流讨论学习!

  • 希望这篇文章能够帮到你
  • 如有遇到问题,可以在评论区留言交流
  • 若有侵权,联系必删!

参考文章

动手学深度学习-注意力机制

Transformer Architecture: The Positional Encoding

相关推荐
卷心菜小温15 小时前
【BUG】P-tuningv2微调ChatGLM2-6B时所踩的坑
python·深度学习·语言模型·nlp·bug
龙的爹233319 小时前
论文翻译 | Generated Knowledge Prompting for Commonsense Reasoning
人工智能·gpt·机器学习·语言模型·自然语言处理·nlp·prompt
龙的爹233320 小时前
论文翻译 | Model-tuning Via Prompts Makes NLP Models Adversarially Robust
人工智能·gpt·语言模型·自然语言处理·nlp·prompt
萱仔学习自我记录2 天前
常用大语言模型简单介绍
人工智能·python·自然语言处理·nlp
龙的爹23332 天前
论文翻译 | LLaMA-Adapter :具有零初始化注意的语言模型的有效微调
人工智能·gpt·语言模型·自然语言处理·nlp·prompt·llama
爱敲代码的小崔2 天前
NLP自然语言处理
人工智能·自然语言处理·nlp
OptimaAI4 天前
【LLM论文日更】| 通过指令调整进行零样本稠密检索的无监督文本表示学习
人工智能·深度学习·语言模型·自然语言处理·nlp
龙的爹23334 天前
论文翻译 | ReWOO: 高效增强语言模型的解耦推理
人工智能·语言模型·自然语言处理·nlp·prompt·agi
Ven%5 天前
深度学习速通系列:强大的中文自然语言处理工具之Pyltp的使用
人工智能·python·深度学习·自然语言处理·nlp
源大模型7 天前
源2.0全面适配百度PaddleNLP,大模型开发开箱即用
人工智能·语言模型·开源·nlp·源大模型