摘要
Transformer架构自2017年由谷歌研究团队在论文《Attention Is All You Need》中提出以来,彻底改变了自然语言处理(NLP)领域的发展轨迹,并逐步扩展到计算机视觉、语音识别等多个AI子领域。本文系统梳理了Transformer从RNN演进而来的技术背景,深入剖析了其核心组件------多头自注意力机制、位置编码、前馈神经网络、残差连接与层归一化------的工作原理与设计动机,详细对比了编码器与解码器的结构差异,并探讨了BERT、GPT、T5等基于Transformer的代表性模型变体。最后,本文通过完整的PyTorch代码实现了编码器和解码器模块,并给出一个简化的英译中机器翻译示例,帮助读者从理论走向实践。
关键词:Transformer;自注意力机制;多头注意力;位置编码;BERT;GPT;PyTorch
1. 引言
近年来,深度学习在人工智能领域取得了举世瞩目的成就,而其中最具里程碑意义的突破之一便是Transformer架构的诞生。Transformer摒弃了传统的循环神经网络(RNN)结构,完全基于注意力机制(Attention Mechanism)进行建模,实现了序列数据处理的新范式。
本文将带你深入理解Transformer的每一处设计细节,并通过完整的PyTorch代码将理论付诸实践。无论你是深度学习的初学者,还是希望系统掌握Transformer核心原理的从业者,都能从中获得有价值的收获。
2. Transformer背景:从RNN到注意力机制
2.1 循环神经网络的局限性
在Transformer出现之前,RNN及其变体(如LSTM、GRU)是处理序列数据的主流方法。RNN通过将上一个时刻的隐藏状态传递给下一个时刻,实现了对序列信息的顺序建模。然而,这种设计存在三个根本性的缺陷:
-
串行计算导致的训练效率低下:RNN必须按照时间步顺序依次处理,无法并行化。当序列长度达到数千甚至数万时,训练时间会成为严重的瓶颈。
-
长距离依赖问题(Long-range Dependency):虽然LSTM和GRU通过门控机制一定程度上缓解了梯度消失问题,但当相关信息跨越很长的序列距离时,模型仍然难以有效捕获这种依赖关系。
-
上下文容量限制:RNN将整个上下文信息压缩到一个固定维度的隐藏状态中,对于需要访问分散在长序列中多处信息的人物,固定维度成为了信息瓶颈。
2.2 Self-Attention的并行计算优势
Self-Attention(自注意力)机制的核心思想是:序列中每个位置的输出取决于该位置与序列中所有位置的关联程度,而不仅仅依赖于相邻位置。具体而言,给定一个输入序列,自注意力通过以下三个步骤计算每个位置的表示:
-
Query(查询):当前位置想要查找的信息
-
Key(键):每个位置用于被查询匹配的特征
-
Value(值):每个位置的实际内容
通过Query与所有Key的点积计算注意力权重,再对Value进行加权求和,得到最终的输出。这种设计使得任意两个位置之间的依赖关系可以在常数时间内(O(1))被直接建模,而无需像RNN那样经过O(N)步的序列传递。
更重要的是,Self-Attention的计算过程可以完全矩阵化,利用GPU的并行计算能力,在O(N²·d)的时间内完成整个序列的处理(其中N为序列长度,d为模型维度),这比RNN的O(N·d²·T)串行计算要高效得多。
2.3 《Attention Is All You Need》论文解读
2017年,谷歌团队发表了开创性论文《Attention Is All You Need》,首次提出了完全基于注意力机制的Transformer架构。该论文的核心贡献包括:
-
仅使用注意力机制:完全摒弃RNN、CNN等传统结构,仅用多头自注意力(Multi-Head Attention)和位置前馈网络(Position-wise Feed-Forward Networks)构建模型。
-
并行化训练:通过位置编码(Positional Encoding)注入序列顺序信息,使注意力计算可以完全并行。
-
注意力多样性:通过多头注意力(Multi-Head Attention)从不同子空间捕获多种类型的依赖关系。
-
机器翻译SOTA:在WMT 2014英德翻译任务上达到了28.4 BLEU分数(当时最佳),在英法翻译任务上达到41.8 BLEU。
3. 编码器结构
Transformer的编码器(Encoder)由多个相同的层(Layer)堆叠而成,每一层包含两个子模块:
-
多头自注意力层(Multi-Head Self-Attention)
-
前馈神经网络层(Feed-Forward Network)
每个子模块周围都使用了残差连接(Residual Connection)和层归一化(Layer Normalization)。下面逐一详解各组件。
3.1 多头自注意力层
缩放点积注意力(Scaled Dot-Product Attention)
给定Query矩阵Q、Key矩阵K和Value矩阵V,注意力计算的公式为:
Attention(Q, K, V) = softmax(QKᵀ / √d_k) · V
其中,√d_k是缩放因子(scaling factor),d_k为Key的维度。缩放的目的是防止当d_k较大时,点积结果的方差过大,导致softmax函数进入饱和区(梯度接近于零),影响模型训练。
多头机制(Multi-Head Attention)
单一注意力头只能关注特定的关联模式。多头注意力将Q、K、V分别投影到h个不同的低维空间(子空间),在每个子空间中独立计算注意力,最后将h个头的输出拼接起来再进行一次线性变换:
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) · W_O
where head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)
在原始论文中,d_model=512,h=8,每个头的维度d_k = d_v = d_model/h = 64。
自注意力的意义
在编码器中,输入序列的每个位置同时作为Query、Key和Value参与计算。这意味着每个位置的表示是由整个序列的加权组合来决定的,能够全面地捕获当前位置与序列中任意其他位置的语义关联。
3.2 前馈神经网络层
每个位置(token)还经过一个逐位置(position-wise)的前馈神经网络:
FFN(x) = max(0, xW_1 + b_1)W_2 + b_2
这是一个两层全连接网络,先将维度从d_model扩展到d_ff(通常为2048或3072),经过ReLU激活后,再压缩回d_model。该网络在每个位置上独立地应用相同的变换,因此称为"逐位置"。
3.3 残差连接与层归一化
每个子模块的输出为:
LayerNorm(x + Sublayer(x))
其中Sublayer(x)是子模块本身的输出。这种残差连接确保了即使子模块的映射结果接近于零,梯度也能直接反向传播到更低的层,有效缓解了深层网络的训练难度。
层归一化(Layer Normalization)则对每一层的输入进行归一化,使其均值为0、方差为1,稳定训练过程。
3.4 位置编码
由于Transformer本身不包含循环或卷积结构,无法自然地感知序列中元素的顺序关系,因此需要额外注入位置信息。论文使用了基于正弦和余弦函数的位置编码:
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
其中pos为位置索引,i为维度索引。这种编码方式具有两个重要特性:
-
不同位置可以线性区分:任意两个位置pos₁和pos₂的编码差异可以通过线性变换关联。
-
可以泛化到训练时未见过的更长序列:因为正弦/余弦函数在任意位置都有定义。
位置编码直接加到输入嵌入上,使模型能够区分不同位置的 token。
4. 解码器结构
解码器(Decoder)与编码器结构相似,但包含三个子模块,并引入了关键的mask机制。
4.1 Masked自注意力(防止看到未来)
在解码器中,每个位置只能关注该位置及其之前的所有位置(不能看到"未来"的 token),这是因为在推理(inference)时,输出是一个词一个词生成的,当前词的预测不应该依赖于后续词。
实现方式是在缩放点积注意力中,将被禁止关注的位置(j > i)的注意力分数设置为一个非常大的负数(如-1e9),经过softmax后这些位置的权重接近于零:
Attention(Q, K, V) = softmax(mask(QKᵀ) / √d_k) · V
4.2 交叉注意力层
解码器还有一个独特的交叉注意力层(Cross Attention),其Query来自前一个解码器层的输出,而Key和Value则来自编码器的最终输出:
CrossAttention(Q, K, V) = Attention(Q, K, V)
这个设计使得解码器的每个位置都能够"查询"整个编码器序列的上下文信息,是编码器-解码器之间信息传递的核心桥梁。机器翻译中,这一步对应于在生成目标语言单词时"参考"源语言句子的语义表示。
4.3 输出线性层与Softmax
解码器的最终输出经过一个线性层,将维度映射到词表大小(vocab_size),然后通过Softmax函数转化为每个词的生成概率分布,用于预测下一个token。
4.4 解码器的完整结构
综上所述,解码器的每一层包含:
-
Masked Multi-Head Self-Attention:对目标序列(已生成部分)进行自注意力,避免泄露未来信息。
-
Cross Attention(交叉注意力):Query来自解码器自身,K/V来自编码器输出,实现跨模态信息交互。
-
Feed-Forward Network:与编码器相同的逐位置前馈网络。
每个子模块同样采用残差连接和层归一化。
5. 完整Transformer
5.1 编码器-解码器交互
在机器翻译等序列到序列(Seq2Seq)任务中,编码器首先处理完整的源序列,输出一个包含全局上下文信息的表示序列。解码器在生成目标序列的每一个词时,通过交叉注意力层从编码器的输出中检索相关信息。
具体流程如下:
-
源序列经过N层编码器,得到上下文增强的表示。
-
解码器在第1层使用已生成目标序列的嵌入和位置编码(通过Masked注意力)。
-
解码器在交叉注意力层,根据当前已生成的表示"查询"编码器输出,决定应该重点关注源序列的哪些部分。
-
经过N层解码器后,通过线性层和Softmax输出下一个词的概率分布。
5.2 训练与推理的区别
训练阶段(Training):
-
编码器一次性接收完整的源序列。
-
解码器一次性接收完整的目标序列(通常右移一位,即在输入中加入特殊的起始符号
<BOS>,在输出中包含结束符号<EOS>)。 -
所有位置并行计算,训练效率高。
推理阶段(Inference):
-
编码器仍然一次性处理完整的源序列。
-
解码器自回归(Autoregressive)生成:先生成第1个词,然后将该词作为输入生成第2个词,如此循环,直到生成
<EOS>结束符号。 -
每一步都需要重新计算整个解码器的所有层,无法像训练时那样并行。
5.3 Teacher Forcing
在训练解码器时,如果完全使用模型自己的上一轮预测作为输入,错误会累积并放大(级联误差),导致训练困难。Teacher Forcing的核心思想是:在训练时,以一定概率(通常是固定使用)使用目标序列中上一位置的实际token(ground truth)作为解码器的输入,而不是使用模型自身的预测结果。这显著加速了训练收敛。
6. BERT vs GPT:两大主流架构对比
基于Transformer的预训练模型主要分为两类:BERT的双向编码器和GPT的单向自回归生成器。
6.1 BERT:双向上下文编码器
BERT(Bidirectional Encoder Representations from Transformers)由谷歌于2018年提出,仅使用Transformer的编码器部分。
核心设计:
-
双向性:通过掩码语言模型(Masked Language Model, MLM),随机遮盖输入中约15%的词,让模型基于双向上下文来预测被遮盖的词。
-
任务无关:预训练完成后,通过添加简单的输出层,可以微调(Fine-tuning)用于各种下游任务,如文本分类、问答、命名实体识别等。
-
代表性模型:BERT-Base(12层,768维,12头),BERT-Large(24层,1024维,16头)。
适用场景:理解任务,如文本分类、情感分析、问答系统(抽取式问答)、自然语言推理。
6.2 GPT:单向自回归生成器
GPT(Generative Pre-trained Transformer)由OpenAI提出,仅使用Transformer的解码器部分。
核心设计:
-
单向性(自回归):从左到右逐词生成,每个词的预测只能依赖于其左侧的已生成内容。
-
预训练目标:标准的语言模型目标------最大化P(x_t | x_{<t})的概率。
-
涌现能力:随着模型规模(参数量)的急剧增大,GPT系列(如GPT-3、GPT-4)展现出惊人的少样本学习(Few-shot)和推理能力。
-
代表性模型:GPT-2(15亿参数)、GPT-3(1750亿参数)、GPT-4(多模态)。
适用场景:生成任务,如文本续写、对话系统、代码生成、文档摘要。
6.3 T5、BART等变体
除了BERT和GPT之外,还有许多重要的Transformer变体:
-
T5(Text-to-Text Transfer Transformer):谷歌提出,将所有NLP任务统一建模为"文本到文本"的转换问题。编码器-解码器架构,灵活且通用。
-
BART(Bidirectional and Auto-Regressive Transformers):Facebook提出,使用类似降噪自动编码器的预训练目标:随机破坏输入文本(如随机遮盖、删除、旋转),让解码器恢复原始文本。结合了BERT的双向上下文和GPT的自回归生成优势。
-
RoBERTa:BERT的优化版本,去掉了下一句预测(NSP)任务,增加训练数据和批量大小,延长训练时间,在多项基准上超越BERT。
-
XLNet:通过排列语言模型(Permutation Language Model)解决BERT中MASK token在微调时不存在的问题,实现某种程度上的"双向上下文+自回归"统一。
7. 使用场景
Transformer架构的卓越性能使其在NLP乃至更广泛的AI领域得到了全面应用。
7.1 机器翻译
这是Transformer的"出生地"和最直接的应用场景。谷歌翻译、DeepL等商用翻译系统均已基于Transformer构建。编码器理解源语言语义,解码器自回归生成目标语言译文。相较于传统的基于短语的统计机器翻译(SMT),Transformer将BLEU分数提升了数个百分点。
7.2 文本生成
-
GPT系列:用于文章写作、故事创作、代码补全、游戏对话等。
-
T5/BART:用于文本摘要、文本风格转换、句子纠错等。
-
控制生成:通过Prompt工程或微调,可以控制生成文本的主题、风格、长度等属性。
7.3 问答系统
-
抽取式问答:BERT在此任务上取得了超越人类的表现,通过理解问题与文本段落的关系,定位答案span。
-
生成式问答:如ChatGPT,能够基于海量知识生成连贯的答案。
-
检索增强生成(RAG):结合检索系统和生成模型,提高答案的准确性和可追溯性。
7.4 其他应用场景
-
语音识别:Whisper、Speech-to-Text等模型使用Transformer编码器处理声学特征。
-
图像分类与目标检测:ViT(Vision Transformer)将图像分割为patch序列,用标准Transformer编码器进行分类,性能超越CNN。
-
多模态:CLIP、DALL-E、GPT-4V等多模态模型利用Transformer统一处理文本、图像等多种模态。
-
蛋白质结构预测:AlphaFold2使用Transformer架构预测蛋白质的三维结构,被誉为科学领域的重大突破。
8. PyTorch实现代码
本节通过完整的PyTorch代码,从零实现一个简化但可运行的Transformer模型,并给出一个简单的机器翻译示例。所有代码包含详细中文注释,可直接运行。
8.1 完整Transformer模块实现
"""
完整Transformer架构的PyTorch实现
基于论文《Attention Is All You Need》(Vaswani et al., 2017)
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import warnings
warnings.filterwarnings('ignore')
# ============================================================
# 1. 位置编码(Positional Encoding)
# ============================================================
class PositionalEncoding(nn.Module):
"""
使用正弦/余弦函数生成位置编码,
将序列中每个位置的信息注入到输入嵌入中。
公式:
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
这种编码方式的优势是:任意两个位置可以通过线性变换区分,
并且能够泛化到训练时未见过的更长序列。
"""
def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
# 创建位置编码矩阵:[max_len, d_model]
pe = torch.zeros(max_len, d_model)
# 位置索引:[max_len, 1]
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
# 频率项:用于控制不同维度的波长
div_term = torch.exp(
torch.arange(0, d_model, 2, dtype=torch.float) *
(-math.log(10000.0) / d_model)
)
# 偶数维度用sin,奇数维度用cos
pe[:, 0::2] = torch.sin(position * div_term) # 维度 0, 2, 4, ...
pe[:, 1::2] = torch.cos(position * div_term) # 维度 1, 3, 5, ...
# 添加batch维度:[1, max_len, d_model],便于后续与输入相加
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: [batch_size, seq_len, d_model]
将位置编码加到输入上,并应用dropout
"""
x = x + self.pe[:, :x.size(1), :]
return self.dropout(x)
# ============================================================
# 2. 缩放点积注意力(Scaled Dot-Product Attention)
# ============================================================
def scaled_dot_product_attention(
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor,
mask: torch.Tensor = None
) -> torch.Tensor:
"""
缩放点积注意力公式:
Attention(Q,K,V) = softmax(QKᵀ / √d_k) · V
参数:
Q: [batch_size, num_heads, seq_len, d_k]
K: [batch_size, num_heads, seq_len, d_k]
V: [batch_size, num_heads, seq_len, d_v]
mask: [batch_size, 1, seq_len, seq_len] 或 [batch_size, 1, 1, seq_len]
返回:
输出:[batch_size, num_heads, seq_len, d_v]
注意力权重:[batch_size, num_heads, seq_len, seq_len]
"""
d_k = Q.size(-1) # Key的维度
# 计算注意力分数:[batch_size, num_heads, seq_len, seq_len]
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
# 应用mask(如果有):将需要遮蔽的位置设为极小值
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# softmax归一化:[batch_size, num_heads, seq_len, seq_len]
attn_weights = F.softmax(scores, dim=-1)
# 加权求和:[batch_size, num_heads, seq_len, d_v]
output = torch.matmul(attn_weights, V)
return output, attn_weights
# ============================================================
# 3. 多头注意力层(Multi-Head Attention)
# ============================================================
class MultiHeadAttention(nn.Module):
"""
多头注意力机制:
1. 将Q、K、V分别投影到h个子空间
2. 在每个子空间独立计算注意力
3. 拼接所有头的输出并进行线性变换
核心思想:不同的头可以关注不同类型的依赖关系,
例如有的头关注语法关系,有的头关注语义关联。
"""
def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
super().__init__()
assert d_model % num_heads == 0, "d_model必须能被num_heads整除"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads # 每个头的维度
# 四个线性变换层:Q, K, V 以及输出投影
self.W_Q = nn.Linear(d_model, d_model, bias=False)
self.W_K = nn.Linear(d_model, d_model, bias=False)
self.W_V = nn.Linear(d_model, d_model, bias=False)
self.W_O = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor = None
) -> tuple:
"""
参数:
query/key/value: [batch_size, seq_len, d_model]
mask: [batch_size, seq_len, seq_len] 或类似形状
返回:
output: [batch_size, seq_len, d_model]
attn_weights: [batch_size, num_heads, seq_len, seq_len]
"""
batch_size = query.size(0)
# 线性变换并分头:[batch_size, seq_len, num_heads, d_k]
# transpose后变为:[batch_size, num_heads, seq_len, d_k]
Q = self.W_Q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_K(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_V(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# 如果有mask,需要调整形状以适配注意力计算
if mask is not None:
# mask: [batch_size, 1, seq_len, seq_len] 或 [batch_size, seq_len, seq_len]
mask = mask.unsqueeze(1) # 扩展维度以匹配多头:[batch_size, 1, seq_len, seq_len]
# 计算缩放点积注意力
attn_output, attn_weights = scaled_dot_product_attention(Q, K, V, mask)
# 合并多头:[batch_size, seq_len, num_heads, d_k] -> [batch_size, seq_len, d_model]
attn_output = attn_output.transpose(1, 2).contiguous().view(
batch_size, -1, self.d_model
)
# 最终线性投影
output = self.W_O(attn_output)
return output, attn_weights
# ============================================================
# 4. 前馈神经网络层(Position-wise Feed-Forward Network)
# ============================================================
class PositionwiseFeedForward(nn.Module):
"""
逐位置前馈神经网络:
FFN(x) = max(0, xW_1 + b_1)W_2 + b_2
两个线性变换之间使用ReLU激活,
维度从 d_model -> d_ff -> d_model。
注意:每个位置独立应用相同的变换,因此称为"逐位置"。
"""
def __init__(self, d_model: int, d_ff: int = 2048, dropout: float = 0.1):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# 先扩展维度,经过ReLU,再压缩回来
return self.linear2(self.dropout(F.relu(self.linear1(x))))
# ============================================================
# 5. 编码器层(Encoder Layer)
# ============================================================
class EncoderLayer(nn.Module):
"""
Transformer编码器的一层:
1. Multi-Head Self-Attention(自注意力)
2. 残差连接 + 层归一化
3. Position-wise FFN(前馈网络)
4. 残差连接 + 层归一化
"""
def __init__(self, d_model: int, num_heads: int, d_ff: int = 2048, dropout: float = 0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
self.ffn = PositionwiseFeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
# --- 自注意力子层 ---
# Q=K=V=x:每个位置通过关注序列中所有位置来更新自己的表示
attn_output, _ = self.self_attn(x, x, x, mask)
x = self.norm1(x + self.dropout1(attn_output))
# --- 前馈网络子层 ---
ffn_output = self.ffn(x)
x = self.norm2(x + self.dropout2(ffn_output))
return x
# ============================================================
# 6. 编码器(Encoder)
# ============================================================
class Encoder(nn.Module):
"""
完整的Transformer编码器:
- 输入嵌入 + 位置编码
- N个编码器层堆叠
"""
def __init__(
self,
vocab_size: int,
d_model: int = 512,
num_heads: int = 8,
num_layers: int = 6,
d_ff: int = 2048,
dropout: float = 0.1,
max_len: int = 5000
):
super().__init__()
self.d_model = d_model
# 源语言的嵌入层和位置编码
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
# N个编码器层
self.layers = nn.ModuleList([
EncoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
"""
x: [batch_size, src_seq_len] — 源语言序列(token IDs)
返回: [batch_size, src_seq_len, d_model] — 编码后的序列表示
"""
# 词嵌入 + 位置编码
x = self.pos_encoding(self.dropout(self.embedding(x) * math.sqrt(self.d_model)))
# 依次通过每一层编码器
for layer in self.layers:
x = layer(x, mask)
return x
# ============================================================
# 7. 解码器层(Decoder Layer)
# ============================================================
class DecoderLayer(nn.Module):
"""
Transformer解码器的一层:
1. Masked Multi-Head Self-Attention(遮蔽自注意力)
2. 残差连接 + 层归一化
3. Cross Attention(交叉注意力,Q来自解码器,K/V来自编码器)
4. 残差连接 + 层归一化
5. Position-wise FFN
6. 残差连接 + 层归一化
"""
def __init__(self, d_model: int, num_heads: int, d_ff: int = 2048, dropout: float = 0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout)
self.ffn = PositionwiseFeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
def forward(
self,
x: torch.Tensor, # 解码器输入
encoder_output: torch.Tensor, # 编码器输出
src_mask: torch.Tensor = None, # 源序列mask
tgt_mask: torch.Tensor = None # 目标序列mask(用于遮蔽未来位置)
) -> torch.Tensor:
# --- 第一子层:Masked自注意力 ---
# Q=K=V=x,tgt_mask确保每个位置只能看到自己和之前的词
_attn_out, _ = self.self_attn(x, x, x, tgt_mask)
x = self.norm1(x + self.dropout1(_attn_out))
# --- 第二子层:交叉注意力 ---
# Q来自解码器(x),K/V来自编码器(encoder_output)
# 这使得解码器能够"看到"源序列的全局信息
_attn_out, _ = self.cross_attn(x, encoder_output, encoder_output, src_mask)
x = self.norm2(x + self.dropout2(_attn_out))
# --- 第三子层:前馈网络 ---
ffn_out = self.ffn(x)
x = self.norm3(x + self.dropout3(ffn_out))
return x
# ============================================================
# 8. 解码器(Decoder)
# ============================================================
class Decoder(nn.Module):
"""
完整的Transformer解码器:
- 输出嵌入 + 位置编码
- N个解码器层堆叠
- 最终输出线性层 + Softmax
"""
def __init__(
self,
vocab_size: int,
d_model: int = 512,
num_heads: int = 8,
num_layers: int = 6,
d_ff: int = 2048,
dropout: float = 0.1,
max_len: int = 5000
):
super().__init__()
self.d_model = d_model
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
self.layers = nn.ModuleList([
DecoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
self.fc_out = nn.Linear(d_model, vocab_size) # 最终投影到词表维度
self.dropout = nn.Dropout(dropout)
def forward(
self,
x: torch.Tensor, # 目标序列
encoder_output: torch.Tensor, # 编码器输出
src_mask: torch.Tensor = None,
tgt_mask: torch.Tensor = None
) -> torch.Tensor:
"""
x: [batch_size, tgt_seq_len] — 目标序列(token IDs)
返回: [batch_size, tgt_seq_len, vocab_size] — 每个位置的词概率分布
"""
x = self.pos_encoding(self.dropout(self.embedding(x) * math.sqrt(self.d_model)))
for layer in self.layers:
x = layer(x, encoder_output, src_mask, tgt_mask)
# 投影到词表维度并返回logits(由调用方决定是否通过Softmax)
return self.fc_out(x)
# ============================================================
# 9. 完整Transformer模型
# ============================================================
class Transformer(nn.Module):
"""
完整的Transformer模型(编码器-解码器架构)
"""
def __init__(
self,
src_vocab_size: int,
tgt_vocab_size: int,
d_model: int = 512,
num_heads: int = 8,
num_encoder_layers: int = 6,
num_decoder_layers: int = 6,
d_ff: int = 2048,
dropout: float = 0.1,
max_len: int = 5000
):
super().__init__()
self.encoder = Encoder(
src_vocab_size, d_model, num_heads,
num_encoder_layers, d_ff, dropout, max_len
)
self.decoder = Decoder(
tgt_vocab_size, d_model, num_heads,
num_decoder_layers, d_ff, dropout, max_len
)
def forward(
self,
src: torch.Tensor,
tgt: torch.Tensor,
src_mask: torch.Tensor = None,
tgt_mask: torch.Tensor = None
) -> torch.Tensor:
"""
src: [batch_size, src_seq_len]
tgt: [batch_size, tgt_seq_len]
返回: [batch_size, tgt_seq_len, tgt_vocab_size]
"""
# 编码源序列
encoder_output = self.encoder(src, src_mask)
# 解码目标序列(使用编码器输出作为交叉注意力的K/V来源)
decoder_output = self.decoder(tgt, encoder_output, src_mask, tgt_mask)
return decoder_output
# ============================================================
# 辅助函数:生成mask
# ============================================================
def create_padding_mask(seq: torch.Tensor, pad_idx: int = 0) -> torch.Tensor:
"""
创建padding mask:将序列中值为pad_idx的位置标记为0(用于遮蔽)
返回:[batch_size, 1, 1, seq_len]
"""
return (seq != pad_idx).unsqueeze(1).unsqueeze(2)
def create_causal_mask(seq_len: int, device: torch.device) -> torch.Tensor:
"""
创建因果mask(causal mask):遮蔽未来位置
返回:[seq_len, seq_len] — 上三角为0(遮蔽),下三角为1(可见)
"""
# 创建一个上三角矩阵(不含对角线),值为负无穷
mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool()
return ~mask # 转换为True表示可见,False表示遮蔽
def create_masks(
src: torch.Tensor,
tgt: torch.Tensor,
pad_idx: int = 0
) -> tuple:
"""
创建所有必要的mask:
- src_mask: 遮蔽源序列中的padding位置
- tgt_mask: 同时遮蔽padding和未来位置
"""
# Padding mask for source
src_mask = create_padding_mask(src, pad_idx) # [B, 1, 1, src_len]
# Padding + causal mask for target
tgt_pad_mask = create_padding_mask(tgt, pad_idx) # [B, 1, 1, tgt_len]
tgt_len = tgt.size(1)
causal_mask = create_causal_mask(tgt_len, tgt.device) # [tgt_len, tgt_len]
# 扩展到batch维度:[1, 1, tgt_len, tgt_len]
tgt_mask = tgt_pad_mask & causal_mask.unsqueeze(0).unsqueeze(1)
return src_mask, tgt_mask
print("=" * 60)
print("基础模块定义完成!")
print("=" * 60)
8.2 简单机器翻译示例
"""
简单机器翻译示例:英译中(极简版本)
演示如何使用上述Transformer模块进行训练和推理
注意:这是一个极简示例,仅用于演示完整流程。
实际翻译系统需要大规模语料、更多epoch训练和更大模型。
"""
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import random
# ============================================================
# 1. 超参数配置
# ============================================================
BATCH_SIZE = 32
D_MODEL = 256 # 模型维度(论文原版为512,此处简化)
NUM_HEADS = 8 # 注意力头数
NUM_ENCODER_LAYERS = 3 # 编码器层数
NUM_DECODER_LAYERS = 3 # 解码器层数
D_FF = 512 # 前馈网络维度(论文原版为2048)
DROPOUT = 0.1
EPOCHS = 20
LEARNING_RATE = 0.0001
MAX_LEN = 20 # 最大序列长度
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备:{DEVICE}")
# ============================================================
# 2. 极简词表和数据集
# ============================================================
# 为了演示方便,我们使用一个极小的合成中英对照数据集
# 实际应用中应使用WMT、IWSLT等标准翻译数据集
# 英文词表(包含特殊符号)
en_tokens = {
'<PAD>': 0, '<BOS>': 1, '<EOS>': 2,
'hello': 3, 'world': 4, 'i': 5, 'love': 6,
'you': 7, 'the': 8, 'cat': 9, 'dog': 10,
'is': 11, 'a': 12, 'good': 13, 'friend': 14,
'how': 15, 'are': 16, 'today': 17, 'fine': 18,
'thank': 19, 'my': 20, 'name': 21, 'is': 22,
}
en_itos = {v: k for k, v in en_tokens.items()}
EN_VOCAB_SIZE = len(en_tokens)
# 中文词表(按字符分词,包含特殊符号)
zh_tokens = {
'<PAD>': 0, '<BOS>': 1, '<EOS>': 2,
'你': 3, '好': 4, '世': 5, '界': 6, '我': 7,
'爱': 8, '你': 9, '的': 10, '猫': 11, '是': 12,
'一': 13, '只': 14, '好': 15, '朋': 16, '友': 17,
'怎': 18, '么': 19, '样': 20, '今': 21, '天': 22,
'谢': 23, '谢': 24, '你': 25, '叫': 26, '什': 27,
'么': 28
}
zh_itos = {v: k for k, v in zh_tokens.items()}
ZH_VOCAB_SIZE = len(zh_tokens)
# 简单的英中对照训练数据
training_pairs = [
("hello world", "你好世界"),
("i love you", "我爱你"),
("the cat is a good friend", "猫是一只好朋友"),
("how are you today", "你今天怎么样"),
("thank you my friend", "谢谢我的朋友"),
("my name is cat", "我叫猫"),
("i am fine", "我很好"),
("the dog is a friend", "狗是朋友"),
("hello i love the cat", "你好我爱猫"),
("world is good", "世界是好的"),
]
def tokenize_en(sentence: str, max_len: int = MAX_LEN) -> list:
"""英文分词(简单的空格分词 + padding/truncation)"""
words = sentence.lower().split()
tokens = [en_tokens.get(w, en_tokens['<PAD>']) for w in words]
if len(tokens) < max_len:
tokens += [en_tokens['<PAD>']] * (max_len - len(tokens))
return tokens[:max_len]
def tokenize_zh(sentence: str, max_len: int = MAX_LEN) -> list:
"""中文分词(字符级 + padding/truncation)"""
chars = list(sentence)
tokens = [zh_tokens.get(c, zh_tokens['<PAD>']) for c in chars]
if len(tokens) < max_len:
tokens += [zh_tokens['<PAD>']] * (max_len - len(tokens))
return tokens[:max_len]
class TranslationDataset(Dataset):
"""翻译数据集"""
def __init__(self, pairs, max_len=MAX_LEN):
self.pairs = pairs
self.max_len = max_len
def __len__(self):
return len(self.pairs)
def __getitem__(self, idx):
en_sent, zh_sent = self.pairs[idx]
# 编码源序列(英文)
src_tokens = tokenize_en(en_sent, self.max_len)
# 编码目标序列(中文):输入加<BOS>,输出加<EOS>
tgt_tokens = tokenize_zh('<BOS>' + zh_sent, self.max_len)
# 标签序列:加<EOS>
tgt_labels = tokenize_zh(zh_sent + '<EOS>', self.max_len)
return (
torch.tensor(src_tokens, dtype=torch.long),
torch.tensor(tgt_tokens, dtype=torch.long),
torch.tensor(tgt_labels, dtype=torch.long)
)
def collate_fn(batch):
"""自定义batch整理函数"""
src_batch = torch.stack([item[0] for item in batch])
tgt_batch = torch.stack([item[1] for item in batch])
tgt_labels_batch = torch.stack([item[2] for item in batch])
return src_batch, tgt_batch, tgt_labels_batch
# ============================================================
# 3. 实例化模型
# ============================================================
model = Transformer(
src_vocab_size=EN_VOCAB_SIZE,
tgt_vocab_size=ZH_VOCAB_SIZE,
d_model=D_MODEL,
num_heads=NUM_HEADS,
num_encoder_layers=NUM_ENCODER_LAYERS,
num_decoder_layers=NUM_DECODER_LAYERS,
d_ff=D_FF,
dropout=DROPOUT
).to(DEVICE)
print(f"模型参数量:{sum(p.numel() for p in model.parameters()):,}")
# ============================================================
# 4. 损失函数和优化器
# ============================================================
PAD_IDX = 0
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX) # 忽略PAD位置的损失
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
# ============================================================
# 5. 训练函数
# ============================================================
def train_epoch(model, dataloader, optimizer, criterion, device):
model.train()
total_loss = 0
for src, tgt, tgt_labels in dataloader:
src = src.to(device)
tgt = tgt.to(device)
tgt_labels = tgt_labels.to(device)
# 创建mask
src_mask, tgt_mask = create_masks(src, tgt, PAD_IDX)
# 前向传播
# 解码器输入:tgt(不含最后一个token)
# 标签:tgt_labels(含<EOS>)
tgt_input = tgt[:, :-1] # 去掉最后一个token作为输入
tgt_labels_slice = tgt_labels[:, 1:] # 去掉<BOS>作为标签
# 创建对应的mask
_, tgt_mask = create_masks(src, tgt_input, PAD_IDX)
logits = model(src, tgt_input, src_mask, tgt_mask) # [B, tgt_len-1, vocab]
# 计算损失(跨词表维度)
loss = criterion(
logits.reshape(-1, logits.size(-1)),
tgt_labels_slice.reshape(-1)
)
# 反向传播
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # 梯度裁剪,防止梯度爆炸
optimizer.step()
total_loss += loss.item()
return total_loss / len(dataloader)
# ============================================================
# 6. 翻译函数(贪心解码)
# ============================================================
def translate(
model,
sentence: str,
en_tokens_dict: dict,
zh_tokens_dict: dict,
zh_itos_dict: dict,
max_len: int = MAX_LEN,
device: torch.device = DEVICE
) -> str:
"""
使用贪心解码(Greedy Decoding)进行翻译:
每一步选择概率最高的词,直到生成<EOS>或达到最大长度。
"""
model.eval()
# 编码源序列
src_tokens = tokenize_en(sentence, max_len)
src_tensor = torch.tensor([src_tokens], dtype=torch.long).to(device)
src_mask, _ = create_masks(src_tensor, src_tensor, PAD_IDX)
# 编码器前向传播
encoder_output = model.encoder(src_tensor, src_mask)
# 解码:从<BOS>开始自回归生成
tgt_tokens = [zh_tokens_dict['<BOS>']]
for _ in range(max_len):
tgt_tensor = torch.tensor([tgt_tokens], dtype=torch.long).to(device)
_, tgt_mask = create_masks(src_tensor, tgt_tensor, PAD_IDX)
# 解码器前向传播
logits = model.decoder(tgt_tensor, encoder_output, src_mask, tgt_mask)
# 取最后一个时间步的预测(下一个词)
next_token_logits = logits[:, -1, :] # [1, vocab_size]
next_token_id = next_token_logits.argmax(dim=-1).item() # 贪心选择概率最高的词
# 如果遇到<EOS>,停止生成
if next_token_id == zh_tokens_dict['<EOS>']:
break
tgt_tokens.append(next_token_id)
# 将token IDs转回中文文本
zh_chars = [zh_itos_dict.get(tid, '<UNK>') for tid in tgt_tokens[1:]] # 去掉<BOS>
return ''.join(zh_chars)
# ============================================================
# 7. 开始训练
# ============================================================
print("\n开始训练...")
dataset = TranslationDataset(training_pairs)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
for epoch in range(EPOCHS):
avg_loss = train_epoch(model, dataloader, optimizer, criterion, DEVICE)
print(f"Epoch {epoch + 1:02d}/{EPOCHS} | 平均损失:{avg_loss:.4f}")
# 每5个epoch演示一次翻译效果
if (epoch + 1) % 5 == 0:
print("\n翻译示例:")
for en_sent, _ in training_pairs[:3]:
zh_translation = translate(
model, en_sent, en_tokens, zh_tokens, zh_itos, MAX_LEN, DEVICE
)
print(f" 英文:{en_sent}")
print(f" 中文:{zh_translation}")
print()
# ============================================================
# 8. 最终翻译效果展示
# ============================================================
print("\n" + "=" * 60)
print("最终翻译效果:")
print("=" * 60)
for en_sent, expected_zh in training_pairs:
predicted_zh = translate(
model, en_sent, en_tokens, zh_tokens, zh_itos, MAX_LEN, DEVICE
)
print(f"输入:{en_sent}")
print(f"期望:{expected_zh}")
print(f"预测:{predicted_zh}")
print("-" * 40)
print("\n模型训练和推理完成!")
print("注意:由于使用了极小的合成数据集,模型泛化能力有限。")
print("实际应用中请使用WMT、IWSLT等大规模翻译数据集。")
运行以上代码,你将看到一个完整的Transformer从零实现的全过程,包括:
-
位置编码:通过正弦/余弦函数为序列注入位置信息
-
多头注意力:从不同子空间并行捕获多种依赖关系
-
编码器:堆叠自注意力和前馈网络,全面理解源序列语义
-
解码器:通过掩码自注意力防止信息泄露,通过交叉注意力查询源序列
-
机器翻译示例:在极小数据集上验证模型的前向传播和训练流程
9. 总结与展望
Transformer架构以其优雅的设计和卓越的性能,成为深度学习领域最成功的模型之一。本文系统地梳理了Transformer从RNN困境到注意力机制突破的技术演进,详细剖析了编码器和解码器的每个核心组件,并通过PyTorch代码实现了完整可运行的模型。
回顾Transformer的成功,我们可以总结出几个核心设计原则:
-
全局建模能力:自注意力机制使每个位置能够直接关注序列中的任意其他位置,突破了RNN的局部感受野限制。
-
并行化训练:通过位置编码而非序列顺序传递信息,实现了真正的并行计算。
-
模块化与可扩展性:编码器/解码器层可以堆叠,多头注意力可以灵活调整,为模型规模的扩大提供了可行路径。
-
表示的丰富性:多头机制使模型能够在不同子空间中同时捕获多种类型的语义关系。
展望未来,Transformer架构仍在持续演进:大语言模型(LLM)将模型规模推向千亿甚至万亿参数;混合专家(MoE)架构探索了高效扩展的可能性;FlashAttention等算法持续优化注意力计算的成本;多模态Transformer正在统一视觉、语言、音频等多种模态的建模范式。理解Transformer的原理,不仅是掌握当前AI技术的钥匙,更是迎接未来更多突破的基础。