基于 Transformer 的编码器和解码器结构如图所示,左侧和右侧分别对应着编码器(Encoder)和解码器(Decoder)结构。它们均由若干个基本的 Transformer 块(Block)组成(对应着图中的灰色框)。这里 <math xmlns="http://www.w3.org/1998/Math/MathML"> N × N \times </math>N× 表示进行了 <math xmlns="http://www.w3.org/1998/Math/MathML"> N N </math>N 次堆叠。每个 Transformer 块都接收一个向量序列 <math xmlns="http://www.w3.org/1998/Math/MathML"> { x i } i = 1 t \{x_i\}{i=1}^t </math>{xi}i=1t 作为输入,并输出一个等长的向量序列作为输出 <math xmlns="http://www.w3.org/1998/Math/MathML"> { y i } i = 1 t \{y_i\}{i=1}^t </math>{yi}i=1t。这里的 <math xmlns="http://www.w3.org/1998/Math/MathML"> x i x_i </math>xi 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> y i y_i </math>yi 分别对应着文本序列中的一个词元的表示。而 <math xmlns="http://www.w3.org/1998/Math/MathML"> y i y_i </math>yi 是当前 Transformer 块对输入 <math xmlns="http://www.w3.org/1998/Math/MathML"> x i x_i </math>xi 进一步整合其上下文语义后对应的输出。在从输入 <math xmlns="http://www.w3.org/1998/Math/MathML"> { x i } i = 1 t \{x_i\}{i=1}^t </math>{xi}i=1t 到输出 <math xmlns="http://www.w3.org/1998/Math/MathML"> { y i } i = 1 t \{y_i\}{i=1}^t </math>{yi}i=1t 的语义抽象过程中,主要涉及到如下几个模块:
- 注意力层:使用多头注意力(Multi-Head Attention)机制整合上下文语义,它使得序列中任意两个单词之间的依赖关系可以直接被建立而不基于传统的循环结构,从而更好地解决文本的长程依赖。
- 位置感知前馈层(Position-wise FFN):通过全连接层对输入文本序列中的每个单词表示进行更复杂的变换。
- 残差连接:对应图中的 Add 部分。它是一条分别作用在上述两个子层当中的直连通路,被用于连接它们的输入与输出。从而使得信息流动更加高效,有利于模型的优化。
- 层归一化:对应图中的 Norm 部分。作用于上述两个子层的输出表示序列中,对表示序列进行层归一化操作,同样起到稳定优化的作用。
嵌入表示层
对于输入文本序列,首先通过输入嵌入层(Input Embedding)将每个单词转换为其相应的向量表示。通常,直接对每个单词创建一个向量表示。由于 Transformer 结构不再使用基于循环的方式建模文本输入,序列中不再有任何信息能够提示模型单词之间的相对位置关系。在送入编码器端建模其上下文语义之前,一个非常重要的操作是在词嵌入中加入位置编码(Positional Encoding)这一特征。具体来说,序列中每一个单词所在的位置都对应一个向量。这个向量会与单词表示对应相加并送入到后续模块中做进一步处理。在训练的过程当中,模型会自动地学习到如何利用这部分位置信息。
为了得到不同位置对应的编码,Transformer 结构使用不同频率的正余弦函数如下所示:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> PE(pos, 2i) = sin ( pos 1000 0 2 i / d ) \text{PE(pos, 2i)} = \sin\left(\frac{\text{pos}}{10000^{2i/d}}\right) </math>PE(pos, 2i)=sin(100002i/dpos)
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> PE(pos, 2i + 1) = cos ( pos 1000 0 2 i / d ) \text{PE(pos, 2i + 1)} = \cos\left(\frac{\text{pos}}{10000^{2i/d}}\right) </math>PE(pos, 2i + 1)=cos(100002i/dpos)
其中,pos 表示单词所在的位置,2i 和 2i + 1 表示位置编码向量中的对应维度,d 则对应位置编码的总维度(即d_model词嵌入维度 默认512)。通过上面这种方式计算位置编码有这样几个好处:首先,正余弦函数的范围是在 [-1,+1],导出的位置编码与原词嵌入相加不会使得结果偏离过远而破坏原有单词的语义信息。其次,依据三角函数的基本性质,可以得到第 pos + k 个位置编码是第 pos 个位置编码的线性组合,这就意味着位置编码中蕴含着单词之间的距离信息。
数学理解,利用三角函数的加法定理,我们可以将这些表达式写成:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> PE(pos+k, 2i) = sin ( pos 1000 0 2 i / d + k 1000 0 2 i / d ) \text{PE(pos+k, 2i)} = \sin\left(\frac{\text{pos}}{10000^{2i/d}} + \frac{k}{10000^{2i/d}}\right) </math>PE(pos+k, 2i)=sin(100002i/dpos+100002i/dk)
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> PE(pos+k, 2i+1) = cos ( pos 1000 0 2 i / d + k 1000 0 2 i / d ) \text{PE(pos+k, 2i+1)} = \cos\left(\frac{\text{pos}}{10000^{2i/d}} + \frac{k}{10000^{2i/d}}\right) </math>PE(pos+k, 2i+1)=cos(100002i/dpos+100002i/dk)根据三角函数的加法定理:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> sin ( pos 1000 0 2 i / d + k 1000 0 2 i / d ) = sin ( pos 1000 0 2 i / d ) cos ( k 1000 0 2 i / d ) + cos ( pos 1000 0 2 i / d ) sin ( k 1000 0 2 i / d ) \sin\left(\frac{\text{pos}}{10000^{2i/d}} + \frac{k}{10000^{2i/d}}\right) = \sin\left(\frac{\text{pos}}{10000^{2i/d}}\right) \cos\left(\frac{k}{10000^{2i/d}}\right) + \cos\left(\frac{\text{pos}}{10000^{2i/d}}\right) \sin\left(\frac{k}{10000^{2i/d}}\right) </math>sin(100002i/dpos+100002i/dk)=sin(100002i/dpos)cos(100002i/dk)+cos(100002i/dpos)sin(100002i/dk)
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> cos ( pos 1000 0 2 i / d + k 1000 0 2 i / d ) = cos ( pos 1000 0 2 i / d ) cos ( k 1000 0 2 i / d ) − sin ( pos 1000 0 2 i / d ) sin ( k 1000 0 2 i / d ) \cos\left(\frac{\text{pos}}{10000^{2i/d}} + \frac{k}{10000^{2i/d}}\right) = \cos\left(\frac{\text{pos}}{10000^{2i/d}}\right) \cos\left(\frac{k}{10000^{2i/d}}\right) - \sin\left(\frac{\text{pos}}{10000^{2i/d}}\right) \sin\left(\frac{k}{10000^{2i/d}}\right) </math>cos(100002i/dpos+100002i/dk)=cos(100002i/dpos)cos(100002i/dk)−sin(100002i/dpos)sin(100002i/dk)这表明,第 <math xmlns="http://www.w3.org/1998/Math/MathML"> pos + k \text{pos} + k </math>pos+k 个位置编码的每个维度都可以通过第 <math xmlns="http://www.w3.org/1998/Math/MathML"> pos \text{pos} </math>pos 个位置编码的相应维度进行线性组合得到。
代码实现:
python
class PositionalEncoder(nn.Module):
def __init__(self, d_model, max_seq_len = 80):
super().__init__()
self.d_model = d_model
# 根据 pos 和 i 创建一个常量 PE 矩阵
pe = torch.zeros(max_seq_len, d_model)
for pos in range(max_seq_len):
for i in range(0, d_model, 2):
pe[pos, i] = math.sin(pos / (10000 ** (i/d_model)))
pe[pos, i + 1] = math.cos(pos / (10000 ** (i/d_model)))
pe = pe.unsqueeze(0)
# pe的维度是 [1, max_seq_len, d_model]
self.register_buffer('pe', pe) #缓冲区,不参与梯度更新
def forward(self, x):
# 使得单词嵌入表示相对大一些
x = x * math.sqrt(self.d_model)
# 增加位置常量到单词嵌入表示中
seq_len = x.size(1)
x = x + self.pe[:, :seq_len] # self.pe[:,:seq_len] 的形状是 [1, seq_len, d_model]。
return x
注意力层
自注意力(Self-Attention)操作是基于 Transformer 的机器翻译模型的基本操作,在源语言的编码和目标语言的生成中频繁地被使用以建模源语言、目标语言任意两个单词之间的依赖关系。给定由单词语义嵌入及其位置编码叠加得到的输入表示 <math xmlns="http://www.w3.org/1998/Math/MathML"> { x i ∈ R d } i = 1 t \{x_i \in \mathbb{R}^d\}_{i=1}^t </math>{xi∈Rd}i=1t,为了实现对上下文语义依赖的建模,进一步引入在自注意力机制中涉及到的三个元素:查询 <math xmlns="http://www.w3.org/1998/Math/MathML"> q i q_i </math>qi (Query),键 <math xmlns="http://www.w3.org/1998/Math/MathML"> k i k_i </math>ki (Key),值 <math xmlns="http://www.w3.org/1998/Math/MathML"> v i v_i </math>vi (Value)。在编码输入序列中每一个单词的表示过程中,这三个元素用于计算上下文单词对应的权重得分。直观地说,这些权重反映了在编码当前单词的表示时,对于上下文不同部分所需要的关注程度。具体来说,如图所示,通过三个线性变换 <math xmlns="http://www.w3.org/1998/Math/MathML"> W Q ∈ R d × d q , W K ∈ R d × d k , W V ∈ R d × d v W^Q \in \mathbb{R}^{d \times d_q}, W^K \in \mathbb{R}^{d \times d_k}, W^V \in \mathbb{R}^{d \times d_v} </math>WQ∈Rd×dq,WK∈Rd×dk,WV∈Rd×dv,将输入序列中的每一个单词表示 <math xmlns="http://www.w3.org/1998/Math/MathML"> x i x_i </math>xi 转换为其对应的 <math xmlns="http://www.w3.org/1998/Math/MathML"> q i ∈ R d q , k i ∈ R d k , v i ∈ R d v q_i \in \mathbb{R}^{d_q}, k_i \in \mathbb{R}^{d_k}, v_i \in \mathbb{R}^{d_v} </math>qi∈Rdq,ki∈Rdk,vi∈Rdv 向量。
为了得到编码单词 <math xmlns="http://www.w3.org/1998/Math/MathML"> x i x_i </math>xi 时所需要关注的上下文信息,通过位置 <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i 查询向量与其他位置的键向量做点积得到匹配分数 <math xmlns="http://www.w3.org/1998/Math/MathML"> q i ⋅ k 1 , q i ⋅ k 2 , . . . , q i ⋅ k t q_i \cdot k_1, q_i \cdot k_2, ..., q_i \cdot k_t </math>qi⋅k1,qi⋅k2,...,qi⋅kt。为防止过大的匹配分数在后续 Softmax 计算过程中导致的梯度爆炸以及收敛效率差的问题,这些得分会除放缩因子 <math xmlns="http://www.w3.org/1998/Math/MathML"> d \sqrt{d} </math>d 以稳定优化。放缩后的得分经过 Softmax 归一化为概率之后,与其他位置的值向量相乘来聚合希望关注的上下文信息,并最小化不相关信息的干扰。上述计算过程可以被形式化地表达如下:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Z = Attention ( Q , K , V ) = Softmax ( Q K ⊤ d ) V Z = \text{Attention}(Q, K, V) = \text{Softmax}\left(\frac{QK^\top}{\sqrt{d}}\right)V </math>Z=Attention(Q,K,V)=Softmax(d QK⊤)V
其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> Q ∈ R L × d q , K ∈ R L × d k , V ∈ R L × d v Q \in \mathbb{R}^{L \times d_q}, K \in \mathbb{R}^{L \times d_k}, V \in \mathbb{R}^{L \times d_v} </math>Q∈RL×dq,K∈RL×dk,V∈RL×dv 分别表示输入序列中的不同单词的 <math xmlns="http://www.w3.org/1998/Math/MathML"> q , k , v q, k, v </math>q,k,v 向量拼接组成的矩阵, <math xmlns="http://www.w3.org/1998/Math/MathML"> L L </math>L 表示序列长度, <math xmlns="http://www.w3.org/1998/Math/MathML"> Z ∈ R L × d v Z \in \mathbb{R}^{L \times d_v} </math>Z∈RL×dv 表示自注意力操作的输出。为了进一步增强自注意力机制整合上下文信息的能力,提出了多头自注意力(Multi-head Attention)的机制,以关注上下文的不同侧面。具体来说,上下文中每一个单词的表示 <math xmlns="http://www.w3.org/1998/Math/MathML"> x i x_i </math>xi 经过多组线性 <math xmlns="http://www.w3.org/1998/Math/MathML"> { W j Q , W j K , W j V } j = 1 N \{W_j^Q, W_j^K, W_j^V\}{j=1}^N </math>{WjQ,WjK,WjV}j=1N 映射到不同的表示子空间中。公式 2.3 会在不同的子空间中分别计算并得到不同的上下文相关的单词序列表示 <math xmlns="http://www.w3.org/1998/Math/MathML"> { Z j } j = 1 N \{Z_j\}{j=1}^N </math>{Zj}j=1N。最终,线性变换 <math xmlns="http://www.w3.org/1998/Math/MathML"> W O ∈ R ( N d v ) × d W^O \in \mathbb{R}^{(Nd_v) \times d} </math>WO∈R(Ndv)×d 用于综合不同子空间中的上下文表示并形成自注意力层最终的输出 <math xmlns="http://www.w3.org/1998/Math/MathML"> { x i ∈ R d } i = 1 t \{x_i \in \mathbb{R}^d\}_{i=1}^t </math>{xi∈Rd}i=1t。
使用 PyTorch 实现的自注意力层参考代码如下:
python
class MultiHeadAttention(nn.Module):
def __init__(self, heads, d_model, dropout = 0.1):
super().__init__()
self.d_model = d_model
self.d_k = d_model // heads #每个头的维度
self.h = heads #头数
self.q_linear = nn.Linear(d_model, d_model)
self.v_linear = nn.Linear(d_model, d_model)
self.k_linear = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
self.out = nn.Linear(d_model, d_model)
# q:查询矩阵,形状为 [batch_size, num_heads, seq_len, d_k]。
# k:键矩阵,形状为 [batch_size, num_heads, seq_len, d_k]。
def attention(self, q, k, v, d_k, mask=None, dropout=None):
# scores [batch_size, num_heads, seq_len, seq_len]
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
mask = mask.unsqueeze(1)
scores = scores.masked_fill(mask == 0, -1e9)
scores = F.softmax(scores, dim=-1)
if dropout is not None:
scores = dropout(scores)
# output [batch_size, num_heads, seq_len, d_k]
output = torch.matmul(scores, v)
return output
def forward(self, q, k, v, mask=None):
bs = q.size(0) # batch size
# # 进行线性操作划分为成 h 个头
k = self.k_linear(k).view(bs, -1, self.h, self.d_k) # [batch_size, seq_len, num_heads, d_head]
q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
v = self.v_linear(v).view(bs, -1, self.h, self.d_k)
# 矩阵转置
k = k.transpose(1,2)
q = q.transpose(1,2)
v = v.transpose(1,2)
# scores [batch_size, num_heads, seq_len, d_head]
scores = self.attention(q, k, v, self.d_k, mask, self.dropout)
# cancat -> [batch_size, seq_len, num_heads, d_head] -> [batch_size, seq_len, d_model]
cancat = scores.transpose(1,2).contiguous().view(bs, -1, self.d_model)
# [batch_size, seq_len, d_model]
output = self.out(cancat)
return output
前馈层
前馈层接受自注意力子层的输出作为输入,并通过一个带有 ReLU 激活函数的两层全连接网络对输入进行更加复杂的非线性变换。实验证明,这一非线性变换会对模型最终的性能产生十分重要的影响。
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> FFN ( x ) = ReLU ( x W 1 + b 1 ) W 2 + b 2 \text{FFN}(x) = \text{ReLU}(xW_1 + b_1)W_2 + b_2 </math>FFN(x)=ReLU(xW1+b1)W2+b2
其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> W 1 , b 1 , W 2 , b 2 W_1, b_1, W_2, b_2 </math>W1,b1,W2,b2 表示前馈子层的参数。实验结果表明,增大前馈子层隐藏状态的维度有利于提升最终翻译结果的质量,因此,前馈子层隐藏状态的维度一般比自注意力子层要大。
使用 PyTorch 实现的前馈层参考代码如下:
python
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff = 2048, dropout = 0.1):
super().__init__()
# d_ff 默认设置为 2048
self.linear_1 = nn.Linear(d_model, d_ff)
self.dropout = nn.Dropout(dropout)
self.linear_2 = nn.Linear(d_ff, d_model)
def forward(self, x):
x = self.dropout(F.relu(self.linear_1(x)))
x = self.linear_2(x)
return x
残差连接与层归一化
由于 Transformer 结构组成的网络结构通常都是非常庞大的。编码器和解码器均由很多层基本的 Transformer 块组成,每一层当中包含复杂的非线性映射,这就导致模型的训练比较困难。因此,研究者们在 Transformer 块中进一步引入了残差连接与层归一化技术以进一步提升训练的稳定性。具体来说,残差连接主要是指使用一条直连通道直接将对应子层的输入连接到输出上去,从而避免由于网络过深在优化过程中潜在的梯度消失问题:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> x l + 1 = f ( x l ) + x l x^{l+1} = f(x^l) + x^l </math>xl+1=f(xl)+xl
其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> x l x^l </math>xl 表示第 <math xmlns="http://www.w3.org/1998/Math/MathML"> l l </math>l 层的输入, <math xmlns="http://www.w3.org/1998/Math/MathML"> f ( ⋅ ) f(\cdot) </math>f(⋅) 表示一个映射函数。此外,为了使得每一层的输入输出稳定在一个合理的范围内,层归一化技术被进一步引入每个 Transformer 块中:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> LN ( x ) = α ⋅ x − μ σ + b \text{LN}(x) = \alpha \cdot \frac{x - \mu}{\sigma} + b </math>LN(x)=α⋅σx−μ+b
其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> μ \mu </math>μ 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> σ \sigma </math>σ 分别表示均值和方差,用于将数据平移缩放到均值为 0,方差为 1 的标准分布, <math xmlns="http://www.w3.org/1998/Math/MathML"> α \alpha </math>α 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> b b </math>b 是可学习的参数。层归一化技术可以有效地缓解优化过程中潜在的不稳定、收敛速度慢等问题。
使用 PyTorch 实现的层归一化参考代码如下:
python
class Norm(nn.modules):
def __init__(self, d_model, eps = 1e-6):
super().__init__()
self.size = d_model
# 层归一化包含两个可以学习的参数
self.alpha = nn.Parameter(torch.ones(self.size))
self.bias = nn.Parameter(torch.zeros(self.size))
self.eps = eps
def forward(self, x):
norm = self.alpha * (x - x.mean(dim=-1, keepdim=True)) / (x.std(dim=-1, keepdim=True) + self.eps) + self.bias
return norm
编码器和解码器结构
基于上述模块,根据图所给出的网络架构,编码器端可以较为容易实现。相比于编码器端, 解码器端要更复杂一些。具体来说,解码器的每个 Transformer 块的第一个自注意力子层额外增加了注意力掩码,对应图中的掩码多头注意力(Masked Multi-Head Attention)部分。这主要是因为在 翻译的过程中,编码器端主要用于编码源语言序列的信息,而这个序列是完全已知的,因而编码 器仅需要考虑如何融合上下文语义信息即可。而解码端则负责生成目标语言序列,这一生成过程 是自回归的,即对于每一个单词的生成过程,仅有当前单词之前的目标语言序列是可以被观测的, 因此这一额外增加的掩码是用来掩盖后续的文本信息,以防模型在训练阶段直接看到后续的文本 序列进而无法得到有效地训练。
此外,解码器端还额外增加了一个多头交叉注意力(Multi-Head Cross-attention)模块,使用交叉注意力(Cross-attention)方法,同时接收来自编码器端的输出以及当前 Transformer 块的前一个掩码注意力层的输出。查询是通过解码器前一层的输出进行投影的,而键和值是使用编码器的输出 进行投影的。它的作用是在翻译的过程当中,为了生成合理的目标语言序列需要观测待翻译的源语言序列是什么。基于上述的编码器和解码器结构,待翻译的源语言文本,首先经过编码器端的 每个Transformer块对其上下文语义的层层抽象,最终输出每一个源语言单词上下文相关的表示。 解码器端以自回归的方式生成目标语言文本,即在每个时间步 t,根据编码器端输出的源语言文本表示,以及前 t − 1 个时刻生成的目标语言文本,生成当前时刻的目标语言单词。
使用 PyTorch 实现的编码器参考代码如下:
python
class EncoderLayer(nn.Module):
def __init__(self, d_model, heads, dropout = 0.1):
super().__init__()
self.norm_1 = Norm(d_model)
self.norm_2 = Norm(d_model)
self.attn = MultiHeadAttention(heads, d_model, dropout = dropout)
self.ff = FeedForward(d_model, dropout = dropout)
self.dropout_1 = nn.Dropout(dropout)
self.dropout_2 = nn.Dropout(dropout)
def forward(self, x, mask):
# x的维度是 [batch_size, seq_len, d_model]
attn_output = self.attn(x, x, x, mask)
attn_output = self.dropout_1(attn_output)
x = x + attn_output
x = self.norm_1(x)
ff_output = self.ff(x)
ff_output = self.dropout_2(ff_output)
x = x + ff_output
x = self.norm_2(x)
return x
class Encoder(nn.Module):
def __init__(self, vocab_size, d_model, N, heads, dropout):
super().__init__()
self.N = N
# 词嵌入
self.embed = nn.Embedding(vocab_size, d_model)
self.pe = PositionalEncoder(d_model, dropout=dropout)
self.layers = nn.ModuleList([EncoderLayer(d_model, heads, dropout) for _ in range(N)])
self.norm = Norm(d_model)
def forward(self, src, mask):
x = self.embed(src)
x = self.pe(x)
for i in range(self.N):
x = self.layers[i](x, mask)
return self.norm(x)
使用 PyTorch 实现的解码器参考代码如下:
python
class DecoderLayer(nn.Module):
def __init__(self, d_model, heads, dropout=0.1):
super().__init__()
self.norm_1 = Norm(d_model)
self.norm_2 = Norm(d_model)
self.norm_3 = Norm(d_model)
self.dropout_1 = nn.Dropout(dropout)
self.dropout_2 = nn.Dropout(dropout)
self.dropout_3 = nn.Dropout(dropout)
self.attn_1 = MultiHeadAttention(heads, d_model, dropout=dropout)
self.attn_2 = MultiHeadAttention(heads, d_model, dropout=dropout)
self.ff = FeedForward(d_model, dropout=dropout)
def forward(self, x, e_output, src_mask, trg_mask):
attn_output_1 = self.attn_1(x, x, x, trg_mask)
attn_output_1 = self.dropout_1(attn_output_1)
x = x + attn_output_1
x = self.norm_1(x)
attn_output_2 = self.attn_2(e_output, e_output, x, src_mask)
attn_output_2 = self.dropout_2(attn_output_2)
x = x + attn_output_2
x = self.norm_2(x)
ff_output = self.ff(x)
ff_output = self.dropout_3(ff_output)
x = x + ff_output
x = self.norm_3(x)
return x
src_mask
是用于编码器部分的掩码(mask),它主要用于处理变长序列和填充(padding)问题。
trg_mask
是用于解码器部分的掩码(mask),它主要用于处理输出序列中的填充和自回归性质。与 src_mask 类似,trg_mask 也是一个布尔张量,但它不仅要处理填充,还要确保解码器在生成每个位置时只能看到该位置之前的信息,从而保持自回归性质:
处理变长序列:在批量处理时,不同样本的目标序列长度可能不同。为了便于所有样本具有相同的长度,通常会在较短的序列末尾添加填充(如
<pad>
标记)。掩码可以确保在计算注意力力时忽略这些填充位置。自回归性质:在解码器中,每个位置只能看到之前的序列信息,不能看到未来的信息。这通过一个下三角矩阵来实现,确保模型在生成每个位置时不会看到未来的位置。
整体结构
python
class Transformer(nn.Module):
def __init__(self, src_vocab, trg_vocab, d_model, N, heads, dropout):
super().__init__()
self.encoder = Encoder(src_vocab, d_model, N, heads, dropout)
self.decoder = Decoder(trg_vocab, d_model, N, heads, dropout)
self.out = nn.Linear(d_model, trg_vocab)
def forward(self, src, trg, src_mask, trg_mask):
e_outputs = self.encoder(src, src_mask)
d_output = self.decoder(trg, e_outputs, src_mask, trg_mask)
output = self.out(d_output)
return output