Transformer架构学习笔记:从数学推导到工程实现与主流变体
执行摘要与目录
执行摘要(≤300字)
本文系统梳理Transformer架构,从"为何抛弃RNN/CNN"、到自注意力的矩阵形式与缩放因子推导、多头注意力的等价视角、残差与LayerNorm的训练动力学作用,再到位置编码的频域解释与相对位置扩展,并给出复杂度/参数量估算、训练与优化超参建议、以及可直接复用的PyTorch核心模块实现与逐行关键解释。最后覆盖BERT/GPT/T5、Transformer-XL与Longformer、Sparse/线性注意力等变体,结合调试与性能优化(FlashAttention、混合精度、ZeRO、梯度检查点)与典型应用案例,形成面向研究生/工程师的可落地学习笔记。
目录
本文按八个"主章节"组织,每章末给出小结与关键公式/推导要点:
- 执行摘要与目录
- 背景与动机:序列建模的瓶颈与Transformer设计目标
- 数学核心:自注意力、缩放因子、多头注意力、残差与归一化
- 位置建模:正弦位置编码的频域解释与相对位置方法谱系
- 编码器与解码器:结构、掩码机制与信息流
- 训练与优化:目标函数、学习率调度、正则化、混合精度与分布式
- 变体谱系:BERT、GPT、T5、Transformer-XL、Longformer、Sparse/线性注意力
- 工程实现、调试与性能优化:PyTorch代码、推理加速与常见问题、参考来源
背景与动机:序列建模的瓶颈与Transformer设计目标
在Transformer提出之前,主流序列到序列(seq2seq)模型多依赖RNN或CNN作为编码器/解码器骨干,并辅以"注意力机制"做跨序列对齐;但RNN的本质瓶颈在于时间步串行依赖 (难并行、长依赖梯度路径长),CNN虽可并行但要覆盖长依赖常需堆叠多层或扩张卷积,导致路径仍可能变长。Transformer的核心主张是:仅用注意力机制(尤其是自注意力)构建序列转换网络,彻底去除递归与卷积,从而提升并行性与训练效率,并在机器翻译等任务上获得更好的效果。
一个非常"工程友好"的比较方式是看每层复杂度、最少串行操作数、以及"最大路径长度"(衡量任意两位置间信息传播所需层数/步数,路径越短通常越利于学习长距离依赖)。原始Transformer论文在表格中给出典型结论:
- 自注意力:每层复杂度 (O(n^2 d)),最少串行操作 (O(1)),最大路径长度 (O(1))。
- RNN:每层复杂度 (O(n d^2)),串行操作 (O(n)),最大路径长度 (O(n))。
- CNN:每层复杂度 (O(k n d^2)),串行 (O(1)),最大路径长度约 (O(\log_k n))(扩张卷积情形)。
这解释了Transformer在中等长度序列(如常见子词级文本)上"更易并行且路径更短"的直观优势,也顺带指出其最大弱点:注意力的二次复杂度使得超长序列训练/推理成本迅速爆炸,促成后续大量Long-context与稀疏/线性注意力研究。
本节小结(带关键要点)
- Transformer的动机可以压缩为三点:并行化(串行步数从 (O(n)\to O(1)))、长依赖路径缩短(最大路径 (O(n)\to O(1)))、以及用矩阵乘法高效实现注意力。citeturn14view0turn10view0
- 代价是注意力二次复杂度 (O(n^2 d)),后续Longformer/Sparse/线性注意力的核心就是"在不显著掉点的情况下,把 (n^2) 降下来"。
数学核心:自注意力、缩放因子、多头注意力、残差与归一化
下面进入Transformer的理论与实现"主干"。本章尽量用统一符号把公式写成工程可实现的矩阵形态。
统一符号与张量形状约定
设输入序列长度为 (n),模型维度为 (d_\text{model})。输入嵌入(叠加位置编码后的输入)记为
X \\in \\mathbb{R}\^{n\\times d_\\text{model}}.
对注意力,我们引入三组投影矩阵(以自注意力为例,Q/K/V都来自同一个 (X)):
Q=XW\^Q,\\quad K=XW\^K,\\quad V=XW\^V,
其中 (WQ,WK\in \mathbb{R}^{d_\text{model}\times d_k}),(W^V\in \mathbb{R}^{d_\text{model}\times d_v})。原始论文在矩阵形式里把多查询打包为 (Q) 矩阵,同时把键和值打包为 (K,V)。
自注意力的矩阵形式推导
步骤一:相似度(logits)矩阵
S = QK\^\\top \\in \\mathbb{R}\^{n\\times n},\\qquad S_{ij}=\\langle q_i,k_j\\rangle.
步骤二:缩放 + 掩码 + Softmax 得到注意力权重
Transformer采用"缩放点积注意力"(Scaled Dot-Product Attention):
A=\\mathrm{softmax}!\\left(\\frac{QK\^\\top}{\\sqrt{d_k}} + M\\right),
其中 (M) 是掩码矩阵(无效位置加 (-\infty),实现上可用一个极小负数)。这正是原始论文给出的核心公式(Equation (1))。
步骤三:加权求和输出
O = AV \\in \\mathbb{R}\^{n\\times d_v},\\qquad o_i=\\sum_{j=1}\^{n}A_{ij}v_j.
上面三步直接对应实现:一次矩阵乘法得到 (QK^\top),一次softmax,最后乘上 (V)。
缩放因子 (1/\sqrt{d_k}) 的"为什么":方差与Softmax饱和
原始论文给出直觉:若 (q,k) 各分量独立、均值0方差1,则点积
q\\cdot k=\\sum_{i=1}\^{d_k} q_i k_i
其方差随 (d_k) 线性增长((\mathrm{Var}(q\cdot k)=d_k)),从而当 (d_k) 大时,logits幅值变大,Softmax更容易进入饱和区间导致梯度极小;因此用 (\sqrt{d_k}) 缩放以抑制该效应。citeturn8view0turn10view0
从工程角度,这相当于让 logits 的典型尺度更稳定,使训练更"线性可控",尤其在大维度/多头拼接场景下更关键。
多头注意力的定义与"等价视角"
定义
多头注意力将 (Q,K,V) 通过不同线性投影分成 (h) 个头并行计算:
\\mathrm{head}_i=\\mathrm{Attention}(QW^Q_i,KW^K_i,VW\^V_i),
\\mathrm{MHA}(Q,K,V)=\\mathrm{Concat}(\\mathrm{head}_1,\\ldots,\\mathrm{head}_h)W\^O.
这是原始论文的标准公式。
等价视角一:一次大线性层 + reshape
把所有头的投影矩阵按列拼接起来:
W^Q=\[W^Q_1\|W^Q_2\|\\cdots\|W^Q_h\]\\in \\mathbb{R}\^{d_\\text{model}\\times (h d_k)},
那么
Q_\\text{all}=XW\^Q \\in \\mathbb{R}\^{n\\times (h d_k)}
只需把最后一维 reshape 成 ((h,n,d_k)),就等价于分别计算每个头的 (Q_i)。(K,V) 同理。这个等价性在实现中非常重要:它解释了为何很多库(例如常见的PyTorch实现方式)会用一个线性层一次性生成QKV再reshape,以减少kernel launch并提高吞吐。
等价视角二:为何"总计算量近似不变"
论文在典型设置下取 (d_k=d_v=d_\text{model}/h),并指出由于每头维度变小,整体计算成本与"单头全维注意力"相近。
残差连接与LayerNorm:表达、优化与稳定性的分工
原始Transformer每个子层(注意力子层、FFN子层)都采用
\\mathrm{LayerNorm}(x+\\mathrm{Sublayer}(x)),
即残差相加后做LayerNorm(常称"Post-LN"布局)。
这背后有两层逻辑:
- 残差连接提供"恒等映射通道",缓解深层网络优化困难(这一思想源自残差网络的经验与理论直觉)。
- LayerNorm在每个样本的特征维上做归一化,减小内部协变量偏移并稳定训练;与BatchNorm不同,它不依赖batch统计量,因此更适用于序列模型与变长输入。
更进一步,后续研究指出:LayerNorm摆放位置会显著影响训练稳定性与"是否强依赖warmup"。例如关于Transformer中LayerNorm位置与warmup必要性的理论/实证分析表明,把LayerNorm放入残差分支内部("Pre-LN")能使初始化附近梯度更"良性",从而可以减少甚至去除warmup依赖。
前馈网络(FFN)的数学形式与参数量主来源
每层包含一个位置前馈网络(Position-wise FFN),对序列每个位置独立同分布地应用:
\\mathrm{FFN}(x)=W_2,\\sigma(W_1 x + b_1)+b_2,
原始论文使用两层线性加ReLU。citeturn8view0
这块通常贡献大量参数:若隐藏维 (d_\mathrm{ff})(常取 (4d_\text{model}) 一类比例),则FFN参数量约为
#\\theta_\\mathrm{FFN}\\approx 2 d_\\text{model} d_\\mathrm{ff} \\quad (\\text{不计bias}),
而注意力投影约为
#\\theta_\\mathrm{MHA}\\approx 4 d_\\text{model}\^2 \\quad (\\text{含}W^Q,W^K,W^V,W^O).
复杂度与参数量估算:一套"可手算"的公式
时间复杂度(单层,自注意力为主)
- 线性投影:(XW) 是 (O(n d_\text{model}^2))。
- 注意力矩阵乘:(QK^\top) 是 (O(n^2 d_k)),再乘 (AV) 是 (O(n^2 d_v))。
在常用 (d_k=d_v=d_\text{model}/h) 且 (h) 个头并行时,主项约为O(n\^2 d_\\text{model}) + O(n d_\\text{model}\^2).
对中长序列((n) 上千)时,(n^2 d_\text{model}) 绝对主导;这也解释了为何长上下文研究主要围绕注意力矩阵展开。
空间复杂度(训练时的关键瓶颈)
朴素实现往往显式存储注意力权重 (A\in\mathbb{R}^{h\times n\times n}),空间 (O(h n^2))。FlashAttention类工作指出,标准注意力在长序列上不仅算力二次,且会被显存读写(HBM↔SRAM搬运)拖慢,因此提出IO-aware的exact attention以减少中间矩阵存储与读写。
与RNN/CNN的差异(来自原论文的可操作结论)
原论文给出清晰对照:自注意力层在串行操作数上是 (O(1)),而RNN是 (O(n));最大路径长度自注意力是 (O(1)),RNN是 (O(n));复杂度方面,自注意力是 (O(n^2 d)),RNN是 (O(n d^2)),卷积是 (O(k n d^2))。
本节小结(带关键公式/推导要点)
- 自注意力矩阵形式:(\mathrm{Attention}(Q,K,V)=\mathrm{softmax}(QK^\top/\sqrt{d_k}+M)V)。citeturn10view0
- 缩放因子来自点积方差随 (d_k) 增长导致Softmax饱和;缩放使logits尺度更稳定。citeturn10view0
- 多头注意力等价于"一次大投影 + reshape + 分头注意力 + concat + 输出投影",并在 (d_k=d_\text{model}/h) 下保持总体计算量近似不变。citeturn8view0
- 残差 + LayerNorm负责可训练性与稳定性;LayerNorm摆放位置(Post-LN/Pre-LN)与是否需要强warmup密切相关。
位置建模:正弦位置编码的频域解释与相对位置方法谱系
注意力本身对输入顺序是置换不变 的:若你只把token当作集合,注意力无法区分"第1个词"和"第10个词"。原始Transformer因此必须向输入注入位置信息,并提出两类选择:可学习位置嵌入 或固定正弦位置编码,论文最终选择正弦版本以期具备长度外推能力。citeturn14view0
正弦位置编码公式与工程构造
原论文给出位置编码(Positional Encoding, PE):
PE_{(pos,2i)}=\\sin\\left(pos/10000\^{2i/d_\\text{model}}\\right), \\quad PE_{(pos,2i+1)}=\\cos\\left(pos/10000\^{2i/d_\\text{model}}\\right).
并将其与词嵌入逐元素相加(维度一致便于相加)。citeturn14view0
对应的工程实现通常是:预先生成一个 ((\text{max_len}, d_\text{model})) 的表,训练时按序列长度切片加到embedding上;推理时可通过offset支持KV缓存中的"位置偏移"。
频域解释:PE本质是"多频率Fourier特征"
对每个频率 (\omega_i=10000^{-2i/d_\text{model}}),位置 (pos) 被映射为二维向量:
\[\\sin(\\omega_i pos),\\ \\cos(\\omega_i pos)\].
这就是典型的Fourier特征映射(用不同频率的正弦/余弦基展开输入坐标),可视为把离散位置嵌入到一个"多尺度频率空间"。在更广泛的机器学习文献中,Fourier特征用于缓解网络对高频信号学习缓慢的"谱偏置",并提升对高频结构的表达能力。citeturn6search1turn6search22
在Transformer语境下,这带来两个关键性质:
- 多尺度相对距离敏感性:不同 (\omega_i) 对不同尺度的相对偏移更敏感(低频捕捉粗粒度、 高频捕捉细粒度)。
- 线性可表达的相对位移:正弦/余弦满足相位叠加恒等式,使得"位移 (k)"可通过一个线性变换作用在 ([ \sin(\omega pos),\cos(\omega pos)]) 上,这正是原论文所强调的"固定offset的PE可由线性函数表示"。citeturn14view0
"相对位置可线性表示"的推导步骤:从三角恒等式到线性变换
固定某一频率 (\omega),记
u(pos)=\\begin{bmatrix}\\sin(\\omega pos)\\ \\cos(\\omega pos)\\end{bmatrix}.
利用
\\sin(a+b)=\\sin a\\cos b+\\cos a\\sin b, \\quad \\cos(a+b)=\\cos a\\cos b-\\sin a\\sin b,
可得
u(pos+k)= \\begin{bmatrix}\\sin(\\omega(pos+k))\\ \\cos(\\omega(pos+k))\\end{bmatrix} \\underbrace{\\begin{bmatrix}\\cos(\\omega k)\&\\sin(\\omega k)\\ -\\sin(\\omega k)\&\\cos(\\omega k)\\end{bmatrix}}*{R(\\omega k)} \\begin{bmatrix}\\sin(\\omega pos)\\ \\cos(\\omega pos)\\end{bmatrix}.
也就是说,对每个频率分量,位置平移对应一个二维旋转矩阵 (R(\omega k)) 的线性作用;把所有频率对拼起来,就是一个分块对角的线性变换。这给出原论文"对任何固定offset (k),(PE*{pos+k}) 可表示为 (PE_{pos}) 的线性函数"的具体数学原因。citeturn14view0
从注意力角度看,这意味着模型可以更容易地学出"仅依赖相对距离"的模式:相对距离 (k) 在不同频率子空间体现为不同角度旋转,从而可被线性层组合出来。
相对位置与长上下文:从Shaw到Transformer-XL、RoPE与ALiBi
固定正弦PE是"把绝对位置加到输入表示上"。但大量任务更需要显式建模相对位置(距离、方向、局部邻域等)。经典相对位置做法之一是把相对距离的表示引入注意力机制,使自注意力能够直接考虑"token间距离"的表征。citeturn6search2
面向长上下文/长度外推,若干代表性路线包括:
- Transformer-XL提出"分段复用的记忆机制(segment-level recurrence)+ 新的位置信息处理",以突破固定上下文长度限制,并报告在评测时可显著加速(其摘要中强调了超长依赖与评测速度方面收益)。citeturn0search2
- RoPE(Rotary Position Embedding)用旋转矩阵把位置信息融入Q/K,从形式上把绝对位置编码为旋转,同时在注意力里自然体现相对位移关系,并强调长度灵活性等性质。citeturn4search2
- ALiBi通过在注意力logits中加入随距离线性变化的偏置,实现"train short, test long"的长度外推,并声称几乎不增加参数且改动少。citeturn4search3turn4search7
这些方法与"稀疏/线性注意力"往往是互补的:位置编码解决"如何表达位置/距离",而稀疏/线性注意力解决"如何把 (n^2) 算法降下来"。
本节小结(带关键公式/推导要点)
- 正弦PE的核心公式与工程做法来自原始Transformer;其本质是多频率Fourier特征映射。citeturn14view0turn6search1
- 利用三角恒等式可证明:对固定偏移 (k),每个频率对上的 ([\sin,\cos]) 发生一个二维旋转,从而 (PE_{pos+k}) 是 (PE_{pos}) 的线性变换;这解释了论文对"相对位置易学"的动机表述。citeturn14view0
- 相对位置路线代表:Shaw相对位置表示、Transformer-XL的相对式位置方案、RoPE旋转位置编码、ALiBi线性偏置。citeturn6search2turn0search2turn4search2turn4search3
编码器与解码器:结构、掩码机制与信息流
Transformer在宏观上是标准encoder-decoder堆栈。原论文给出:编码器由 (N=6) 个相同层堆叠,每层包含"多头自注意力 + 位置前馈网络";解码器同样 (N=6),但每层有三个子层:掩码自注意力(防止看见未来)+ 编码器-解码器注意力(跨注意力)+ FFN,且每个子层都配残差连接与LayerNorm。citeturn8view3
下面用示意图把信息流画清楚(图中"Add&Norm"表示残差相加后LayerNorm的Post-LN范式;工程上也常用Pre-LN)。citeturn8view3turn6search21
解码器层
MHA 掩码自注意力
MHA 交叉注意力
FFN
编码器层
MHA 自注意力
FFN
输入tokens + 位置编码
编码器层 × N
编码器输出 Memory
输出tokens右移 + 位置编码
解码器层 × N
线性层 + Softmax
输出分布/生成tokens
掩码机制:Padding mask 与 Causal mask 的区分
在实现中至少有两类掩码:
- Padding mask:屏蔽pad位置,使其不作为"被关注的key/value"。
- Causal mask(上三角mask):用于自回归解码,保证位置 (i) 只能看见 (\le i) 的历史token;原论文明确指出解码器自注意力需要这种mask以避免"看到未来"。citeturn8view3turn10view0
工程上常用的做法是在注意力logits上加一个矩阵 (M):合法位置加0,不合法位置加一个极大负数,从而softmax后近似为0。citeturn10view0
交叉注意力:解码器"查询",编码器"键/值"
交叉注意力(encoder-decoder attention)中:
Q \\leftarrow \\text{decoder hidden}, \\quad K,V \\leftarrow \\text{encoder output}.
这使得解码器每个位置都能"读取"源序列任意位置的编码结果,类比传统seq2seq中的对齐机制。citeturn8view0
本节小结(带关键要点)
- 结构层面:编码器每层2子层、解码器每层3子层,并配残差与LayerNorm;解码器自注意力必须用因果mask。citeturn8view3turn10view0
- 信息流层面:自注意力负责同一序列内部交互,交叉注意力负责源-目标交互。citeturn8view0
训练与优化:目标函数、学习率调度、正则化、混合精度与分布式
要把Transformer"训得又快又稳",训练策略和工程细节几乎与模型结构同等重要。本节把原论文与后续大模型训练常用要点合并成一套可执行建议。
目标函数与正则化:交叉熵、Label Smoothing 与 Dropout
在机器翻译等分类式token预测任务中,标准目标是token级交叉熵。原始论文在训练中明确使用三类正则化,其中最具代表性的是:
- Residual dropout:对每个子层输出做dropout,再与残差相加并归一化;并对embedding与位置编码之和也做dropout(base模型中 (P_\text{drop}=0.1))。citeturn9view0
- Label smoothing:使用 (\epsilon_{ls}=0.1);论文描述其会"伤害困惑度但提升准确率与BLEU"。citeturn9view0
上述两点在今天仍是默认强基线:dropout控制过拟合与共适应,label smoothing让模型输出分布不过度尖锐,通常对翻译/分类等任务更稳。
学习率调度与warmup:Noam schedule 的公式级理解
原始论文使用entity["organization","PyTorch","deep learning framework"]开发者也熟悉的Adam优化器,但关键在于其学习率调度(俗称"Noam schedule"):
\\mathrm{lrate}=d_\\text{model}\^{-0.5}\\cdot \\min(\\mathrm{step}\^{-0.5},\\ \\mathrm{step}\\cdot \\mathrm{warmup}\^{-1.5}),
并使用 (\mathrm{warmup}=4000)。同时Adam超参取 (\beta_1=0.9,\beta_2=0.98,\epsilon=10^{-9})。citeturn10view1turn9view0turn2search1
这个调度可分段解释:
- warmup期:(\mathrm{lrate}\propto \mathrm{step}) 线性升高,避免一开始更新过猛导致训练不稳定;
- 衰减期:(\mathrm{lrate}\propto \mathrm{step}^{-0.5}) 缓慢下降,保持后期收敛。citeturn10view1turn6search21
如果你采用Pre-LN结构,warmup的重要性可能下降,但在大规模训练中warmup仍常作为稳健默认。citeturn6search21turn6search6
常用超参建议:一套"从小模型到大模型"的可迁移经验
面向有深度学习基础的读者,这里给出偏工程可用的建议(你可视作强基线起点,再结合任务调参):
- 结构:(d_\text{model}\in{256,512,768,1024}),头数 (h\in{4,8,12,16}),确保 (d_\text{model}%\ h=0)。
- FFN维度:常取 (d_\mathrm{ff}\approx 4d_\text{model}) 起步(这与原论文典型配置一致的设计哲学相符)。citeturn8view0
- 正则:dropout 0.1 是强基线;label smoothing 在翻译/分类任务常有效。citeturn9view0
- 梯度裁剪:在长序列或大batch训练中可用global norm裁剪(如1.0或0.5)抑制梯度爆炸;尤其RNN式问题更常见,但Transformer在极深/大LR时也可能受益(属于通用稳定性技巧)。
- 权重衰减:对大规模预训练常用AdamW风格weight decay;例如BERT预训练使用L2 weight decay 0.01,并配合warmup与线性衰减。citeturn12view1turn11view2
预训练目标与训练设定:以BERT为代表的"任务设计影响训练细节"
BERT的代表性在于:它用Transformer编码器做双向上下文建模,并采用Masked LM(MLM)与Next Sentence Prediction(NSP)联合预训练。其关键细节包括:随机mask约15%的WordPiece token,并采用"80%替换为[MASK]、10%随机词、10%保持不变"的策略以缓解预训练-微调不一致;NSP中50%为真实下一句、50%为随机句。citeturn12view0turn12view1
这些细节说明:Transformer训练并非只有"结构",数据构造与目标函数会直接改变最优学习率、训练步数、以及是否需要额外稳定化手段。citeturn11view0turn11view2
混合精度与分布式:把训练从"可跑"变为"可扩展"
当模型与序列长度上来后,瓶颈往往不在算力而在显存与通信。
- 混合精度训练(FP16/BF16):经典工作表明可用半精度存储权重/激活/梯度,并配合loss scaling等策略,在不牺牲精度的情况下减少显存与提升吞吐。citeturn3search0turn3search4
- ZeRO与大模型内存优化:ZeRO通过拆分/去冗余优化器状态、梯度与参数分片,显著降低数据并行下的内存冗余,并在论文中强调可扩展到百亿乃至更大规模。citeturn3search1
- 梯度检查点(activation checkpointing):通过丢弃部分中间激活、反传时重算,以计算换显存,论文给出可将训练内存成本降为次线性量级的系统方法。citeturn3search2
本节小结(带关键公式/推导要点)
- 核心"可复用公式":Noam学习率调度 (\mathrm{lrate}=d_\text{model}{-0.5}\min(\mathrm{step}{-0.5},\mathrm{step}\cdot\mathrm{warmup}^{-1.5})),配Adam ((\beta_1,\beta_2,\epsilon)=(0.9,0.98,10^{-9}))。citeturn10view1
- 原始Transformer的强基线正则:dropout 0.1 + label smoothing 0.1。citeturn9view0
- 大规模训练的三件套:混合精度、ZeRO类状态分片、梯度检查点。citeturn3search0turn3search1turn3search2
变体谱系:BERT、GPT、T5、Transformer-XL、Longformer、Sparse/线性注意力
Transformer的"核心积木"稳定后,主流变体主要沿着三条轴演化:
- 架构轴:encoder-only / decoder-only / encoder-decoder
- 上下文轴:固定长度 / 记忆机制 / 稀疏或线性注意力
- 训练目标轴:自回归LM / MLM / 去噪式seq2seq / 多任务统一格式
先给出一个"家谱式"示意,帮助你把模型放进框架里:
Transformer核心模块
Encoder-only
Decoder-only
Encoder-Decoder
长上下文与高效注意力
BERT / RoBERTa风格
GPT-2/3风格
T5风格
Transformer-XL
Longformer/LED
Sparse Transformer
Reformer
Linformer
Performer
三大主线模型的"结构本质"
BERT(Encoder-only)
BERT以Transformer编码器为骨干,用MLM实现深层双向表示,并辅以NSP学习句对关系;其论文图示强调与"仅左到右"的GPT式预训练差异。citeturn12view1turn12view0
GPT系列(Decoder-only)
GPT-2技术报告明确指出其模型是大规模Transformer语言模型,并强调"模型容量对零样本迁移很关键",其最大模型达1.5B参数。citeturn7view2turn1search0
GPT-3进一步展示随着规模扩大,few-shot/one-shot/zero-shot(不做梯度更新,仅靠上下文示例)能力显著增强,并报告其自回归模型规模达到175B参数。citeturn7view3turn1search9
T5(Encoder-Decoder)
T5的核心思想是把所有NLP任务统一为text-to-text格式,在统一框架下系统比较预训练目标、数据与迁移策略,并通过规模化获得强性能。citeturn7view4turn1search2
长上下文与高效注意力:Transformer的"第二战场"
长序列瓶颈来自自注意力的二次复杂度。代表性解决方案包括:
- Transformer-XL:通过分段复用记忆(缓存历史隐藏状态)与新位置方案,突破固定上下文长度限制,并强调评测时速度优势。citeturn0search2
- Longformer:提出线性复杂度注意力模式(局部窗口注意力 + 任务相关的全局注意力),用于处理成千上万token的长文档,并提出LED以支持长文档生成式seq2seq。citeturn0search3turn0search23
- Sparse Transformer:通过稀疏注意力模式让模型能处理"数万步长"的序列。citeturn1search15
- Reformer:用局部敏感哈希(LSH)将注意力复杂度降到 (O(L\log L)) 等级,并用可逆残差层减少激活存储。citeturn5search0
- Linformer:以低秩近似解释注意力矩阵可压缩,从而把时间/空间复杂度降为 (O(n))。citeturn5search1
- Performer:用随机特征近似softmax注意力核(FAVOR+),在不依赖稀疏/低秩先验下实现线性复杂度并给出理论保证。citeturn5search2
变体对比表:特点、优缺点与适用场景
下表侧重"你做系统设计/选型时最关心的差异",表后附主要信息来源。
| 变体/家族 | 结构范式 | 核心改动点 | 复杂度侧重点 | 优点 | 局限 | 常见场景 |
|---|---|---|---|---|---|---|
| BERT | Encoder-only | MLM(+NSP)预训练,双向表示 | 仍为 (O(n^2)) 注意力 | 强理解/表征,微调范式成熟 | 生成能力弱(非自回归) | 分类、抽取式QA、匹配检索 |
| GPT-2/3 | Decoder-only | 自回归LM,因果mask | 推理依赖KV cache,注意力仍二次 | 强生成与in-context学习(随规模提高) | 长上下文成本高,生成推理串行 | 对话生成、写作、代码生成 |
| T5 | Encoder-Decoder | 统一text-to-text,多任务 | 双塔交互(交叉注意力)成本 | 统一范式适合seq2seq任务 | 结构更复杂、推理开销更高 | 翻译、摘要、问答生成 |
| Transformer-XL | Decoder为主(LM) | 记忆缓存+新位置方案 | 通过缓存扩展有效上下文 | 长依赖更强、评测可加速 | 训练/实现更复杂 | 长文本语言建模 |
| Longformer/LED | Encoder或Enc-Dec | 局部窗口+全局注意力 | 注意力近线性 (O(n)) | 长文档可用、替换标准注意力 | 注意力模式受限,需设计global token | 长文档分类、长摘要 |
| Sparse Transformer | 多种 | 稀疏注意力图 | 比全注意力更省 | 可扩到数万token级 | 稀疏模式设计与质量权衡 | 超长序列建模 |
| Reformer | 多种 | LSH注意力+可逆层 | (O(L\log L))+省激活 | 长序列更高效 | 近似带来复杂性与潜在掉点 | 长序列、资源受限训练 |
| Linformer | 多种 | 低秩投影近似 | 线性 (O(n)) | 理论直觉清晰、计算省 | 近似质量依赖秩假设 | 中长序列加速 |
| Performer | 多种 | 随机特征核近似softmax | 线性 (O(n)) | 有理论误差界、兼容Transformer | 随机近似带方差与实现细节 | 超长序列、核化注意力 |
表中有关BERT的MLM/NSP细节、GPT-2/3规模与训练范式、T5统一text-to-text框架、Transformer-XL长上下文机制、Longformer线性注意力设计,以及Sparse/Reformer/Linformer/Performer的复杂度与核心思想,分别综合自对应原论文/技术报告摘要与正文描述。citeturn12view0turn7view2turn7view3turn7view4turn0search2turn0search3turn1search15turn5search0turn5search1turn5search2
本节小结(带关键要点)
- 三大结构范式:Encoder-only偏理解、Decoder-only偏生成、Encoder-Decoder偏条件生成/转换。citeturn12view1turn7view2turn7view4
- 长上下文变体的两大方向:改变注意力图(稀疏/线性/近似)与改变"上下文使用方式"(缓存记忆/相对位置)。citeturn0search2turn0search3turn5search2
工程实现、调试与性能优化:PyTorch代码、推理加速与常见问题、参考来源
本章给出:
- NumPy层面的最小注意力实现(帮助你把矩阵公式落到代码);
- 一段"可直接拷贝用"的PyTorch实现:位置编码、缩放点积注意力、多头自注意力、TransformerBlock;
- 调试与性能优化路线(从小模型到大模型、从训练到推理)。
NumPy关键代码片段:用最少代码验证矩阵公式
python
import numpy as np
def softmax(x, axis=-1):
x = x - np.max(x, axis=axis, keepdims=True) # 数值稳定
e = np.exp(x)
return e / np.sum(e, axis=axis, keepdims=True)
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
Q: (n, d_k), K: (n, d_k), V: (n, d_v)
mask: (n, n), 合法位置为0,不合法位置为 -1e9 或 -np.inf
"""
d_k = Q.shape[-1]
scores = (Q @ K.T) / np.sqrt(d_k) # 对应 QK^T / sqrt(d_k)
if mask is not None:
scores = scores + mask
A = softmax(scores, axis=-1) # softmax 得到注意力权重
O = A @ V # 输出 = A V
return O, A
这段代码与公式 (\mathrm{softmax}(QK^\top/\sqrt{d_k}+M)V) 一一对应,可用于做"单元测试基准"。citeturn10view0
PyTorch完整核心模块实现:位置编码 + 多头自注意力 + TransformerBlock
下面代码以"教学清晰 + 工程可用"为目标:
- 采用Pre-LN(更稳)或Post-LN(复现原论文)均可切换;
- 支持causal mask与padding mask的logits加法掩码;
- 返回注意力权重便于可视化/调试(实际训练可关闭以省显存)。
python
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
class SinusoidalPositionalEncoding(nn.Module):
"""
固定正弦位置编码:PE(pos,2i)=sin(pos/10000^(2i/d_model)),
PE(pos,2i+1)=cos(pos/10000^(2i/d_model)).
"""
def __init__(self, d_model: int, max_len: int = 4096):
super().__init__()
pe = torch.zeros(max_len, d_model) # (max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float()# (max_len, 1)
# div_term 对应 1 / 10000^(2i/d_model),用 exp/log 更稳定
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
) # (d_model/2,)
pe[:, 0::2] = torch.sin(position * div_term) # 偶数维
pe[:, 1::2] = torch.cos(position * div_term) # 奇数维
# register_buffer:随模型搬到GPU,但不作为可训练参数
self.register_buffer("pe", pe, persistent=False)
def forward(self, x: torch.Tensor, offset: int = 0) -> torch.Tensor:
"""
x: (B, T, d_model)
offset: 用于KV cache/分段推理时的位置偏移
"""
B, T, E = x.shape
return x + self.pe[offset: offset + T].unsqueeze(0) # (1,T,E) 广播到 batch
def make_causal_mask(T: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
"""
生成加法掩码:上三角为 -inf,其他为 0。
形状: (1, 1, T, T),便于广播到 (B, H, T, T)
"""
mask = torch.full((T, T), float("-inf"), device=device, dtype=dtype)
mask = torch.triu(mask, diagonal=1) # j>i 的位置为 -inf
return mask.view(1, 1, T, T)
class ScaledDotProductAttention(nn.Module):
def __init__(self, dropout: float = 0.0):
super().__init__()
self.dropout = dropout
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attn_mask: torch.Tensor | None = None,
is_causal: bool = False,
return_attn: bool = False,
):
"""
q,k,v: (B, H, T, D)
attn_mask: 可加到 logits 上的 mask,形状可广播到 (B,H,T,T)
"""
B, H, T, D = q.shape
scale = 1.0 / math.sqrt(D)
# logits: (B,H,T,T)
logits = torch.matmul(q, k.transpose(-2, -1)) * scale
# causal mask
if is_causal:
logits = logits + make_causal_mask(T, logits.device, logits.dtype)
# padding/自定义 mask(例如屏蔽pad token或局部窗口)
if attn_mask is not None:
logits = logits + attn_mask
attn = F.softmax(logits, dim=-1) # (B,H,T,T)
attn = F.dropout(attn, p=self.dropout, training=self.training)
out = torch.matmul(attn, v) # (B,H,T,D)
if return_attn:
return out, attn
return out, None
class MultiHeadAttention(nn.Module):
"""
通用多头注意力:支持 self-attn 与 cross-attn(q 来自 x,k/v 来自 kv)。
"""
def __init__(self, d_model: int, num_heads: int, dropout: float = 0.0, bias: bool = True):
super().__init__()
assert d_model % num_heads == 0, "d_model 必须能被 num_heads 整除"
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.q_proj = nn.Linear(d_model, d_model, bias=bias)
self.k_proj = nn.Linear(d_model, d_model, bias=bias)
self.v_proj = nn.Linear(d_model, d_model, bias=bias)
self.out_proj = nn.Linear(d_model, d_model, bias=bias)
self.attn_core = ScaledDotProductAttention(dropout=dropout)
def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
"""
(B,T,E) -> (B,H,T,D)
"""
B, T, E = x.shape
x = x.view(B, T, self.num_heads, self.head_dim)
return x.transpose(1, 2)
def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
"""
(B,H,T,D) -> (B,T,E)
"""
B, H, T, D = x.shape
x = x.transpose(1, 2).contiguous().view(B, T, H * D)
return x
def forward(
self,
x: torch.Tensor,
kv: torch.Tensor | None = None,
attn_mask: torch.Tensor | None = None,
is_causal: bool = False,
return_attn: bool = False,
):
"""
x: (B,T,E) 作为 query 来源
kv: (B,S,E) 作为 key/value 来源(None 表示 self-attn)
"""
if kv is None:
kv = x
q = self._split_heads(self.q_proj(x)) # (B,H,T,D)
k = self._split_heads(self.k_proj(kv)) # (B,H,S,D)
v = self._split_heads(self.v_proj(kv)) # (B,H,S,D)
out, attn = self.attn_core(q, k, v, attn_mask=attn_mask, is_causal=is_causal, return_attn=return_attn)
out = self._merge_heads(out) # (B,T,E)
out = self.out_proj(out) # (B,T,E)
return out, attn
class FeedForward(nn.Module):
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.0, activation: str = "gelu"):
super().__init__()
self.fc1 = nn.Linear(d_model, d_ff)
self.fc2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
self.activation = activation
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.activation == "relu":
x = F.relu(self.fc1(x))
else:
x = F.gelu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x
@dataclass
class TransformerBlockConfig:
d_model: int = 512
num_heads: int = 8
d_ff: int = 2048
dropout: float = 0.1
prenorm: bool = True
class TransformerBlock(nn.Module):
"""
一个标准Transformer块(不区分encoder/decoder),只实现 self-attn + FFN。
decoder 的 cross-attn 可再加一个 MHA 子层。
"""
def __init__(self, cfg: TransformerBlockConfig):
super().__init__()
self.cfg = cfg
self.ln1 = nn.LayerNorm(cfg.d_model)
self.ln2 = nn.LayerNorm(cfg.d_model)
self.attn = MultiHeadAttention(cfg.d_model, cfg.num_heads, dropout=cfg.dropout)
self.ffn = FeedForward(cfg.d_model, cfg.d_ff, dropout=cfg.dropout)
self.dropout = nn.Dropout(cfg.dropout)
def forward(
self,
x: torch.Tensor,
attn_mask: torch.Tensor | None = None,
is_causal: bool = False,
return_attn: bool = False,
):
if self.cfg.prenorm:
# Pre-LN: x <- x + Attn(LN(x))
h = self.ln1(x)
attn_out, attn = self.attn(h, kv=None, attn_mask=attn_mask, is_causal=is_causal, return_attn=return_attn)
x = x + self.dropout(attn_out)
# x <- x + FFN(LN(x))
h = self.ln2(x)
x = x + self.dropout(self.ffn(h))
else:
# Post-LN: LN(x + Sublayer(x))
attn_out, attn = self.attn(x, kv=None, attn_mask=attn_mask, is_causal=is_causal, return_attn=return_attn)
x = self.ln1(x + self.dropout(attn_out))
x = self.ln2(x + self.dropout(self.ffn(x)))
return x, attn
关键代码解释(逐行"抓重点")
- 位置编码:
div_term = exp(arange(0,d_model,2)*(-log(10000)/d_model))对应 (10000^{-2i/d}) 的稳定实现;register_buffer确保PE随设备迁移但不参与训练更新。citeturn14view0 - 缩放注意力:
scale = 1/sqrt(D)与logits = q @ k^T * scale对应 (QK^\top/\sqrt{d_k});make_causal_mask构造上三角 (-\infty) mask,对应解码器"不能看未来"。citeturn10view0turn8view3 - 多头reshape:
_split_heads把 ((B,T,E)) 变为 ((B,H,T,D)),本质就是"等价视角一"中的一次大投影后reshape。 - Pre-LN/Post-LN切换:
prenorm=True时先LN再子层再残差,有更稳定梯度行为的理论/实证支撑;prenorm=False更贴近原始"LayerNorm(x + Sublayer(x))"。citeturn6search21turn8view3
利用框架优化:优先使用库内优化注意力与FlashAttention
若你追求性能而非教学透明度,建议优先使用框架提供的优化实现:
- entity["organization","PyTorch","deep learning framework"]的
nn.MultiheadAttention文档明确提到:在可能时会使用优化过的scaled_dot_product_attention()实现,并具备推理fastpath等优化条件。citeturn3search3 - FlashAttention提出IO-aware的exact attention,通过tiling减少HBM读写、避免存储大 (n\times n) 中间矩阵,从而在长序列上实现显著加速与省显存;FlashAttention-2进一步改进并行与work partition以提高GPU利用率。citeturn2search3turn5search3turn2search27
调试路线:从小模型开始,把"形状/掩码/数值"一步步排干净
在实践中,Transformer最常见bug集中在三类:张量形状、mask语义、数值稳定性。一个高成功率路线是:
- 从极小配置起步:例如 (d_\text{model}=64)、heads=4、层数=2、序列长度=16,先让loss下降且梯度正常。
- 对齐基准实现:用上面的NumPy版本对照PyTorch版本(同权重、关掉dropout、固定输入),检查输出误差是否仅来自浮点误差。
- 显式检查mask:打印/可视化attention矩阵,确认因果mask确实是严格上三角屏蔽;padding mask不会让pad位置获得权重。citeturn10view0turn8view3
- 数值稳定性 :对softmax输入做
x - max(x)(NumPy示例已给),在FP16下尤其关键;混合精度训练往往需要loss scaling以避免下溢。citeturn3search0turn3search4
从小模型迁移到大模型:训练与推理的关键瓶颈变化
当模型增大后,你会发现瓶颈从"理论公式"变成"系统工程":
- 训练侧:显存瓶颈主要来自激活与注意力中间结果;可用梯度检查点减少激活存储,用ZeRO分片优化器/梯度/参数状态,用混合精度减少存储与提升吞吐。citeturn3search2turn3search1turn3search0
- 推理侧:自回归生成会引入KV cache(每步追加缓存,避免重复计算历史K/V);长上下文时KV cache带宽与访存成为瓶颈,FlashAttention类工作及其社区实现常被用于缓解。citeturn2search3turn5search3turn0search2
应用场景与案例研究:从翻译、理解到长文档与生成
- 机器翻译:原始Transformer在WMT 2014英德/英法任务上报告了强结果,并强调在8块entity["company","NVIDIA","gpu company"] P100 GPU上训练成本与时间优势;其Table中也展示了相对以往系统的训练成本对比。citeturn9view0turn10view1
- 语言理解:BERT论文强调其通过MLM实现深度双向表示,并在多任务上达到强性能;其训练细节(15% mask、NSP构造)也说明"数据与目标"对效果关键。citeturn12view0turn11view1
- 大模型生成与in-context学习:GPT-2报告展示零样本迁移与规模效应;GPT-3进一步系统化few-shot/one-shot/zero-shot评测并强调不做梯度更新的in-context范式。citeturn7view2turn7view3
- 长文档任务:Longformer明确指出标准自注意力二次复杂度限制长序列处理,并提出线性注意力模式(局部+全局)与LED用于长文档生成。citeturn0search3turn0search23
常见问题与答疑:围绕"最容易卡住的点"
为什么自注意力一定要除以 (\sqrt{d_k})?
若不缩放,点积方差随 (d_k) 增大导致logits幅值变大,softmax饱和区梯度变小,训练更不稳定;缩放让尺度更可控。citeturn10view0
多头注意力到底在"理论上"做了什么?
在数学上等价于把表示映射到多个子空间并行做注意力,再拼接回去;工程上等价于"一次线性层生成Q/K/V,再reshape成多头"。原论文强调其好处是让模型能在不同位置、不同表征子空间联合关注信息。citeturn8view0
Post-LN和Pre-LN怎么选?
Post-LN更贴近原始"LayerNorm(x + Sublayer(x))";Pre-LN常被认为更稳定、对warmup依赖更弱。若你训练很深的模型或遇到loss spike,优先尝试Pre-LN并重新调学习率/调度。citeturn6search21turn8view3
为什么长上下文这么难?是不是把注意力改成线性就完了?
长上下文难点不仅是算力复杂度,还包括显存读写、KV cache带宽、以及稀疏/近似带来的质量权衡。Longformer用局部+全局模式实现线性;Linformer/Performer用低秩或随机特征近似;FlashAttention强调减少IO而非近似(仍是exact)。这些路线互补且各有适用条件。
如何快速找到高质量中文资料与实现?
《动手学深度学习》中文站覆盖注意力机制、自注意力与位置编码、Transformer结构并配套代码;entity["company","Hugging Face","ai company"]提供Transformers中文文档与课程,方便直接上手预训练模型训练与推理。