2026大模型面试圣经(1):Transformer全解析 | 从Self-Attention到Multi-Head,一文通关Transformer面试
定位:本章是整个面试圣经的基石。Transformer是所有大模型的"操作系统",不搞懂它后面全白搭。
目标:看完本章,你能从零推导Attention公式,手写Multi-Head Attention,画出完整架构图,回答任何Transformer相关面试题。
模块一:Self-Attention机制 | 注意力的本质
1.1 核心概念
什么是Attention?
一句话:Attention就是"加权求和"------对输入序列中的每个位置,根据"相关性"分配不同的权重,然后做加权聚合。
传统RNN/LSTM处理序列时,信息必须沿着时间步一步步传递,长距离信息经过多次传递后会衰减甚至消失。Attention机制让任意两个位置之间可以"直接对话",路径长度为O(1)。
Self-Attention vs Cross-Attention
| 特性 | Self-Attention | Cross-Attention |
|---|---|---|
| Q/K/V来源 | 全部来自同一序列 | Q来自一个序列,K/V来自另一个序列 |
| 典型应用 | Encoder内部、Decoder的Masked Self-Attention | Decoder中的Encoder-Decoder Attention |
| 作用 | 序列内部建模 | 序列间信息交互 |
Q、K、V的直觉理解
- Query(查询):当前位置"想找什么信息"
- Key(键):每个位置"有什么信息可以提供"
- Value(值):每个位置"实际提供的信息内容"
工作流程:Q和K做匹配得到注意力权重,权重乘以V得到输出。类似于"搜索引擎"------Q是搜索词,K是网页标题,V是网页内容。
1.2 原理推导
Scaled Dot-Product Attention公式
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V
推导过程(面试必考):
- 计算注意力分数 : S = Q K T S = QK^T S=QKT,其中 Q ∈ R n × d k Q \in \mathbb{R}^{n \times d_k} Q∈Rn×dk, K ∈ R n × d k K \in \mathbb{R}^{n \times d_k} K∈Rn×dk,得到 S ∈ R n × n S \in \mathbb{R}^{n \times n} S∈Rn×n
- 缩放 : S scaled = S d k S_{\text{scaled}} = \frac{S}{\sqrt{d_k}} Sscaled=dk S
- Softmax归一化 : α = softmax ( S scaled ) \alpha = \text{softmax}(S_{\text{scaled}}) α=softmax(Sscaled),使每行和为1
- 加权求和 : Output = α V \text{Output} = \alpha V Output=αV
为什么要除以 d k \sqrt{d_k} dk ?(高频考点)
假设Q和K的每个元素都是独立的均值为0、方差为1的随机变量,那么:
- q ⋅ k = ∑ i = 1 d k q i k i q \cdot k = \sum_{i=1}^{d_k} q_i k_i q⋅k=∑i=1dkqiki
- 每项 q i k i q_i k_i qiki的均值为0,方差为 Var ( q i ) Var ( k i ) = 1 \text{Var}(q_i)\text{Var}(k_i) = 1 Var(qi)Var(ki)=1
- 所以 q ⋅ k q \cdot k q⋅k的方差为 d k d_k dk( d k d_k dk个独立随机变量之和)
- 当 d k d_k dk很大时(如64、128),点积值的量级会很大
- 大数值输入softmax后,输出接近one-hot分布,梯度几乎为0
- 除以 d k \sqrt{d_k} dk 后方差变回1,softmax输出分布更平滑,梯度更健康
1.3 代码实现
python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class ScaledDotProductAttention(nn.Module):
def __init__(self, d_k, dropout=0.1):
super().__init__()
self.d_k = d_k
self.dropout = nn.Dropout(dropout)
def forward(self, Q, K, V, mask=None):
# Q, K, V: [batch, heads, seq_len, d_k]
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
output = torch.matmul(attn_weights, V)
return output, attn_weights
代码要点:
mask == 0的位置填充-inf,softmax后变为0,实现遮蔽- dropout在attention权重上做,而不是在scores上
- 返回attention权重方便可视化和分析
1.4 工程实践
padding mask vs causal mask
-
Padding Mask:对batch中不同长度的序列做padding,padding位置的attention权重应为0
python# padding_mask: [batch, 1, 1, seq_len],padding位为0 scores = scores.masked_fill(padding_mask == 0, float('-inf')) -
Causal Mask(因果掩码):Decoder中用,防止看到未来信息
pythoncausal_mask = torch.tril(torch.ones(seq_len, seq_len)) # 上三角为0,对角线及下三角为1 -
组合使用:实际中两种mask通常需要组合
pythonfinal_mask = padding_mask & causal_mask # 逐元素与
1.5 面试考点精讲
Q1:Self-Attention为什么需要Q、K、V三个矩阵?用一个或两个行不行?
秒答:三个矩阵各司其职------Q负责"提问",K负责"被匹配",V负责"提供信息"。
展开 :如果只用一个矩阵( X ⋅ X T X \cdot X^T X⋅XT),那么每个位置和自己的相似度最高,attention退化为近似单位矩阵。用Q和K两个不同的投影,可以让"提问视角"和"被查询视角"不同,大大增加表达能力。V和K分开是因为"匹配标准"和"实际传递的信息"可以是不同的------就像你查字典时,按拼音(K)查找,但实际获取的是释义(V)。
Q2:Attention的时间复杂度和空间复杂度各是多少?
秒答 :时间 O ( n 2 d ) O(n^2 d) O(n2d),空间 O ( n 2 ) O(n^2) O(n2),其中 n n n是序列长度, d d d是特征维度。
展开 : Q K T QK^T QKT的矩阵乘法是 O ( n 2 d ) O(n^2 d) O(n2d),生成的 n × n n \times n n×n注意力矩阵要存储所以空间是 O ( n 2 ) O(n^2) O(n2)。这就是为什么标准Transformer处理长序列( n > 8 K n > 8K n>8K)会很慢------FlashAttention、稀疏Attention等都是在解决这个 O ( n 2 ) O(n^2) O(n2)问题。
Q3:为什么是点积Attention而不是加性Attention?
秒答:效果差不多,但点积可以用矩阵乘法硬件加速,实际速度更快。
展开 :加性Attention用一层MLP计算相关性: score = v T tanh ( W 1 q + W 2 k ) \text{score} = v^T \tanh(W_1 q + W_2 k) score=vTtanh(W1q+W2k),表达能力理论上更强。但点积 q ⋅ k q \cdot k q⋅k可以直接用高度优化的GEMM库(cuBLAS),在GPU上效率高得多。在 d k d_k dk较大时加性Attention略优,但加了 d k \sqrt{d_k} dk 缩放后差距消失。
Q4:Softmax Attention存在"注意力汇聚"(Attention Sink)现象,这是什么?
秒答:模型倾向于把大量注意力分配给第一个token(通常是BOS/CLS),即使它语义上不重要。
展开:这是2023年发现的现象(StreamingLLM论文)。原因是softmax要求所有权重之和为1,当模型"不确定该关注哪里"时,会把权重"倒"到第一个位置作为"垃圾桶"。实际影响:做KV-Cache裁剪时不能删掉最前面的几个token,否则性能暴跌。StreamingLLM的解决方案:保留前几个sink tokens + 滑动窗口。
1.6 【大厂真题 - 字节/DeepSeek高频】
真题 1:字节跳动算法岗------"既然你提到了KV-Cache,请推导一下在极限长文本下,KV Cache的显存占用公式,并解释DeepSeek V2/V3是如何用MLA(Multi-head Latent Attention)解决这个内存刺客问题的?"
痛点剖析:这道题考察对底层算力瓶颈的敏感度。长文本时代,参数量不再是推理最大的瓶颈,KV Cache才是。
极客解法(秒答+公式):
- 标准MHA的KV Cache公式 :每生成一个Token,需缓存其Key和Value向量。
显存占用 = 2(K和V) × 序列长度L × 层数N × 隐层维度d_model × 批次大小B × 精度字节数(如FP16为2字节)。
假设N=32, d_model=4096,FP16下,每1K Token的KV Cache约为0.5MB/生成流。当并发飙到1000,上下文长达128K时,光KV Cache就需要几十GB甚至上百GB显存!- DeepSeek MLA的破局思路 :
核心思想是低秩压缩(Low-Rank Compression) 。MLA不再专门存储庞大的K和V矩阵,而是只存储一个极度降维的隐变量(Latent Vector) c t c_t ct(例如降维到512)。
- 推理时,当需要计算Attention,从这个极小的 c t c_t ct中当场恢复/投影 出K和V( k t = c t W K , v t = c t W V k_t = c_t W_K, v_t = c_t W_V kt=ctWK,vt=ctWV)。
- 配合RoPE位移编码的特殊解耦设计,MLA将KV Cache的存储量压缩到了标准MHA的 5.4%(近20倍压缩率),这意味着单卡可以支撑原本多卡才能搞定的巨大Batch Size,吞吐量实现降维打击。
真题 2:DeepSeek 核心研发岗------"刚才推导了MLA,那你知道在MLA中应用RoPE(旋转位置编码)时会遇到什么数学冲突?DeepSeek是如何优雅化解(Decoupled RoPE)的?"
痛点剖析:这是一道纯纯的数学+工程地狱题。只背过概念的人立刻挂定。
极客解法(硬核推导):
- 数学冲突 :标准MHA中, q ⋅ k q \cdot k q⋅k 计算时会经过RoPE矩阵 R m R_m Rm 旋转( q T R m − n k q^T R_{m-n} k qTRm−nk)。但在MLA里, k k k 并不是存下来的,而是通过低维向量 c c c 实时乘以权重矩阵 W K W_K WK 算出来的( k = c W K k = c W_K k=cWK)。
如果我们要把RoPE融入进去,公式变成了 ... ( c W K R n ) T ... \dots (c W_K R_n)^T \dots ...(cWKRn)T... 。问题来了!RoPE矩阵 R n R_n Rn 是受位置 n n n 影响的动态矩阵 ,它和静态权重 W K W_K WK 乘在一起,没法结合成一个常量提前算好(无法吸收到权重里),导致每次不得不重新施加庞大的RoPE矩阵,这直接打碎了低秩压缩带来的计算加速美梦。- DeepSeek的化解之道 (Decoupled RoPE) :
DeepSeek极其巧妙地将Query和Key撕成两半 :带RoPE的部分和不带RoPE的部分(内容部分)。
- 内容部分(Content):纯粹通过低维隐向量 c c c 压缩和恢复,产生 k c k_c kc(完全不碰RoPE)。
- 旋转部分(RoPE):专门单拉出一小截维度(比如64维)生成 k r k_r kr,单独施加RoPE并缓存。
- 计算时:
Attention Score = (q_c点乘k_c) + (q_r点乘k_r)。
绝杀结果 :绝大部分KV信息被压缩在 c c c 中(无需存巨大的K),只有极小维度的 k r k_r kr 带着位置信息被真实缓存。既完成了KV Cache极限压缩,又完美保留了RoPE的相对位置感知能力。
模块二:Multi-Head Attention | 多头并行的智慧
2.1 核心概念
为什么需要Multi-Head?
单头Attention只有一种"关注模式"。但自然语言中,同一个词可能需要同时关注:
- 语法关系(主语-谓语-宾语)
- 语义关系(近义词、反义词)
- 位置关系(相邻词、句首句尾)
Multi-Head Attention让模型拥有多个"注意力头",每个头可以学习不同的关注模式,最后拼接起来。
Multi-Head公式
MultiHead ( Q , K , V ) = Concat ( head 1 , . . . , head h ) W O \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h) W^O MultiHead(Q,K,V)=Concat(head1,...,headh)WO
其中每个头:
head i = Attention ( Q W i Q , K W i K , V W i V ) \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) headi=Attention(QWiQ,KWiK,VWiV)
参数量分析
- 输入维度 d model d_{\text{model}} dmodel,头数 h h h,每头维度 d k = d model / h d_k = d_{\text{model}} / h dk=dmodel/h
- 每个头的Q/K/V投影: 3 × d model × d k = 3 d model 2 / h 3 \times d_{\text{model}} \times d_k = 3d_{\text{model}}^2/h 3×dmodel×dk=3dmodel2/h
- h h h个头总共: 3 d model 2 3d_{\text{model}}^2 3dmodel2
- 输出投影 W O W^O WO: d model 2 d_{\text{model}}^2 dmodel2
- MHA总参数: 4 d model 2 4d_{\text{model}}^2 4dmodel2(和单头一样!多头并不增加参数量)
2.2 原理推导
为什么每个头要降维?
- 单头: d k = d model d_k = d_{\text{model}} dk=dmodel,计算 Attention \text{Attention} Attention一次
- 多头: d k = d model / h d_k = d_{\text{model}} / h dk=dmodel/h,计算 Attention \text{Attention} Attention共 h h h次
- 总计算量基本相同: h × O ( n 2 ⋅ d model / h ) = O ( n 2 ⋅ d model ) h \times O(n^2 \cdot d_{\text{model}}/h) = O(n^2 \cdot d_{\text{model}}) h×O(n2⋅dmodel/h)=O(n2⋅dmodel)
- 但多头提供了 h h h个不同的子空间表示,信息更丰富
2.3 代码实现
python
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads, dropout=0.1):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, Q, K, V, mask=None):
batch_size = Q.size(0)
# 线性投影 + reshape成多头: [batch, seq, d_model] -> [batch, heads, seq, d_k]
Q = self.W_q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_k(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
V = self.W_v(V).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
# Scaled Dot-Product Attention
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1)
attn = self.dropout(attn)
context = torch.matmul(attn, V)
# 拼接多头: [batch, heads, seq, d_k] -> [batch, seq, d_model]
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
output = self.W_o(context)
return output
2.4 工程实践
MHA → MQA → GQA 的演进
| 方法 | Q头数 | K头数 | V头数 | KV-Cache大小 | 代表模型 |
|---|---|---|---|---|---|
| MHA | h h h | h h h | h h h | 2 n h d k 2nhd_k 2nhdk | GPT-2, BERT |
| MQA | h h h | 1 | 1 | 2 n d k 2nd_k 2ndk | PaLM, StarCoder |
| GQA | h h h | g g g | g g g | 2 n g d k 2ngd_k 2ngdk | LLaMA-2-70B, Qwen |
其中 g g g是KV的组数, g = 1 g=1 g=1退化为MQA, g = h g=h g=h退化为MHA。
实际影响:LLaMA-2-70B用GQA(8组),KV-Cache从MHA的100%降到约12%,推理速度提升显著,精度损失极小。
2.5 面试考点精讲
Q1:MQA和GQA分别是什么?为什么能加速推理?
秒答:MQA让所有Q头共享一组KV,GQA让每几个Q头共享一组KV。KV-Cache变小了,显存少了,推理自然快了。
展开:推理瓶颈在KV-Cache的显存占用和带宽。一个7B模型32层、32头,序列长度4096,用FP16,MHA的KV-Cache约2GB。GQA(8组)只要0.5GB。显存少了一方面能处理更长序列,另一方面减少了HBM读写量(推理是memory-bound),直接提速。
Q2:Multi-Head Attention相比Single-Head Attention的优势是什么?
秒答:多头让模型同时从不同子空间学习不同的注意力模式,信息更丰富。
展开:实验验证:可视化不同头的attention权重,发现有的头关注局部相邻词,有的头关注句法结构(动词-宾语),有的头关注长距离共指。如果只有一个头,这些模式被压缩到一个空间里会互相干扰。但注意:不是头越多越好,很多研究发现部分头是冗余的,可以剪枝。
Q3:如何把训练好的MHA模型转换为GQA模型?
秒答:把同一组内多个KV头的权重取平均,然后做少量继续训练来恢复精度。
展开 :LLaMA-2论文的做法------把 h h h个KV头分成 g g g组,每组内的KV权重取均值作为共享权重,然后用原始数据的5%做继续预训练。实测70B模型从MHA转GQA(8组),只需约5%训练量就能恢复到接近原始精度。
模块三:位置编码 | 让Transformer知道"顺序"
3.1 核心概念
为什么需要位置编码?
Self-Attention对输入做的是集合操作------如果打乱输入token的顺序,输出不变(置换不变性)。但语言是有顺序的:"我爱你"和"你爱我"含义完全不同。所以必须注入位置信息。
位置编码的三大流派
-
绝对位置编码:给每个位置一个固定或可学习的向量
- Sinusoidal(Transformer原始)
- Learnable(BERT、GPT-2)
-
相对位置编码:编码两个位置之间的距离
- T5 Relative Bias
- ALiBi
-
旋转位置编码(RoPE):巧妙地用绝对编码实现相对位置感知
- LLaMA、Qwen、DeepSeek等主流模型标配
3.2 原理推导
Sinusoidal位置编码
P E ( p o s , 2 i ) = sin ( p o s / 10000 2 i / d model ) PE_{(pos, 2i)} = \sin(pos / 10000^{2i/d_{\text{model}}}) PE(pos,2i)=sin(pos/100002i/dmodel)
P E ( p o s , 2 i + 1 ) = cos ( p o s / 10000 2 i / d model ) PE_{(pos, 2i+1)} = \cos(pos / 10000^{2i/d_{\text{model}}}) PE(pos,2i+1)=cos(pos/100002i/dmodel)
设计直觉:
- 不同维度使用不同频率的正弦/余弦波
- 低维度变化快(高频),高维度变化慢(低频)
- 任意位置 p o s + k pos+k pos+k的编码可以表示为 p o s pos pos位置编码的线性函数(线性外推性)
RoPE(旋转位置编码)
核心思想:不是把位置编码加到词向量上,而是把Q和K旋转一个与位置相关的角度。
对于位置 m m m处的Q向量 q q q,RoPE做如下变换(以2D为例):
f ( q , m ) = ( q 0 cos m θ − q 1 sin m θ q 0 sin m θ + q 1 cos m θ ) f(q, m) = \begin{pmatrix} q_0 \cos m\theta - q_1 \sin m\theta \\ q_0 \sin m\theta + q_1 \cos m\theta \end{pmatrix} f(q,m)=(q0cosmθ−q1sinmθq0sinmθ+q1cosmθ)
妙处在于: f ( q , m ) T f ( k , n ) = g ( q , k , m − n ) f(q, m)^T f(k, n) = g(q, k, m-n) f(q,m)Tf(k,n)=g(q,k,m−n)------点积只依赖相对位置 ( m − n ) (m-n) (m−n)!
用绝对位置编码实现了相对位置感知,而且对KV-Cache完全兼容(缓存的KV不需要随新token的加入而重新计算)。
ALiBi位置编码
完全不修改Q/K/V,而是在attention score上加一个与距离成正比的偏置:
score i j = q i T k j − m ⋅ ∣ i − j ∣ \text{score}_{ij} = q_i^T k_j - m \cdot |i - j| scoreij=qiTkj−m⋅∣i−j∣
其中 m m m是每个头不同的斜率(几何级数)。越远的token惩罚越大。
3.3 代码实现
python
class RotaryPositionEmbedding(nn.Module):
def __init__(self, d_k, max_len=8192, base=10000):
super().__init__()
# 计算频率: theta_i = 1 / base^(2i/d_k)
freqs = 1.0 / (base ** (torch.arange(0, d_k, 2).float() / d_k))
# 位置序列
positions = torch.arange(max_len).float()
# 外积得到角度矩阵: [max_len, d_k/2]
angles = torch.outer(positions, freqs)
# 计算cos和sin
self.register_buffer('cos_cached', angles.cos())
self.register_buffer('sin_cached', angles.sin())
def forward(self, x, seq_len):
# x: [batch, heads, seq_len, d_k]
cos = self.cos_cached[:seq_len].unsqueeze(0).unsqueeze(0)
sin = self.sin_cached[:seq_len].unsqueeze(0).unsqueeze(0)
# 将x分成两半,应用旋转
x1, x2 = x[..., ::2], x[..., 1::2]
rotated = torch.stack([
x1 * cos - x2 * sin,
x1 * sin + x2 * cos
], dim=-1).flatten(-2)
return rotated
3.4 工程实践
位置编码外推问题
模型训练时用4096长度,推理时输入8192会怎样?
| 方法 | 外推性 | 做法 |
|---|---|---|
| Learnable PE | 差(超出训练长度直接失败) | 不推荐 |
| Sinusoidal PE | 理论可外推,实际效果差 | 基本不用了 |
| RoPE 原始 | 较差(超出2x训练长度崩溃) | 需要额外处理 |
| RoPE + NTK-aware | 好 | 修改base参数 |
| RoPE + YaRN | 很好 | NTK + 注意力分布修正 |
| ALiBi | 较好(天然外推) | 线性偏置容易泛化 |
NTK-aware Scaling的核心思想:不是简单地缩放位置索引(Position Interpolation),而是缩放RoPE的基频,让高频信息保持不变、低频信息被"拉伸"。
3.5 面试考点精讲
Q1:RoPE相比Sinusoidal PE的优势是什么?
秒答:RoPE让注意力分数天然只依赖相对位置,和KV-Cache完美兼容,外推性也更好。
展开 :Sinusoidal PE是加到输入上的,Q和K点积后位置信息混在语义信息里,不够"干净"。RoPE直接在Q和K上做旋转,点积结果只依赖 m − n m-n m−n(相对位置),理论上更优雅。而且RoPE不影响V,缓存的KV可以直接复用。这就是为什么LLaMA、GPT-NeoX、Qwen、DeepSeek全都选了RoPE。
Q2:训练长度4K的模型,怎么扩展到32K?
秒答:用Position Interpolation(PI)或NTK-aware缩放修改RoPE参数,再做少量长文本数据继续训练。
展开 :PI的做法是把位置索引缩小到训练范围内( p o s ′ = p o s × L train / L target pos' = pos \times L_{\text{train}}/L_{\text{target}} pos′=pos×Ltrain/Ltarget),但这样高频信息被压缩。NTK-aware更聪明:调大base(如10000→160000),等效于只"拉伸"低频维度。YaRN进一步加了attention scaling因子,效果最好。CodeLlama用PI+16K数据继续训练,成功从4K扩展到100K。
Q3:ALiBi和RoPE各自的优缺点?
秒答:ALiBi外推性天然好但表达力有限,RoPE表达力更强但需要额外处理外推。
展开:ALiBi只是一个线性衰减偏置,不修改模型参数,简单直接。但研究表明ALiBi在超长文本上注意力分布过于集中在局部,远距离信息利用不够。RoPE虽然原生外推性差,但配合NTK/YaRN后效果更好。实际选型:绝大多数2024-2026的主流模型都选了RoPE。
模块四:LayerNorm与归一化 | 稳定训练的关键
4.1 核心概念
为什么需要Normalization?
深层网络中,每一层的输入分布会随着训练不断变化(Internal Covariate Shift),导致训练不稳定。Normalization把每层的输入"拉回"到稳定的分布(均值0、方差1),让梯度流更健康。
Batch Norm vs Layer Norm
| 特性 | BatchNorm | LayerNorm |
|---|---|---|
| 归一化方向 | 跨batch、在特征维度 | 跨特征、在单个样本内 |
| 依赖batch size | 是(batch太小效果差) | 否 |
| 适用场景 | CV、batch大的场景 | NLP、序列模型 |
| 推理时 | 用训练时的running mean/var | 当场计算 |
Transformer用LayerNorm因为:NLP中序列长度不同,batch内padding多,BN统计不稳定。
4.2 原理推导
LayerNorm公式
LayerNorm ( x ) = γ ⊙ x − μ σ 2 + ϵ + β \text{LayerNorm}(x) = \gamma \odot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta LayerNorm(x)=γ⊙σ2+ϵ x−μ+β
其中 μ \mu μ和 σ 2 \sigma^2 σ2是对单个样本的特征维度计算的均值和方差, γ \gamma γ和 β \beta β是可学习的缩放和偏移参数。
RMSNorm公式
RMSNorm ( x ) = γ ⊙ x 1 d ∑ i = 1 d x i 2 + ϵ \text{RMSNorm}(x) = \gamma \odot \frac{x}{\sqrt{\frac{1}{d}\sum_{i=1}^{d}x_i^2 + \epsilon}} RMSNorm(x)=γ⊙d1∑i=1dxi2+ϵ x
去掉了"减均值"和"加偏置"两步。计算量减少约10-15%,效果基本一致。
Pre-Norm vs Post-Norm
- Post-Norm (原始Transformer): x + Sublayer ( Norm ( x ) ) x + \text{Sublayer}(\text{Norm}(x)) x+Sublayer(Norm(x))...不对,原始是 Norm ( x + Sublayer ( x ) ) \text{Norm}(x + \text{Sublayer}(x)) Norm(x+Sublayer(x))
- Pre-Norm (现在主流): x + Sublayer ( Norm ( x ) ) x + \text{Sublayer}(\text{Norm}(x)) x+Sublayer(Norm(x))
Pre-Norm的残差连接从输入直接到输出,梯度流更通畅,训练更稳定,大模型几乎都用Pre-Norm。但Post-Norm理论上上限更高(因为梯度传播不会"绕过"子层)。
4.3 代码实现
python
class RMSNorm(nn.Module):
def __init__(self, d_model, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(d_model))
self.eps = eps
def forward(self, x):
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
return x / rms * self.weight
4.4 工程实践
现代大模型的Norm选择
| 模型 | Norm类型 | 位置 |
|---|---|---|
| BERT | LayerNorm | Post-Norm |
| GPT-2 | LayerNorm | Pre-Norm |
| LLaMA | RMSNorm | Pre-Norm |
| Qwen | RMSNorm | Pre-Norm |
| DeepSeek | RMSNorm | Pre-Norm |
| Gemma | RMSNorm + Pre/Post | 入口出口都加 |
4.5 面试考点精讲
Q1:为什么LLaMA等模型选择RMSNorm而不是LayerNorm?
秒答:RMSNorm去掉了减均值和加偏置,计算量减少约10-15%,但效果几乎一样。大模型训练时间长,这点节省累积起来很可观。
展开 :RMSNorm的论文(Zhang & Sennrich, 2019)证明了LayerNorm中"重新居中"(减均值)这一步对最终效果贡献很小,主要是"重新缩放"(除以标准差)在起作用。去掉均值计算后,不仅少一次reduce操作,还减少了一个可学习参数( β \beta β),在数十亿参数的模型上,训练效率的提升是实实在在的。
Q2:Pre-Norm和Post-Norm的优缺点?为什么大模型都用Pre-Norm?
秒答:Pre-Norm训练更稳定容易收敛,Post-Norm理论上限更高但难训练。大模型追求稳定性,选Pre-Norm。
展开:Post-Norm中梯度必须经过Norm层才能到达残差连接,深层网络中梯度可能被Norm"截断"。Pre-Norm中残差是一条"高速公路",梯度可以直通输入层。代价是Pre-Norm的每层输出被Norm"压住"了,表达能力有一定损失。DeepNorm试图结合两者优点------在Post-Norm基础上修改残差系数,使梯度更稳定。
Q3:Batch Normalization为什么不适合Transformer?
秒答:NLP中batch内序列长度不一、padding比例不同,跨batch统计不稳定。而且推理时batch可能很小甚至为1。
展开:BN要在batch维度算均值方差,但NLP的batch通常很小(8/16/32),而且每个样本的有效长度不同,padding位不应参与统计。另外推理时用的running statistics是训练时累计的,如果训练和推理的数据分布差异大就会出问题。LayerNorm在特征维度上计算,每个token独立归一化,完全避免了这些问题。
模块五:Feed-Forward Network与激活函数 | Transformer的"记忆库"
5.1 核心概念
FFN在Transformer中的角色
每个Transformer Block = Attention + FFN。如果说Attention是"信息路由"(决定哪些信息该被关注),那FFN就是"信息加工"(对关注到的信息做非线性变换)。
有研究将FFN比作"键值记忆"------FFN的第一层权重是"键"(匹配输入模式),第二层权重是"值"(对应的输出知识)。这解释了为什么知识编辑(ROME/MEMIT)可以通过修改FFN权重来更新模型知识。
标准FFN结构
FFN ( x ) = W 2 ⋅ Act ( W 1 x + b 1 ) + b 2 \text{FFN}(x) = W_2 \cdot \text{Act}(W_1 x + b_1) + b_2 FFN(x)=W2⋅Act(W1x+b1)+b2
- W 1 : d model → d ff W_1: d_{\text{model}} \to d_{\text{ff}} W1:dmodel→dff(上投影,通常 d ff = 4 d model d_{\text{ff}} = 4d_{\text{model}} dff=4dmodel)
- W 2 : d ff → d model W_2: d_{\text{ff}} \to d_{\text{model}} W2:dff→dmodel(下投影)
5.2 原理推导
SwiGLU的计算
SwiGLU ( x ) = ( Swish ( x W 1 ) ⊙ x W 3 ) W 2 \text{SwiGLU}(x) = (\text{Swish}(xW_1) \odot xW_3) W_2 SwiGLU(x)=(Swish(xW1)⊙xW3)W2
这里有三个权重矩阵!为了保持总参数量和标准FFN( 2 × d model × 4 d model 2 \times d_{\text{model}} \times 4d_{\text{model}} 2×dmodel×4dmodel)一致,SwiGLU把 d ff d_{\text{ff}} dff从 4 d 4d 4d调整为 8 3 d \frac{8}{3}d 38d(约 2.67 d 2.67d 2.67d),三个矩阵的总参数量: 3 × d model × 8 3 d model = 8 d model 2 3 \times d_{\text{model}} \times \frac{8}{3}d_{\text{model}} = 8d_{\text{model}}^2 3×dmodel×38dmodel=8dmodel2,和原来的 8 d model 2 8d_{\text{model}}^2 8dmodel2相同。
Swish激活函数
Swish ( x ) = x ⋅ σ ( x ) \text{Swish}(x) = x \cdot \sigma(x) Swish(x)=x⋅σ(x)
其中 σ \sigma σ是sigmoid函数。Swish是光滑的非单调函数,在 x < 0 x<0 x<0区域不完全截断(不像ReLU),允许少量负值通过,有助于梯度流动。
5.3 代码实现
python
class SwiGLU_FFN(nn.Module):
def __init__(self, d_model, d_ff=None):
super().__init__()
d_ff = d_ff or int(d_model * 8 / 3)
# 实际中通常round到64/128的倍数
d_ff = ((d_ff + 63) // 64) * 64
self.w1 = nn.Linear(d_model, d_ff, bias=False)
self.w3 = nn.Linear(d_model, d_ff, bias=False) # gate
self.w2 = nn.Linear(d_ff, d_model, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
5.4 工程实践
激活函数演进
| 激活函数 | 公式 | 使用模型 |
|---|---|---|
| ReLU | max ( 0 , x ) \max(0, x) max(0,x) | 原始Transformer |
| GeLU | x ⋅ Φ ( x ) x \cdot \Phi(x) x⋅Φ(x) | BERT, GPT-2 |
| SwiGLU | Swish ( x W 1 ) ⊙ x W 3 \text{Swish}(xW_1) \odot xW_3 Swish(xW1)⊙xW3 | LLaMA, Qwen, DeepSeek |
| GeGLU | GeLU ( x W 1 ) ⊙ x W 3 \text{GeLU}(xW_1) \odot xW_3 GeLU(xW1)⊙xW3 | Gemma |
5.5 面试考点精讲
Q1:FFN的中间维度为什么是 4 d model 4d_{\text{model}} 4dmodel?SwiGLU为什么改成 8 / 3 d 8/3d 8/3d?
秒答 : 4 d 4d 4d是经验值,在精度和计算量之间取得平衡。SwiGLU因为多了一个gate矩阵(三个权重),为了总参数量不变所以缩小到 8 / 3 d 8/3d 8/3d。
Q2:为什么现在主流模型都用SwiGLU而不是GeLU或ReLU?
秒答:SwiGLU的GLU门控机制让模型可以选择性地通过信息,比简单的逐元素激活函数效果更好。PaLM论文实验证明SwiGLU在同参数量下一致性地优于ReLU/GeLU。
Q3:有人说FFN是Transformer的"记忆模块",如何理解?
秒答 :FFN的参数是固定的权重矩阵,存储的是训练数据中学到的"知识"。 W 1 W_1 W1的每一行是一个"模式检测器",对应的 W 2 W_2 W2列是被触发时输出的"知识"。
展开 :Geva等人(2021)的研究显示,FFN的第一层做的是"模式匹配"(哪些输入模式会激活这个神经元),第二层做的是"知识输出"(激活后应该输出什么)。这就是为什么ROME等知识编辑方法可以通过修改FFN权重来精确更新某个事实知识------定位到存储该知识的神经元,修改对应的 W 2 W_2 W2行即可。
模块六:Encoder-Decoder架构与Decoder-Only
6.1 核心概念
三种架构范式:
| 架构 | 注意力类型 | 代表模型 | 适合任务 |
|---|---|---|---|
| Encoder-only | 双向Self-Attention | BERT, RoBERTa | 分类、NER、语义理解 |
| Encoder-Decoder | Encoder双向 + Decoder因果 + Cross-Attention | T5, BART, Flan-T5 | 翻译、摘要、seq2seq |
| Decoder-only | 因果Self-Attention(左到右) | GPT, LLaMA, Qwen | 生成、对话、通用 |
为什么Decoder-only一统天下?
- 训练效率:Causal LM每个位置都能产生训练信号(预测下一个token),而MLM只有15%的mask位置有信号。
- KV Cache友好:自回归生成时已生成token的KV不变可cache。Encoder的双向Attention没法cache。
- Scaling Law的赢家:大规模实验表明Decoder-only的scaling性能最好。
- Zero-shot泛化:因果语言模型天然适合"续写"。
Prefix LM vs Causal LM:
- Causal LM:纯左到右,每个位置只能看前面。GPT系列。
- Prefix LM:prefix部分双向可见(像Encoder),生成部分因果。GLM、U-PaLM。
6.2 原理推导
Decoder-only的"低秩"论证:
苏剑林的分析:双向Attention的注意力矩阵是对称性较强的矩阵,由于softmax归一化约束,实际上是"低秩"的------很多行长得很像。低秩意味着不同token获得的上下文表示趋于相似,限制了表达能力。Decoder-only的因果掩码打破了这种对称性,注意力矩阵更"满秩"。
参数量估算:
对于L层、d维、V词表的Decoder-only模型(SwiGLU FFN,d_ff=8d/3):
- Embedding: V*d
- 每层Attention: 4dd (QKV+O)
- 每层FFN: 3d d_ff = 8dd
- 每层Norm: 4*d
- 总计约 Vd + L (12dd)
LLaMA-7B验证:V=32000, d=4096, L=32,约6.74B参数。
6.3 代码实现
python
class CausalLM(nn.Module):
def __init__(self, vocab_size, d_model, n_layers, n_heads, max_seq_len):
super().__init__()
self.token_emb = nn.Embedding(vocab_size, d_model)
self.layers = nn.ModuleList([
TransformerBlock(d_model, n_heads, d_ff=int(8*d_model/3))
for _ in range(n_layers)
])
self.norm = RMSNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
# 权重共享
self.lm_head.weight = self.token_emb.weight
mask = torch.tril(torch.ones(max_seq_len, max_seq_len))
self.register_buffer("causal_mask", mask)
def forward(self, input_ids):
B, T = input_ids.shape
x = self.token_emb(input_ids)
mask = self.causal_mask[:T, :T]
for layer in self.layers:
x = layer(x, mask=mask)
x = self.norm(x)
return self.lm_head(x)
6.4 工程实践
Embedding和LM Head权重共享(Weight Tying):输入Embedding矩阵和输出LM Head共享同一个权重。节省V*d参数,且让输入输出在同一语义空间中。几乎所有现代模型都用。
FLOPs估算经验公式:
- 前向FLOPs约等于2倍参数量乘以序列长度
- 训练FLOPs约等于6倍参数量乘以总训练token数(前向+反向+激活重计算)
6.5 面试考点精讲(5题)
Q1. ⭐ [高频] 为什么现在的LLM几乎都采用Decoder-only架构?
一句话秒答:训练效率高(每个token都产生loss信号)、推理KV Cache友好、scaling性能最好、zero-shot泛化能力强。
展开来说:四个核心原因:(1) Causal LM在每个位置都有loss,数据利用率比MLM高6-7倍;(2) KV Cache让自回归生成高效,双向Encoder不能cache;(3) Chinchilla、PaLM的大规模实验验证Decoder-only是scaling最优架构;(4) 所有任务都可以转成"续写"格式,不需要任务专属头。
面试加分:补一个反面------"Encoder-Decoder在特定任务(翻译、改写)上还有优势,Flan-T5在指令跟随上也不差。只是大模型的趋势是Decoder-only。"
Q2. ⭐⭐ [字节] Encoder-only、Encoder-Decoder、Decoder-only三种架构分别适合什么任务?
一句话秒答:Encoder-only适合理解(分类/NER),Encoder-Decoder适合seq2seq(翻译/摘要),Decoder-only适合生成(对话/续写),但现在Decoder-only通过足够大的规模几乎通吃。
展开来说:关键转折点是GPT-3/ChatGPT证明了Decoder-only足够大+指令微调后,在理解任务上也不比BERT差。于是"一个架构通吃"的路线胜出。
面试加分:补一句------"还有Prefix LM(如GLM),prefix部分双向、生成部分因果,试图兼顾两者优点但最终也没成主流。"
Q3. ⭐⭐ [高频] Transformer的Embedding权重和LM Head权重为什么可以共享?
一句话秒答:输入Embedding把token映射到向量空间,LM Head把向量映射回token------两者是反操作,共享权重让输入和输出在同一语义空间。
展开来说:LM Head的logits = h * E^T,即隐藏状态和每个词向量的点积,等价于在词向量空间做最近邻搜索。好处:节省V*d参数、训练信号共享、语义一致。
面试加分:提一句------"权重共享在大词表模型中更有价值。Qwen词表150K,共享能省600M参数。但BLOOM选择不共享,因为不共享训练初期收敛更快。"
Q4. ⭐⭐ [面经] Prefix LM和Causal LM的区别是什么?
一句话秒答:Causal LM全程只能看左边(严格下三角mask),Prefix LM的前缀部分双向可见、后续部分才是因果的。
展开来说:区别只在attention mask------Causal LM是严格下三角,Prefix LM在prefix区域是全1(双向)。GLM就是Prefix LM的代表,结合了BERT的双向理解和GPT的生成能力。
Q5. ⭐⭐⭐ [腾讯] Transformer模型的参数量如何计算?LLaMA-7B的参数量怎么得来的?
一句话秒答 :总参数约等于Vd + 12 Ldd。LLaMA-7B:V=32K, d=4096, L=32,约6.74B。
展开来说:逐项:Embedding 131M + 每层Attention(QKV+O) 67M + 每层FFN(SwiGLU) 135M + 每层Norm 8K。32层总计6.48B + Embedding 0.13B + Final Norm 4K = 约6.74B。
面试加分 :补FLOPs公式------"训练FLOPs约67B 1T=4.2e22。前向FLOPs约2参数量seq_len。"
模块七:Transformer训练技巧
7.1 核心概念
学习率调度:Transformer训练中最关键的超参不是学习率本身,而是学习率的变化曲线。
现代大模型标准做法:Warmup + Cosine Decay
- 前2000步线性warmup到峰值学习率
- 然后余弦衰减到峰值的1/10
为什么需要Warmup? 训练初期参数随机,梯度方向noisy。大学习率会把参数推飞。Warmup让模型先"热身",等梯度方向稳定后再加大步子。
优化器:AdamW是标准选择。关键超参:beta1=0.9, beta2=0.95, weight_decay=0.1。
训练稳定性trick:
- Gradient Clipping:梯度裁剪到max_norm=1.0
- Mixed Precision:BF16前向,FP32累积更新
- Gradient Accumulation:小batch多步累积,等效大batch
7.2 原理推导
BF16 vs FP16:
| 格式 | 指数位 | 尾数位 | 数值范围 | 特点 |
|---|---|---|---|---|
| FP32 | 8 | 23 | 3.4e38 | 标准精度 |
| FP16 | 5 | 10 | 65504 | 范围小,易溢出 |
| BF16 | 8 | 7 | 3.4e38 | 范围大,精度略低 |
BF16的指数位和FP32相同,数值范围一样大,训练时不容易出NaN。FP16指数位只有5位,需要配合loss scaling。A100以后BF16成默认选择。
7.3 代码实现
python
import math
import torch
def get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps):
def lr_lambda(step):
if step < warmup_steps:
return step / max(1, warmup_steps)
progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
# 使用示例
optimizer = torch.optim.AdamW(
model.parameters(), lr=3e-4,
betas=(0.9, 0.95), weight_decay=0.1
)
scheduler = get_cosine_schedule_with_warmup(optimizer, 2000, 100000)
7.4 工程实践
Loss Spike排查与处理:
- 数据质量:检查spike前后的batch是否含异常文本
- 学习率:是否在lr变化拐点
- 梯度:监控每层梯度范数
- 数值:FP16容易Inf/NaN
处理:轻微spike自愈不管;严重spike回滚checkpoint+跳过问题数据。PaLM训练中约遇20次spike,都是回滚解决。
Gradient Checkpointing:用时间换空间------前向时不保存中间激活,反向时重新计算。显存省60-70%,计算增加30%。
7.5 面试考点精讲(5题)
Q1. ⭐ [高频] 为什么Transformer训练要用Learning Rate Warmup?
一句话秒答:训练初期梯度方差大、Adam的二阶动量估计不准,大学习率会让参数发散。Warmup让优化器先积累准确的统计量。
展开来说:两个原因:(1) Adam的二阶动量v_t初始很小,bias-correction放大后有效学习率趋近无穷;(2) Post-Norm在训练初期梯度方差大。Pre-Norm可以不用warmup(GPT-2论文验证)。
面试加分:提一句"LION、Sophia等新优化器号称不需要warmup,但主流实践中大家还是加warmup。"
Q2. ⭐⭐ [字节] BF16和FP16有什么区别?为什么大模型训练优先选BF16?
一句话秒答:BF16指数位多(8位)数值范围大不易溢出,FP16指数位少(5位)易出NaN。大模型中"不爆炸"比"精度高"更重要。
展开来说:FP16上限65504,训练中gradient/activation经常超过这个值,需要loss scaling。BF16范围和FP32一样,几乎不可能溢出,不需要loss scaling。代价是精度略低(尾数7位vs10位),但对模型质量影响很小。
面试加分:提一句"FP8(E4M3/E5M2)是下一代格式,H100的FP8比BF16快2倍。推理已可用,训练还在成熟中。"
Q3. ⭐⭐ [高频] 什么是Gradient Accumulation?解决什么问题?
一句话秒答:显存不够装大batch时,多个小batch梯度累加再更新------等效大batch训练。
展开来说:等效batch_size = micro_batch * accumulation_steps * num_gpus。LN对batch维度独立,所以Transformer用gradient accumulation完全安全(不像BN有统计量问题)。
面试加分:提一句"等效batch太大可能导致收敛到更差的局部最优。LLM的甜点范围是2M到4M tokens。"
Q4. ⭐⭐ [面经] 现代大模型还用Dropout吗?为什么?
一句话秒答:不用了。LLaMA、GPT-3、PaLM的Dropout率都设为0。因为数据量够大不怕过拟合,weight decay就够了。
展开来说:原始Transformer有三个Dropout位置:Attention权重后、FFN中间、残差连接前。但现代大模型的训练数据足够大,过拟合风险小。Dropout在分布式训练中还引入额外通信开销。如果需要正则化,weight_decay=0.1是更好的选择。
Q5. ⭐⭐⭐ [阿里] 训练大模型时遇到Loss Spike怎么排查和处理?
一句话秒答:查数据质量、查学习率、查梯度范数、查数值溢出。处理:回滚checkpoint+跳过问题数据。
展开来说:排查清单:(1) 数据层面------spike前后的batch是否含异常文本;(2) 学习率------是否在warmup结束的拐点;(3) 梯度------某层梯度是否突变;(4) 数值------是否有Inf/NaN。
处理方法:轻微spike(涨了又降)不管;严重spike回滚100步+跳过问题batch。PaLM训练遇20次spike都这么解决的,问题batch主要是代码和数学公式。
面试加分:提一个预防措施------"训练时要实时监控每层的gradient norm和activation norm,设置告警阈值,在spike刚出现时就自动暂停回滚。"
模块八:高效Attention与推理优化
8.1 核心概念
为什么需要高效Attention? 标准Attention的O(n^2)复杂度在长序列场景下是致命瓶颈。n=128K时,注意力矩阵需要60+GB显存。
主流优化方向:
| 方法 | 核心思路 | 复杂度 | 代表 |
|---|---|---|---|
| FlashAttention | IO优化(tiling+kernel fusion) | O(n^2)但实际快2-4x | FA1/FA2/FA3 |
| 稀疏Attention | 只计算部分注意力对 | O(nsqrt(n))到O(nk) | Longformer, BigBird |
| 滑动窗口 | 每个token只看最近w个 | O(n*w) | Mistral, Gemma |
| 线性Attention | kernel trick去掉softmax | O(n*d) | Linear Transformer |
| PagedAttention | KV Cache内存管理优化 | 不改复杂度,减少碎片 | vLLM |
8.2 原理推导
FlashAttention的核心insight:
标准Attention需要把n*n的注意力矩阵从GPU的HBM(高带宽内存)读写多次。FlashAttention通过tiling(分块)把计算分成小块,每块在SRAM(片上高速缓存)中完成所有计算,只把最终结果写回HBM。
关键挑战:softmax需要知道全局最大值来做数值稳定(减去max),但tiling时每块只能看到局部。解决方案:Online Softmax------维护一个running max和running sum,每处理一个新块就更新。
数学上:对于两块S1和S2拼接的softmax:
- m_new = max(m1, m2)
- l_new = l1 * exp(m1 - m_new) + l2 * exp(m2 - m_new)
- output = (o1 * l1 * exp(m1-m_new) + o2 * l2 * exp(m2-m_new)) / l_new
这样就能逐块处理而不需要存储整个n*n矩阵。
Sliding Window Attention:
每个token只attend到最近w个token。注意力矩阵从full nn变成band-diagonal nw。
直觉:对于第L层来说,token_i的有效感受野是wL(每层窗口w,L层堆叠后覆盖wL)。Mistral-7B用w=4096,32层后有效感受野=131072>max_seq_len,理论上信息可以传到任意远。
PagedAttention(vLLM核心):
KV Cache的内存管理问题:不同请求的序列长度不同,固定分配max_len的cache会造成大量碎片。PagedAttention借鉴OS的分页内存------把KV Cache切成固定大小的"页"(如16个token一页),按需分配,不同请求可以共享物理页。
好处:KV Cache利用率从50-60%提升到>95%,同等显存下可服务的并发请求翻倍。
8.3 代码实现
python
# FlashAttention使用(调用而非实现,因为需要CUDA kernel)
from flash_attn import flash_attn_func
# Q, K, V: (batch, seq_len, n_heads, d_head)
# 注意:flash_attn的输入shape和标准实现不同
output = flash_attn_func(
Q, K, V,
dropout_p=0.0,
causal=True, # 因果mask
softmax_scale=1.0 / math.sqrt(d_head)
)
# output: (batch, seq_len, n_heads, d_head)
# Sliding Window Attention(简化版)
def sliding_window_attention(Q, K, V, window_size):
B, H, T, D = Q.shape
# 创建band mask: 只有|i-j| <= window_size的位置为True
i = torch.arange(T).unsqueeze(1) # (T, 1)
j = torch.arange(T).unsqueeze(0) # (1, T)
mask = ((i - j).abs() <= window_size) & (j <= i) # causal + window
mask = mask.float().unsqueeze(0).unsqueeze(0) # (1, 1, T, T)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(D)
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1)
return torch.matmul(attn, V)
8.4 工程实践
FlashAttention的使用注意事项:
- FA2比FA1快约2x,支持更多head dim(如128、256)
- FA3在H100上进一步优化,支持FP8
- 必须安装对应CUDA版本的flash-attn包
- HuggingFace的Transformers已经内置FA2支持:
model = AutoModel.from_pretrained(..., attn_implementation="flash_attention_2")
推理优化的完整栈:
模型层面: GQA/MLA减少KV头 -> 量化(W4A16) -> 稀疏注意力
系统层面: FlashAttention -> PagedAttention -> Continuous Batching
硬件层面: Tensor Core -> NVLink -> InfiniBand
KV Cache显存估算实例:
LLaMA-2-70B(80层GQA-8,d_k=128,BF16):
- 每token KV: 2 * 80 * 8 * 128 * 2 bytes = 320KB
- 4K序列: 320KB * 4096 = 1.28GB
- batch=32: 1.28GB * 32 = 41GB(超过单卡80GB的一半!)
所以长序列+大batch下KV Cache是显存大户,GQA/MLA和量化KV Cache至关重要。
8.5 面试考点精讲(5题)
Q1. ⭐⭐ [高频] FlashAttention的核心思想是什么?它为什么能加速Attention计算?
一句话秒答:FlashAttention通过tiling(分块)把Attention计算搬到GPU的SRAM中完成,大幅减少HBM读写次数------是IO优化而非算法复杂度优化。
展开来说:GPU有两级存储:HBM(高带宽但慢,80GB)和SRAM(极快但小,20MB)。标准Attention需要把n*n的注意力矩阵在HBM和SRAM之间搬运多次(写出S矩阵、读回做softmax、写出P矩阵、读回与V相乘)。
FlashAttention把Q/K/V分成小块(tile),每块在SRAM中完成所有计算(QK^T -> softmax -> *V),只把最终结果写回HBM。用Online Softmax解决了分块softmax的问题。
结果:Attention的FLOPs没变(还是O(n2d)),但HBM访问量从O(n2+nd)降到O(n^2d/M)(M是SRAM大小),实际速度快2-4x。
面试加分:提一句"FlashAttention还有一个bonus:因为不需要存储n*n的注意力矩阵,空间复杂度从O(n^2)降到O(n)------这对长序列至关重要。"
Q2. ⭐⭐ [字节] PagedAttention是什么?它解决了什么问题?
一句话秒答:PagedAttention把KV Cache切成固定大小的"页"按需分配,解决了不同长度请求共享GPU显存时的碎片问题------类似OS的虚拟内存。
展开来说:没有PagedAttention时,每个请求预分配max_seq_len的KV Cache空间。如果max=8K但实际只用了500个token,7500个位置的显存就浪费了。多请求时碎片率可达40-50%。
PagedAttention的做法:把KV Cache切成固定大小的page(如每页16个token),用page table管理。新生成一个token时按需分配新页,序列结束时释放页。不同请求可以共享空同的物理页。
vLLM就是基于PagedAttention构建的推理引擎,KV Cache利用率从50-60%提升到>95%。
面试加分:提一句"PagedAttention还支持copy-on-write------parallel sampling时多个生成分支可以共享prompt的KV Cache页,只在产生分歧时才复制。这对beam search和parallel decoding很重要。"
Q3. ⭐⭐ [腾讯] Sliding Window Attention的原理是什么?Mistral是怎么用的?
一句话秒答:每个token只attend到最近w个token。通过L层堆叠,有效感受野扩展到w*L,理论上覆盖任意长度。
展开来说:Mistral-7B用w=4096的滑动窗口。32层后有效感受野=4096*32=131072。这意味着虽然每层只看4K的局部窗口,但信息可以通过层层传递覆盖到128K+的范围。
好处:KV Cache大小固定为w(不随序列长度增长),推理显存可控。对比标准Attention的KV Cache随n线性增长。
限制:信息传递是"间接的"------远距离信息需要通过多层中转,不如full attention的直接连接准确。所以Mistral在某些需要精确长距离引用的任务上不如full attention的模型。
Q4. ⭐⭐⭐ [高频] 线性Attention的原理是什么?为什么它能把复杂度从O(n^2)降到O(n)?
一句话秒答 :线性Attention用kernel trick把softmax(QKT)V分解为phi(Q)(phi(K)T V)------改变矩阵乘法顺序,从先算nn降到先算dd。
展开来说 :标准Attention计算顺序是(softmax(QK^T))V,先算nn的注意力矩阵。
线性Attention的关键:找一个特征映射phi,使得exp(q*k)约等于phi(q)*phi(k)。然后:
Attention = phi(Q) * (phi(K)^T * V)
先算phi(K)^T * V得到dd矩阵(O(n dd)),再左乘phi(Q)得到n d的输出(O(ndd))。总复杂度O(n*d^2)------对n线性!
代价:phi的近似质量直接决定模型质量。实践中线性Attention的效果一般不如标准Attention,尤其在需要"尖锐注意力"(sharp attention)的任务上。
最新进展:Mamba/State Space Model走了另一条路------用线性RNN+选择机制替代Attention,在长序列建模上效果不错。
Q5. ⭐⭐⭐ [阿里] 比较FlashAttention、稀疏Attention、线性Attention三种优化方式的本质区别。
一句话秒答:FlashAttention是IO优化(不改算法,改计算方式),稀疏Attention是算法优化(只算部分注意力对),线性Attention是数学优化(用核近似降低复杂度)。
展开来说:
| 维度 | FlashAttention | 稀疏Attention | 线性Attention |
|---|---|---|---|
| 算法复杂度 | O(n^2d) 不变 | O(nwd) 或 O(n*sqrt(n)*d) | O(n*d^2) |
| 空间复杂度 | O(n) | O(n*w) | O(n*d) |
| 质量损失 | 无(数值等价) | 有(信息截断) | 有(核近似误差) |
| 适用场景 | 通用加速 | 已知局部性强的任务 | 超长序列、RNN替代 |
| 工程难度 | 高(需要CUDA kernel) | 中(改mask即可) | 低(纯数学变换) |
实际工程中最常用的组合:FlashAttention(必开)+ GQA/MLA(减少KV头)+ 量化KV Cache。稀疏和线性Attention在特定场景有用但不是主流。
面试加分:提一句"Mamba2(2024)尝试统一Attention和SSM,提出了SSD(Structured State Space Duality)框架,证明线性Attention和状态空间模型在数学上是等价的。这可能是未来序列建模的方向。"
Q6. ⭐⭐⭐ [高频] KV Cache显存怎么算?70B模型、batch=8、seq_len=32K需要多少KV Cache显存?
这是大厂面试的高频计算题,考察对Transformer推理机制的理解。
公式:KV Cache显存 = 2 x num_layers x num_kv_heads x head_dim x seq_len x batch_size x dtype_bytes
以LLaMA-70B为例:
- num_layers = 80, num_kv_heads = 8 (GQA), head_dim = 128, dtype = FP16 (2 bytes)
- 单条请求:2 x 80 x 8 x 128 x 32768 x 2 = 10 GB
- batch=8:10 x 8 = 80 GB -- 仅KV Cache就占满一张A100
面试追问与应答:
- "怎么减少KV Cache?" --> GQA/MQA减少kv_heads;量化KV Cache到INT8/FP8(减半);PagedAttention减少碎片浪费
- "KV Cache和模型权重谁占显存多?" --> 长序列+大batch时KV Cache远超权重(70B FP16权重=140GB,但KV Cache可达数百GB)
- "为什么说KV Cache是推理的第一瓶颈?" --> 因为它随seq_len线性增长、随batch线性增长,是限制吞吐量和最大序列长度的直接因素
本章总结
| 模块 | 核心考点 | 必记公式/概念 |
|---|---|---|
| 模块一 | Self-Attention | Attention(Q,K,V)=softmax(QK^T/sqrt(d_k))V |
| 模块二 | MHA/MQA/GQA/MLA | KV Cache大小 = 2L gd_kn*B |
| 模块三 | 位置编码 | RoPE旋转实现相对位置,YaRN外推 |
| 模块四 | FFN/SwiGLU | SwiGLU = Swish(xW1) * xW3, d_ff=8d/3 |
| 模块五 | LayerNorm/RMSNorm | Pre-Norm更稳定,RMSNorm更快 |
| 模块六 | 架构选型 | Decoder-only主流,参数量约12Ld^2 |
| 模块七 | 训练技巧 | Warmup+Cosine, BF16, Gradient Clip |
| 模块八 | 高效Attention | FlashAttention(IO优化), PagedAttention(内存管理), KV Cache显存计算 |
全章共覆盖 41道 面试真题,涵盖概念理解、数学推导、代码实现、工程实践四个层次。