深度学习之Transformer架构详解

摘要

Transformer架构自2017年由谷歌研究团队在论文《Attention Is All You Need》中提出以来,彻底改变了自然语言处理(NLP)领域的发展轨迹,并逐步扩展到计算机视觉、语音识别等多个AI子领域。本文系统梳理了Transformer从RNN演进而来的技术背景,深入剖析了其核心组件------多头自注意力机制、位置编码、前馈神经网络、残差连接与层归一化------的工作原理与设计动机,详细对比了编码器与解码器的结构差异,并探讨了BERT、GPT、T5等基于Transformer的代表性模型变体。最后,本文通过完整的PyTorch代码实现了编码器和解码器模块,并给出一个简化的英译中机器翻译示例,帮助读者从理论走向实践。

关键词:Transformer;自注意力机制;多头注意力;位置编码;BERT;GPT;PyTorch


1. 引言

近年来,深度学习在人工智能领域取得了举世瞩目的成就,而其中最具里程碑意义的突破之一便是Transformer架构的诞生。Transformer摒弃了传统的循环神经网络(RNN)结构,完全基于注意力机制(Attention Mechanism)进行建模,实现了序列数据处理的新范式。

本文将带你深入理解Transformer的每一处设计细节,并通过完整的PyTorch代码将理论付诸实践。无论你是深度学习的初学者,还是希望系统掌握Transformer核心原理的从业者,都能从中获得有价值的收获。


2. Transformer背景:从RNN到注意力机制

2.1 循环神经网络的局限性

在Transformer出现之前,RNN及其变体(如LSTM、GRU)是处理序列数据的主流方法。RNN通过将上一个时刻的隐藏状态传递给下一个时刻,实现了对序列信息的顺序建模。然而,这种设计存在三个根本性的缺陷:

  1. 串行计算导致的训练效率低下:RNN必须按照时间步顺序依次处理,无法并行化。当序列长度达到数千甚至数万时,训练时间会成为严重的瓶颈。

  2. 长距离依赖问题(Long-range Dependency):虽然LSTM和GRU通过门控机制一定程度上缓解了梯度消失问题,但当相关信息跨越很长的序列距离时,模型仍然难以有效捕获这种依赖关系。

  3. 上下文容量限制:RNN将整个上下文信息压缩到一个固定维度的隐藏状态中,对于需要访问分散在长序列中多处信息的人物,固定维度成为了信息瓶颈。

2.2 Self-Attention的并行计算优势

Self-Attention(自注意力)机制的核心思想是:序列中每个位置的输出取决于该位置与序列中所有位置的关联程度,而不仅仅依赖于相邻位置。具体而言,给定一个输入序列,自注意力通过以下三个步骤计算每个位置的表示:

  • Query(查询):当前位置想要查找的信息

  • Key(键):每个位置用于被查询匹配的特征

  • Value(值):每个位置的实际内容

通过Query与所有Key的点积计算注意力权重,再对Value进行加权求和,得到最终的输出。这种设计使得任意两个位置之间的依赖关系可以在常数时间内(O(1))被直接建模,而无需像RNN那样经过O(N)步的序列传递。

更重要的是,Self-Attention的计算过程可以完全矩阵化,利用GPU的并行计算能力,在O(N²·d)的时间内完成整个序列的处理(其中N为序列长度,d为模型维度),这比RNN的O(N·d²·T)串行计算要高效得多。

2.3 《Attention Is All You Need》论文解读

2017年,谷歌团队发表了开创性论文《Attention Is All You Need》,首次提出了完全基于注意力机制的Transformer架构。该论文的核心贡献包括:

  • 仅使用注意力机制:完全摒弃RNN、CNN等传统结构,仅用多头自注意力(Multi-Head Attention)和位置前馈网络(Position-wise Feed-Forward Networks)构建模型。

  • 并行化训练:通过位置编码(Positional Encoding)注入序列顺序信息,使注意力计算可以完全并行。

  • 注意力多样性:通过多头注意力(Multi-Head Attention)从不同子空间捕获多种类型的依赖关系。

  • 机器翻译SOTA:在WMT 2014英德翻译任务上达到了28.4 BLEU分数(当时最佳),在英法翻译任务上达到41.8 BLEU。


3. 编码器结构

Transformer的编码器(Encoder)由多个相同的层(Layer)堆叠而成,每一层包含两个子模块:

  1. 多头自注意力层(Multi-Head Self-Attention)

  2. 前馈神经网络层(Feed-Forward Network)

每个子模块周围都使用了残差连接(Residual Connection)和层归一化(Layer Normalization)。下面逐一详解各组件。

3.1 多头自注意力层

缩放点积注意力(Scaled Dot-Product Attention)

给定Query矩阵Q、Key矩阵K和Value矩阵V,注意力计算的公式为:

复制代码
Attention(Q, K, V) = softmax(QKᵀ / √d_k) · V

其中,√d_k是缩放因子(scaling factor),d_k为Key的维度。缩放的目的是防止当d_k较大时,点积结果的方差过大,导致softmax函数进入饱和区(梯度接近于零),影响模型训练。

多头机制(Multi-Head Attention)

单一注意力头只能关注特定的关联模式。多头注意力将Q、K、V分别投影到h个不同的低维空间(子空间),在每个子空间中独立计算注意力,最后将h个头的输出拼接起来再进行一次线性变换:

复制代码
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) · W_O
where head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)

在原始论文中,d_model=512,h=8,每个头的维度d_k = d_v = d_model/h = 64。

自注意力的意义

在编码器中,输入序列的每个位置同时作为Query、Key和Value参与计算。这意味着每个位置的表示是由整个序列的加权组合来决定的,能够全面地捕获当前位置与序列中任意其他位置的语义关联。

3.2 前馈神经网络层

每个位置(token)还经过一个逐位置(position-wise)的前馈神经网络:

复制代码
FFN(x) = max(0, xW_1 + b_1)W_2 + b_2

这是一个两层全连接网络,先将维度从d_model扩展到d_ff(通常为2048或3072),经过ReLU激活后,再压缩回d_model。该网络在每个位置上独立地应用相同的变换,因此称为"逐位置"。

3.3 残差连接与层归一化

每个子模块的输出为:

复制代码
LayerNorm(x + Sublayer(x))

其中Sublayer(x)是子模块本身的输出。这种残差连接确保了即使子模块的映射结果接近于零,梯度也能直接反向传播到更低的层,有效缓解了深层网络的训练难度。

层归一化(Layer Normalization)则对每一层的输入进行归一化,使其均值为0、方差为1,稳定训练过程。

3.4 位置编码

由于Transformer本身不包含循环或卷积结构,无法自然地感知序列中元素的顺序关系,因此需要额外注入位置信息。论文使用了基于正弦和余弦函数的位置编码:

复制代码
PE(pos, 2i)   = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))

其中pos为位置索引,i为维度索引。这种编码方式具有两个重要特性:

  • 不同位置可以线性区分:任意两个位置pos₁和pos₂的编码差异可以通过线性变换关联。

  • 可以泛化到训练时未见过的更长序列:因为正弦/余弦函数在任意位置都有定义。

位置编码直接加到输入嵌入上,使模型能够区分不同位置的 token。


4. 解码器结构

解码器(Decoder)与编码器结构相似,但包含三个子模块,并引入了关键的mask机制。

4.1 Masked自注意力(防止看到未来)

在解码器中,每个位置只能关注该位置及其之前的所有位置(不能看到"未来"的 token),这是因为在推理(inference)时,输出是一个词一个词生成的,当前词的预测不应该依赖于后续词。

实现方式是在缩放点积注意力中,将被禁止关注的位置(j > i)的注意力分数设置为一个非常大的负数(如-1e9),经过softmax后这些位置的权重接近于零:

复制代码
Attention(Q, K, V) = softmax(mask(QKᵀ) / √d_k) · V

4.2 交叉注意力层

解码器还有一个独特的交叉注意力层(Cross Attention),其Query来自前一个解码器层的输出,而Key和Value则来自编码器的最终输出:

复制代码
CrossAttention(Q, K, V) = Attention(Q, K, V)

这个设计使得解码器的每个位置都能够"查询"整个编码器序列的上下文信息,是编码器-解码器之间信息传递的核心桥梁。机器翻译中,这一步对应于在生成目标语言单词时"参考"源语言句子的语义表示。

4.3 输出线性层与Softmax

解码器的最终输出经过一个线性层,将维度映射到词表大小(vocab_size),然后通过Softmax函数转化为每个词的生成概率分布,用于预测下一个token。

4.4 解码器的完整结构

综上所述,解码器的每一层包含:

  1. Masked Multi-Head Self-Attention:对目标序列(已生成部分)进行自注意力,避免泄露未来信息。

  2. Cross Attention(交叉注意力):Query来自解码器自身,K/V来自编码器输出,实现跨模态信息交互。

  3. Feed-Forward Network:与编码器相同的逐位置前馈网络。

每个子模块同样采用残差连接和层归一化。


5. 完整Transformer

5.1 编码器-解码器交互

在机器翻译等序列到序列(Seq2Seq)任务中,编码器首先处理完整的源序列,输出一个包含全局上下文信息的表示序列。解码器在生成目标序列的每一个词时,通过交叉注意力层从编码器的输出中检索相关信息。

具体流程如下:

  1. 源序列经过N层编码器,得到上下文增强的表示。

  2. 解码器在第1层使用已生成目标序列的嵌入和位置编码(通过Masked注意力)。

  3. 解码器在交叉注意力层,根据当前已生成的表示"查询"编码器输出,决定应该重点关注源序列的哪些部分。

  4. 经过N层解码器后,通过线性层和Softmax输出下一个词的概率分布。

5.2 训练与推理的区别

训练阶段(Training)

  • 编码器一次性接收完整的源序列。

  • 解码器一次性接收完整的目标序列(通常右移一位,即在输入中加入特殊的起始符号 <BOS>,在输出中包含结束符号 <EOS>)。

  • 所有位置并行计算,训练效率高。

推理阶段(Inference)

  • 编码器仍然一次性处理完整的源序列。

  • 解码器自回归(Autoregressive)生成:先生成第1个词,然后将该词作为输入生成第2个词,如此循环,直到生成 <EOS> 结束符号。

  • 每一步都需要重新计算整个解码器的所有层,无法像训练时那样并行。

5.3 Teacher Forcing

在训练解码器时,如果完全使用模型自己的上一轮预测作为输入,错误会累积并放大(级联误差),导致训练困难。Teacher Forcing的核心思想是:在训练时,以一定概率(通常是固定使用)使用目标序列中上一位置的实际token(ground truth)作为解码器的输入,而不是使用模型自身的预测结果。这显著加速了训练收敛。


6. BERT vs GPT:两大主流架构对比

基于Transformer的预训练模型主要分为两类:BERT的双向编码器和GPT的单向自回归生成器。

6.1 BERT:双向上下文编码器

BERT(Bidirectional Encoder Representations from Transformers)由谷歌于2018年提出,仅使用Transformer的编码器部分。

核心设计

  • 双向性:通过掩码语言模型(Masked Language Model, MLM),随机遮盖输入中约15%的词,让模型基于双向上下文来预测被遮盖的词。

  • 任务无关:预训练完成后,通过添加简单的输出层,可以微调(Fine-tuning)用于各种下游任务,如文本分类、问答、命名实体识别等。

  • 代表性模型:BERT-Base(12层,768维,12头),BERT-Large(24层,1024维,16头)。

适用场景:理解任务,如文本分类、情感分析、问答系统(抽取式问答)、自然语言推理。

6.2 GPT:单向自回归生成器

GPT(Generative Pre-trained Transformer)由OpenAI提出,仅使用Transformer的解码器部分。

核心设计

  • 单向性(自回归):从左到右逐词生成,每个词的预测只能依赖于其左侧的已生成内容。

  • 预训练目标:标准的语言模型目标------最大化P(x_t | x_{<t})的概率。

  • 涌现能力:随着模型规模(参数量)的急剧增大,GPT系列(如GPT-3、GPT-4)展现出惊人的少样本学习(Few-shot)和推理能力。

  • 代表性模型:GPT-2(15亿参数)、GPT-3(1750亿参数)、GPT-4(多模态)。

适用场景:生成任务,如文本续写、对话系统、代码生成、文档摘要。

6.3 T5、BART等变体

除了BERT和GPT之外,还有许多重要的Transformer变体:

  • T5(Text-to-Text Transfer Transformer):谷歌提出,将所有NLP任务统一建模为"文本到文本"的转换问题。编码器-解码器架构,灵活且通用。

  • BART(Bidirectional and Auto-Regressive Transformers):Facebook提出,使用类似降噪自动编码器的预训练目标:随机破坏输入文本(如随机遮盖、删除、旋转),让解码器恢复原始文本。结合了BERT的双向上下文和GPT的自回归生成优势。

  • RoBERTa:BERT的优化版本,去掉了下一句预测(NSP)任务,增加训练数据和批量大小,延长训练时间,在多项基准上超越BERT。

  • XLNet:通过排列语言模型(Permutation Language Model)解决BERT中MASK token在微调时不存在的问题,实现某种程度上的"双向上下文+自回归"统一。


7. 使用场景

Transformer架构的卓越性能使其在NLP乃至更广泛的AI领域得到了全面应用。

7.1 机器翻译

这是Transformer的"出生地"和最直接的应用场景。谷歌翻译、DeepL等商用翻译系统均已基于Transformer构建。编码器理解源语言语义,解码器自回归生成目标语言译文。相较于传统的基于短语的统计机器翻译(SMT),Transformer将BLEU分数提升了数个百分点。

7.2 文本生成

  • GPT系列:用于文章写作、故事创作、代码补全、游戏对话等。

  • T5/BART:用于文本摘要、文本风格转换、句子纠错等。

  • 控制生成:通过Prompt工程或微调,可以控制生成文本的主题、风格、长度等属性。

7.3 问答系统

  • 抽取式问答:BERT在此任务上取得了超越人类的表现,通过理解问题与文本段落的关系,定位答案span。

  • 生成式问答:如ChatGPT,能够基于海量知识生成连贯的答案。

  • 检索增强生成(RAG):结合检索系统和生成模型,提高答案的准确性和可追溯性。

7.4 其他应用场景

  • 语音识别:Whisper、Speech-to-Text等模型使用Transformer编码器处理声学特征。

  • 图像分类与目标检测:ViT(Vision Transformer)将图像分割为patch序列,用标准Transformer编码器进行分类,性能超越CNN。

  • 多模态:CLIP、DALL-E、GPT-4V等多模态模型利用Transformer统一处理文本、图像等多种模态。

  • 蛋白质结构预测:AlphaFold2使用Transformer架构预测蛋白质的三维结构,被誉为科学领域的重大突破。


8. PyTorch实现代码

本节通过完整的PyTorch代码,从零实现一个简化但可运行的Transformer模型,并给出一个简单的机器翻译示例。所有代码包含详细中文注释,可直接运行。

8.1 完整Transformer模块实现

复制代码
"""
完整Transformer架构的PyTorch实现
基于论文《Attention Is All You Need》(Vaswani et al., 2017)
"""
​
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import warnings
warnings.filterwarnings('ignore')
​
​
# ============================================================
# 1. 位置编码(Positional Encoding)
# ============================================================
class PositionalEncoding(nn.Module):
    """
    使用正弦/余弦函数生成位置编码,
    将序列中每个位置的信息注入到输入嵌入中。
    
    公式:
        PE(pos, 2i)   = sin(pos / 10000^(2i/d_model))
        PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
    
    这种编码方式的优势是:任意两个位置可以通过线性变换区分,
    并且能够泛化到训练时未见过的更长序列。
    """
    def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        # 创建位置编码矩阵:[max_len, d_model]
        pe = torch.zeros(max_len, d_model)
        # 位置索引:[max_len, 1]
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        # 频率项:用于控制不同维度的波长
        div_term = torch.exp(
            torch.arange(0, d_model, 2, dtype=torch.float) *
            (-math.log(10000.0) / d_model)
        )
        # 偶数维度用sin,奇数维度用cos
        pe[:, 0::2] = torch.sin(position * div_term)  # 维度 0, 2, 4, ...
        pe[:, 1::2] = torch.cos(position * div_term)  # 维度 1, 3, 5, ...
        # 添加batch维度:[1, max_len, d_model],便于后续与输入相加
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: [batch_size, seq_len, d_model]
        将位置编码加到输入上,并应用dropout
        """
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)
​
​
# ============================================================
# 2. 缩放点积注意力(Scaled Dot-Product Attention)
# ============================================================
def scaled_dot_product_attention(
    Q: torch.Tensor,
    K: torch.Tensor,
    V: torch.Tensor,
    mask: torch.Tensor = None
) -> torch.Tensor:
    """
    缩放点积注意力公式:
        Attention(Q,K,V) = softmax(QKᵀ / √d_k) · V
    
    参数:
        Q: [batch_size, num_heads, seq_len, d_k]
        K: [batch_size, num_heads, seq_len, d_k]
        V: [batch_size, num_heads, seq_len, d_v]
        mask: [batch_size, 1, seq_len, seq_len] 或 [batch_size, 1, 1, seq_len]
    
    返回:
        输出:[batch_size, num_heads, seq_len, d_v]
        注意力权重:[batch_size, num_heads, seq_len, seq_len]
    """
    d_k = Q.size(-1)  # Key的维度
    # 计算注意力分数:[batch_size, num_heads, seq_len, seq_len]
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    
    # 应用mask(如果有):将需要遮蔽的位置设为极小值
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    
    # softmax归一化:[batch_size, num_heads, seq_len, seq_len]
    attn_weights = F.softmax(scores, dim=-1)
    # 加权求和:[batch_size, num_heads, seq_len, d_v]
    output = torch.matmul(attn_weights, V)
    
    return output, attn_weights
​
​
# ============================================================
# 3. 多头注意力层(Multi-Head Attention)
# ============================================================
class MultiHeadAttention(nn.Module):
    """
    多头注意力机制:
        1. 将Q、K、V分别投影到h个子空间
        2. 在每个子空间独立计算注意力
        3. 拼接所有头的输出并进行线性变换
    
    核心思想:不同的头可以关注不同类型的依赖关系,
    例如有的头关注语法关系,有的头关注语义关联。
    """
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % num_heads == 0, "d_model必须能被num_heads整除"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # 每个头的维度
        
        # 四个线性变换层:Q, K, V 以及输出投影
        self.W_Q = nn.Linear(d_model, d_model, bias=False)
        self.W_K = nn.Linear(d_model, d_model, bias=False)
        self.W_V = nn.Linear(d_model, d_model, bias=False)
        self.W_O = nn.Linear(d_model, d_model, bias=False)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: torch.Tensor = None
    ) -> tuple:
        """
        参数:
            query/key/value: [batch_size, seq_len, d_model]
            mask: [batch_size, seq_len, seq_len] 或类似形状
        
        返回:
            output: [batch_size, seq_len, d_model]
            attn_weights: [batch_size, num_heads, seq_len, seq_len]
        """
        batch_size = query.size(0)
        
        # 线性变换并分头:[batch_size, seq_len, num_heads, d_k]
        # transpose后变为:[batch_size, num_heads, seq_len, d_k]
        Q = self.W_Q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_K(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_V(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        
        # 如果有mask,需要调整形状以适配注意力计算
        if mask is not None:
            # mask: [batch_size, 1, seq_len, seq_len] 或 [batch_size, seq_len, seq_len]
            mask = mask.unsqueeze(1)  # 扩展维度以匹配多头:[batch_size, 1, seq_len, seq_len]
        
        # 计算缩放点积注意力
        attn_output, attn_weights = scaled_dot_product_attention(Q, K, V, mask)
        
        # 合并多头:[batch_size, seq_len, num_heads, d_k] -> [batch_size, seq_len, d_model]
        attn_output = attn_output.transpose(1, 2).contiguous().view(
            batch_size, -1, self.d_model
        )
        
        # 最终线性投影
        output = self.W_O(attn_output)
        
        return output, attn_weights
​
​
# ============================================================
# 4. 前馈神经网络层(Position-wise Feed-Forward Network)
# ============================================================
class PositionwiseFeedForward(nn.Module):
    """
    逐位置前馈神经网络:
        FFN(x) = max(0, xW_1 + b_1)W_2 + b_2
    
    两个线性变换之间使用ReLU激活,
    维度从 d_model -> d_ff -> d_model。
    
    注意:每个位置独立应用相同的变换,因此称为"逐位置"。
    """
    def __init__(self, d_model: int, d_ff: int = 2048, dropout: float = 0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 先扩展维度,经过ReLU,再压缩回来
        return self.linear2(self.dropout(F.relu(self.linear1(x))))
​
​
# ============================================================
# 5. 编码器层(Encoder Layer)
# ============================================================
class EncoderLayer(nn.Module):
    """
    Transformer编码器的一层:
        1. Multi-Head Self-Attention(自注意力)
        2. 残差连接 + 层归一化
        3. Position-wise FFN(前馈网络)
        4. 残差连接 + 层归一化
    """
    def __init__(self, d_model: int, num_heads: int, d_ff: int = 2048, dropout: float = 0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.ffn = PositionwiseFeedForward(d_model, d_ff, dropout)
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
    
    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        # --- 自注意力子层 ---
        # Q=K=V=x:每个位置通过关注序列中所有位置来更新自己的表示
        attn_output, _ = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout1(attn_output))
        
        # --- 前馈网络子层 ---
        ffn_output = self.ffn(x)
        x = self.norm2(x + self.dropout2(ffn_output))
        
        return x
​
​
# ============================================================
# 6. 编码器(Encoder)
# ============================================================
class Encoder(nn.Module):
    """
    完整的Transformer编码器:
        - 输入嵌入 + 位置编码
        - N个编码器层堆叠
    """
    def __init__(
        self,
        vocab_size: int,
        d_model: int = 512,
        num_heads: int = 8,
        num_layers: int = 6,
        d_ff: int = 2048,
        dropout: float = 0.1,
        max_len: int = 5000
    ):
        super().__init__()
        self.d_model = d_model
        
        # 源语言的嵌入层和位置编码
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
        
        # N个编码器层
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        """
        x: [batch_size, src_seq_len] — 源语言序列(token IDs)
        返回: [batch_size, src_seq_len, d_model] — 编码后的序列表示
        """
        # 词嵌入 + 位置编码
        x = self.pos_encoding(self.dropout(self.embedding(x) * math.sqrt(self.d_model)))
        
        # 依次通过每一层编码器
        for layer in self.layers:
            x = layer(x, mask)
        
        return x
​
​
# ============================================================
# 7. 解码器层(Decoder Layer)
# ============================================================
class DecoderLayer(nn.Module):
    """
    Transformer解码器的一层:
        1. Masked Multi-Head Self-Attention(遮蔽自注意力)
        2. 残差连接 + 层归一化
        3. Cross Attention(交叉注意力,Q来自解码器,K/V来自编码器)
        4. 残差连接 + 层归一化
        5. Position-wise FFN
        6. 残差连接 + 层归一化
    """
    def __init__(self, d_model: int, num_heads: int, d_ff: int = 2048, dropout: float = 0.1):
        super().__init__()
        
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.ffn = PositionwiseFeedForward(d_model, d_ff, dropout)
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)
    
    def forward(
        self,
        x: torch.Tensor,           # 解码器输入
        encoder_output: torch.Tensor,  # 编码器输出
        src_mask: torch.Tensor = None, # 源序列mask
        tgt_mask: torch.Tensor = None  # 目标序列mask(用于遮蔽未来位置)
    ) -> torch.Tensor:
        # --- 第一子层:Masked自注意力 ---
        # Q=K=V=x,tgt_mask确保每个位置只能看到自己和之前的词
        _attn_out, _ = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout1(_attn_out))
        
        # --- 第二子层:交叉注意力 ---
        # Q来自解码器(x),K/V来自编码器(encoder_output)
        # 这使得解码器能够"看到"源序列的全局信息
        _attn_out, _ = self.cross_attn(x, encoder_output, encoder_output, src_mask)
        x = self.norm2(x + self.dropout2(_attn_out))
        
        # --- 第三子层:前馈网络 ---
        ffn_out = self.ffn(x)
        x = self.norm3(x + self.dropout3(ffn_out))
        
        return x
​
​
# ============================================================
# 8. 解码器(Decoder)
# ============================================================
class Decoder(nn.Module):
    """
    完整的Transformer解码器:
        - 输出嵌入 + 位置编码
        - N个解码器层堆叠
        - 最终输出线性层 + Softmax
    """
    def __init__(
        self,
        vocab_size: int,
        d_model: int = 512,
        num_heads: int = 8,
        num_layers: int = 6,
        d_ff: int = 2048,
        dropout: float = 0.1,
        max_len: int = 5000
    ):
        super().__init__()
        self.d_model = d_model
        
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
        
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        
        self.fc_out = nn.Linear(d_model, vocab_size)  # 最终投影到词表维度
        self.dropout = nn.Dropout(dropout)
    
    def forward(
        self,
        x: torch.Tensor,              # 目标序列
        encoder_output: torch.Tensor, # 编码器输出
        src_mask: torch.Tensor = None,
        tgt_mask: torch.Tensor = None
    ) -> torch.Tensor:
        """
        x: [batch_size, tgt_seq_len] — 目标序列(token IDs)
        返回: [batch_size, tgt_seq_len, vocab_size] — 每个位置的词概率分布
        """
        x = self.pos_encoding(self.dropout(self.embedding(x) * math.sqrt(self.d_model)))
        
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        
        # 投影到词表维度并返回logits(由调用方决定是否通过Softmax)
        return self.fc_out(x)
​
​
# ============================================================
# 9. 完整Transformer模型
# ============================================================
class Transformer(nn.Module):
    """
    完整的Transformer模型(编码器-解码器架构)
    """
    def __init__(
        self,
        src_vocab_size: int,
        tgt_vocab_size: int,
        d_model: int = 512,
        num_heads: int = 8,
        num_encoder_layers: int = 6,
        num_decoder_layers: int = 6,
        d_ff: int = 2048,
        dropout: float = 0.1,
        max_len: int = 5000
    ):
        super().__init__()
        
        self.encoder = Encoder(
            src_vocab_size, d_model, num_heads,
            num_encoder_layers, d_ff, dropout, max_len
        )
        self.decoder = Decoder(
            tgt_vocab_size, d_model, num_heads,
            num_decoder_layers, d_ff, dropout, max_len
        )
    
    def forward(
        self,
        src: torch.Tensor,
        tgt: torch.Tensor,
        src_mask: torch.Tensor = None,
        tgt_mask: torch.Tensor = None
    ) -> torch.Tensor:
        """
        src: [batch_size, src_seq_len]
        tgt: [batch_size, tgt_seq_len]
        返回: [batch_size, tgt_seq_len, tgt_vocab_size]
        """
        # 编码源序列
        encoder_output = self.encoder(src, src_mask)
        # 解码目标序列(使用编码器输出作为交叉注意力的K/V来源)
        decoder_output = self.decoder(tgt, encoder_output, src_mask, tgt_mask)
        return decoder_output
​
​
# ============================================================
# 辅助函数:生成mask
# ============================================================
def create_padding_mask(seq: torch.Tensor, pad_idx: int = 0) -> torch.Tensor:
    """
    创建padding mask:将序列中值为pad_idx的位置标记为0(用于遮蔽)
    返回:[batch_size, 1, 1, seq_len]
    """
    return (seq != pad_idx).unsqueeze(1).unsqueeze(2)
​
​
def create_causal_mask(seq_len: int, device: torch.device) -> torch.Tensor:
    """
    创建因果mask(causal mask):遮蔽未来位置
    返回:[seq_len, seq_len] — 上三角为0(遮蔽),下三角为1(可见)
    """
    # 创建一个上三角矩阵(不含对角线),值为负无穷
    mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool()
    return ~mask  # 转换为True表示可见,False表示遮蔽
​
​
def create_masks(
    src: torch.Tensor,
    tgt: torch.Tensor,
    pad_idx: int = 0
) -> tuple:
    """
    创建所有必要的mask:
    - src_mask: 遮蔽源序列中的padding位置
    - tgt_mask: 同时遮蔽padding和未来位置
    """
    # Padding mask for source
    src_mask = create_padding_mask(src, pad_idx)  # [B, 1, 1, src_len]
    
    # Padding + causal mask for target
    tgt_pad_mask = create_padding_mask(tgt, pad_idx)  # [B, 1, 1, tgt_len]
    tgt_len = tgt.size(1)
    causal_mask = create_causal_mask(tgt_len, tgt.device)  # [tgt_len, tgt_len]
    # 扩展到batch维度:[1, 1, tgt_len, tgt_len]
    tgt_mask = tgt_pad_mask & causal_mask.unsqueeze(0).unsqueeze(1)
    
    return src_mask, tgt_mask
​
​
print("=" * 60)
print("基础模块定义完成!")
print("=" * 60)

8.2 简单机器翻译示例

复制代码
"""
简单机器翻译示例:英译中(极简版本)
演示如何使用上述Transformer模块进行训练和推理
​
注意:这是一个极简示例,仅用于演示完整流程。
实际翻译系统需要大规模语料、更多epoch训练和更大模型。
"""
​
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import random
​
​
# ============================================================
# 1. 超参数配置
# ============================================================
BATCH_SIZE = 32
D_MODEL = 256          # 模型维度(论文原版为512,此处简化)
NUM_HEADS = 8          # 注意力头数
NUM_ENCODER_LAYERS = 3 # 编码器层数
NUM_DECODER_LAYERS = 3 # 解码器层数
D_FF = 512             # 前馈网络维度(论文原版为2048)
DROPOUT = 0.1
EPOCHS = 20
LEARNING_RATE = 0.0001
MAX_LEN = 20           # 最大序列长度
​
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备:{DEVICE}")
​
​
# ============================================================
# 2. 极简词表和数据集
# ============================================================
# 为了演示方便,我们使用一个极小的合成中英对照数据集
# 实际应用中应使用WMT、IWSLT等标准翻译数据集
​
# 英文词表(包含特殊符号)
en_tokens = {
    '<PAD>': 0, '<BOS>': 1, '<EOS>': 2,
    'hello': 3, 'world': 4, 'i': 5, 'love': 6,
    'you': 7, 'the': 8, 'cat': 9, 'dog': 10,
    'is': 11, 'a': 12, 'good': 13, 'friend': 14,
    'how': 15, 'are': 16, 'today': 17, 'fine': 18,
    'thank': 19, 'my': 20, 'name': 21, 'is': 22,
}
en_itos = {v: k for k, v in en_tokens.items()}
EN_VOCAB_SIZE = len(en_tokens)
​
# 中文词表(按字符分词,包含特殊符号)
zh_tokens = {
    '<PAD>': 0, '<BOS>': 1, '<EOS>': 2,
    '你': 3, '好': 4, '世': 5, '界': 6, '我': 7,
    '爱': 8, '你': 9, '的': 10, '猫': 11, '是': 12,
    '一': 13, '只': 14, '好': 15, '朋': 16, '友': 17,
    '怎': 18, '么': 19, '样': 20, '今': 21, '天': 22,
    '谢': 23, '谢': 24, '你': 25, '叫': 26, '什': 27,
    '么': 28
}
zh_itos = {v: k for k, v in zh_tokens.items()}
ZH_VOCAB_SIZE = len(zh_tokens)
​
​
# 简单的英中对照训练数据
training_pairs = [
    ("hello world", "你好世界"),
    ("i love you", "我爱你"),
    ("the cat is a good friend", "猫是一只好朋友"),
    ("how are you today", "你今天怎么样"),
    ("thank you my friend", "谢谢我的朋友"),
    ("my name is cat", "我叫猫"),
    ("i am fine", "我很好"),
    ("the dog is a friend", "狗是朋友"),
    ("hello i love the cat", "你好我爱猫"),
    ("world is good", "世界是好的"),
]
​
​
def tokenize_en(sentence: str, max_len: int = MAX_LEN) -> list:
    """英文分词(简单的空格分词 + padding/truncation)"""
    words = sentence.lower().split()
    tokens = [en_tokens.get(w, en_tokens['<PAD>']) for w in words]
    if len(tokens) < max_len:
        tokens += [en_tokens['<PAD>']] * (max_len - len(tokens))
    return tokens[:max_len]
​
​
def tokenize_zh(sentence: str, max_len: int = MAX_LEN) -> list:
    """中文分词(字符级 + padding/truncation)"""
    chars = list(sentence)
    tokens = [zh_tokens.get(c, zh_tokens['<PAD>']) for c in chars]
    if len(tokens) < max_len:
        tokens += [zh_tokens['<PAD>']] * (max_len - len(tokens))
    return tokens[:max_len]
​
​
class TranslationDataset(Dataset):
    """翻译数据集"""
    def __init__(self, pairs, max_len=MAX_LEN):
        self.pairs = pairs
        self.max_len = max_len
    
    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, idx):
        en_sent, zh_sent = self.pairs[idx]
        # 编码源序列(英文)
        src_tokens = tokenize_en(en_sent, self.max_len)
        # 编码目标序列(中文):输入加<BOS>,输出加<EOS>
        tgt_tokens = tokenize_zh('<BOS>' + zh_sent, self.max_len)
        # 标签序列:加<EOS>
        tgt_labels = tokenize_zh(zh_sent + '<EOS>', self.max_len)
        
        return (
            torch.tensor(src_tokens, dtype=torch.long),
            torch.tensor(tgt_tokens, dtype=torch.long),
            torch.tensor(tgt_labels, dtype=torch.long)
        )
​
​
def collate_fn(batch):
    """自定义batch整理函数"""
    src_batch = torch.stack([item[0] for item in batch])
    tgt_batch = torch.stack([item[1] for item in batch])
    tgt_labels_batch = torch.stack([item[2] for item in batch])
    return src_batch, tgt_batch, tgt_labels_batch
​
​
# ============================================================
# 3. 实例化模型
# ============================================================
model = Transformer(
    src_vocab_size=EN_VOCAB_SIZE,
    tgt_vocab_size=ZH_VOCAB_SIZE,
    d_model=D_MODEL,
    num_heads=NUM_HEADS,
    num_encoder_layers=NUM_ENCODER_LAYERS,
    num_decoder_layers=NUM_DECODER_LAYERS,
    d_ff=D_FF,
    dropout=DROPOUT
).to(DEVICE)
​
print(f"模型参数量:{sum(p.numel() for p in model.parameters()):,}")
​
​
# ============================================================
# 4. 损失函数和优化器
# ============================================================
PAD_IDX = 0
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)  # 忽略PAD位置的损失
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
​
​
# ============================================================
# 5. 训练函数
# ============================================================
def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    
    for src, tgt, tgt_labels in dataloader:
        src = src.to(device)
        tgt = tgt.to(device)
        tgt_labels = tgt_labels.to(device)
        
        # 创建mask
        src_mask, tgt_mask = create_masks(src, tgt, PAD_IDX)
        
        # 前向传播
        # 解码器输入:tgt(不含最后一个token)
        # 标签:tgt_labels(含<EOS>)
        tgt_input = tgt[:, :-1]   # 去掉最后一个token作为输入
        tgt_labels_slice = tgt_labels[:, 1:]  # 去掉<BOS>作为标签
        
        # 创建对应的mask
        _, tgt_mask = create_masks(src, tgt_input, PAD_IDX)
        
        logits = model(src, tgt_input, src_mask, tgt_mask)  # [B, tgt_len-1, vocab]
        
        # 计算损失(跨词表维度)
        loss = criterion(
            logits.reshape(-1, logits.size(-1)),
            tgt_labels_slice.reshape(-1)
        )
        
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # 梯度裁剪,防止梯度爆炸
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(dataloader)
​
​
# ============================================================
# 6. 翻译函数(贪心解码)
# ============================================================
def translate(
    model,
    sentence: str,
    en_tokens_dict: dict,
    zh_tokens_dict: dict,
    zh_itos_dict: dict,
    max_len: int = MAX_LEN,
    device: torch.device = DEVICE
) -> str:
    """
    使用贪心解码(Greedy Decoding)进行翻译:
    每一步选择概率最高的词,直到生成<EOS>或达到最大长度。
    """
    model.eval()
    
    # 编码源序列
    src_tokens = tokenize_en(sentence, max_len)
    src_tensor = torch.tensor([src_tokens], dtype=torch.long).to(device)
    src_mask, _ = create_masks(src_tensor, src_tensor, PAD_IDX)
    
    # 编码器前向传播
    encoder_output = model.encoder(src_tensor, src_mask)
    
    # 解码:从<BOS>开始自回归生成
    tgt_tokens = [zh_tokens_dict['<BOS>']]
    
    for _ in range(max_len):
        tgt_tensor = torch.tensor([tgt_tokens], dtype=torch.long).to(device)
        _, tgt_mask = create_masks(src_tensor, tgt_tensor, PAD_IDX)
        
        # 解码器前向传播
        logits = model.decoder(tgt_tensor, encoder_output, src_mask, tgt_mask)
        
        # 取最后一个时间步的预测(下一个词)
        next_token_logits = logits[:, -1, :]  # [1, vocab_size]
        next_token_id = next_token_logits.argmax(dim=-1).item()  # 贪心选择概率最高的词
        
        # 如果遇到<EOS>,停止生成
        if next_token_id == zh_tokens_dict['<EOS>']:
            break
        
        tgt_tokens.append(next_token_id)
    
    # 将token IDs转回中文文本
    zh_chars = [zh_itos_dict.get(tid, '<UNK>') for tid in tgt_tokens[1:]]  # 去掉<BOS>
    return ''.join(zh_chars)
​
​
# ============================================================
# 7. 开始训练
# ============================================================
print("\n开始训练...")
dataset = TranslationDataset(training_pairs)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
​
for epoch in range(EPOCHS):
    avg_loss = train_epoch(model, dataloader, optimizer, criterion, DEVICE)
    print(f"Epoch {epoch + 1:02d}/{EPOCHS} | 平均损失:{avg_loss:.4f}")
    
    # 每5个epoch演示一次翻译效果
    if (epoch + 1) % 5 == 0:
        print("\n翻译示例:")
        for en_sent, _ in training_pairs[:3]:
            zh_translation = translate(
                model, en_sent, en_tokens, zh_tokens, zh_itos, MAX_LEN, DEVICE
            )
            print(f"  英文:{en_sent}")
            print(f"  中文:{zh_translation}")
            print()
​
​
# ============================================================
# 8. 最终翻译效果展示
# ============================================================
print("\n" + "=" * 60)
print("最终翻译效果:")
print("=" * 60)
for en_sent, expected_zh in training_pairs:
    predicted_zh = translate(
        model, en_sent, en_tokens, zh_tokens, zh_itos, MAX_LEN, DEVICE
    )
    print(f"输入:{en_sent}")
    print(f"期望:{expected_zh}")
    print(f"预测:{predicted_zh}")
    print("-" * 40)
​
print("\n模型训练和推理完成!")
print("注意:由于使用了极小的合成数据集,模型泛化能力有限。")
print("实际应用中请使用WMT、IWSLT等大规模翻译数据集。")

运行以上代码,你将看到一个完整的Transformer从零实现的全过程,包括:

  1. 位置编码:通过正弦/余弦函数为序列注入位置信息

  2. 多头注意力:从不同子空间并行捕获多种依赖关系

  3. 编码器:堆叠自注意力和前馈网络,全面理解源序列语义

  4. 解码器:通过掩码自注意力防止信息泄露,通过交叉注意力查询源序列

  5. 机器翻译示例:在极小数据集上验证模型的前向传播和训练流程


9. 总结与展望

Transformer架构以其优雅的设计和卓越的性能,成为深度学习领域最成功的模型之一。本文系统地梳理了Transformer从RNN困境到注意力机制突破的技术演进,详细剖析了编码器和解码器的每个核心组件,并通过PyTorch代码实现了完整可运行的模型。

回顾Transformer的成功,我们可以总结出几个核心设计原则:

  1. 全局建模能力:自注意力机制使每个位置能够直接关注序列中的任意其他位置,突破了RNN的局部感受野限制。

  2. 并行化训练:通过位置编码而非序列顺序传递信息,实现了真正的并行计算。

  3. 模块化与可扩展性:编码器/解码器层可以堆叠,多头注意力可以灵活调整,为模型规模的扩大提供了可行路径。

  4. 表示的丰富性:多头机制使模型能够在不同子空间中同时捕获多种类型的语义关系。

展望未来,Transformer架构仍在持续演进:大语言模型(LLM)将模型规模推向千亿甚至万亿参数;混合专家(MoE)架构探索了高效扩展的可能性;FlashAttention等算法持续优化注意力计算的成本;多模态Transformer正在统一视觉、语言、音频等多种模态的建模范式。理解Transformer的原理,不仅是掌握当前AI技术的钥匙,更是迎接未来更多突破的基础。

相关推荐
拾年2754 小时前
一个项目教你玩转Claude Code 常用命令
人工智能
阿里云大数据AI技术4 小时前
PAI-FA|突破 TMEM 瓶颈:FlashAttention-4 大 Head Dimension (256) 高性能算子实现与优化
人工智能
Mr数据杨4 小时前
【CanMV K210】传感器实验 MPU6050 六轴数据与四元数姿态融合
人工智能·硬件开发·canmv k210
Das14 小时前
MCP Is Dead
人工智能
测试员周周4 小时前
【Appium 系列】第13节-混合测试执行器 — API + UI 的协同执行
开发语言·人工智能·python·功能测试·ui·appium·pytest
风落无尘4 小时前
第九章《语言与理解》 完整学习资料
gpt·rnn·语言模型·transformer
莽夫搞战术4 小时前
【Google Stitch】AI原生画布重新定义设计,让想法变成可交互界面
前端·人工智能·ui
malog_4 小时前
大语言模型后训练全解析
人工智能·深度学习·机器学习·ai·语言模型
Soari4 小时前
AI Engineering from Scratch:从数学基础到智能体工程,一套 435 课的 AI 工程实战路线图
人工智能