【2026大模型面试圣经】(1)Transformer全解析 | 从Self-Attention到Multi-Head,一文通关Transformer面试

2026大模型面试圣经(1):Transformer全解析 | 从Self-Attention到Multi-Head,一文通关Transformer面试

定位:本章是整个面试圣经的基石。Transformer是所有大模型的"操作系统",不搞懂它后面全白搭。

目标:看完本章,你能从零推导Attention公式,手写Multi-Head Attention,画出完整架构图,回答任何Transformer相关面试题。


模块一:Self-Attention机制 | 注意力的本质

1.1 核心概念

什么是Attention?

一句话:Attention就是"加权求和"------对输入序列中的每个位置,根据"相关性"分配不同的权重,然后做加权聚合。

传统RNN/LSTM处理序列时,信息必须沿着时间步一步步传递,长距离信息经过多次传递后会衰减甚至消失。Attention机制让任意两个位置之间可以"直接对话",路径长度为O(1)。

Self-Attention vs Cross-Attention

特性 Self-Attention Cross-Attention
Q/K/V来源 全部来自同一序列 Q来自一个序列,K/V来自另一个序列
典型应用 Encoder内部、Decoder的Masked Self-Attention Decoder中的Encoder-Decoder Attention
作用 序列内部建模 序列间信息交互

Q、K、V的直觉理解

  • Query(查询):当前位置"想找什么信息"
  • Key(键):每个位置"有什么信息可以提供"
  • Value(值):每个位置"实际提供的信息内容"

工作流程:Q和K做匹配得到注意力权重,权重乘以V得到输出。类似于"搜索引擎"------Q是搜索词,K是网页标题,V是网页内容。

1.2 原理推导

Scaled Dot-Product Attention公式

Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V

推导过程(面试必考)

  1. 计算注意力分数 : S = Q K T S = QK^T S=QKT,其中 Q ∈ R n × d k Q \in \mathbb{R}^{n \times d_k} Q∈Rn×dk, K ∈ R n × d k K \in \mathbb{R}^{n \times d_k} K∈Rn×dk,得到 S ∈ R n × n S \in \mathbb{R}^{n \times n} S∈Rn×n
  2. 缩放 : S scaled = S d k S_{\text{scaled}} = \frac{S}{\sqrt{d_k}} Sscaled=dk S
  3. Softmax归一化 : α = softmax ( S scaled ) \alpha = \text{softmax}(S_{\text{scaled}}) α=softmax(Sscaled),使每行和为1
  4. 加权求和 : Output = α V \text{Output} = \alpha V Output=αV

为什么要除以 d k \sqrt{d_k} dk ?(高频考点)

假设Q和K的每个元素都是独立的均值为0、方差为1的随机变量,那么:

  • q ⋅ k = ∑ i = 1 d k q i k i q \cdot k = \sum_{i=1}^{d_k} q_i k_i q⋅k=∑i=1dkqiki
  • 每项 q i k i q_i k_i qiki的均值为0,方差为 Var ( q i ) Var ( k i ) = 1 \text{Var}(q_i)\text{Var}(k_i) = 1 Var(qi)Var(ki)=1
  • 所以 q ⋅ k q \cdot k q⋅k的方差为 d k d_k dk( d k d_k dk个独立随机变量之和)
  • 当 d k d_k dk很大时(如64、128),点积值的量级会很大
  • 大数值输入softmax后,输出接近one-hot分布,梯度几乎为0
  • 除以 d k \sqrt{d_k} dk 后方差变回1,softmax输出分布更平滑,梯度更健康

1.3 代码实现

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_k, dropout=0.1):
        super().__init__()
        self.d_k = d_k
        self.dropout = nn.Dropout(dropout)

    def forward(self, Q, K, V, mask=None):
        # Q, K, V: [batch, heads, seq_len, d_k]
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        output = torch.matmul(attn_weights, V)
        return output, attn_weights

代码要点

  • mask == 0的位置填充-inf,softmax后变为0,实现遮蔽
  • dropout在attention权重上做,而不是在scores上
  • 返回attention权重方便可视化和分析

1.4 工程实践

padding mask vs causal mask

  1. Padding Mask:对batch中不同长度的序列做padding,padding位置的attention权重应为0

    python 复制代码
    # padding_mask: [batch, 1, 1, seq_len],padding位为0
    scores = scores.masked_fill(padding_mask == 0, float('-inf'))
  2. Causal Mask(因果掩码):Decoder中用,防止看到未来信息

    python 复制代码
    causal_mask = torch.tril(torch.ones(seq_len, seq_len))
    # 上三角为0,对角线及下三角为1
  3. 组合使用:实际中两种mask通常需要组合

    python 复制代码
    final_mask = padding_mask & causal_mask  # 逐元素与

1.5 面试考点精讲

Q1:Self-Attention为什么需要Q、K、V三个矩阵?用一个或两个行不行?

秒答:三个矩阵各司其职------Q负责"提问",K负责"被匹配",V负责"提供信息"。

展开 :如果只用一个矩阵( X ⋅ X T X \cdot X^T X⋅XT),那么每个位置和自己的相似度最高,attention退化为近似单位矩阵。用Q和K两个不同的投影,可以让"提问视角"和"被查询视角"不同,大大增加表达能力。V和K分开是因为"匹配标准"和"实际传递的信息"可以是不同的------就像你查字典时,按拼音(K)查找,但实际获取的是释义(V)。

Q2:Attention的时间复杂度和空间复杂度各是多少?

秒答 :时间 O ( n 2 d ) O(n^2 d) O(n2d),空间 O ( n 2 ) O(n^2) O(n2),其中 n n n是序列长度, d d d是特征维度。

展开 : Q K T QK^T QKT的矩阵乘法是 O ( n 2 d ) O(n^2 d) O(n2d),生成的 n × n n \times n n×n注意力矩阵要存储所以空间是 O ( n 2 ) O(n^2) O(n2)。这就是为什么标准Transformer处理长序列( n > 8 K n > 8K n>8K)会很慢------FlashAttention、稀疏Attention等都是在解决这个 O ( n 2 ) O(n^2) O(n2)问题。

Q3:为什么是点积Attention而不是加性Attention?

秒答:效果差不多,但点积可以用矩阵乘法硬件加速,实际速度更快。

展开 :加性Attention用一层MLP计算相关性: score = v T tanh ⁡ ( W 1 q + W 2 k ) \text{score} = v^T \tanh(W_1 q + W_2 k) score=vTtanh(W1q+W2k),表达能力理论上更强。但点积 q ⋅ k q \cdot k q⋅k可以直接用高度优化的GEMM库(cuBLAS),在GPU上效率高得多。在 d k d_k dk较大时加性Attention略优,但加了 d k \sqrt{d_k} dk 缩放后差距消失。

Q4:Softmax Attention存在"注意力汇聚"(Attention Sink)现象,这是什么?

秒答:模型倾向于把大量注意力分配给第一个token(通常是BOS/CLS),即使它语义上不重要。

展开:这是2023年发现的现象(StreamingLLM论文)。原因是softmax要求所有权重之和为1,当模型"不确定该关注哪里"时,会把权重"倒"到第一个位置作为"垃圾桶"。实际影响:做KV-Cache裁剪时不能删掉最前面的几个token,否则性能暴跌。StreamingLLM的解决方案:保留前几个sink tokens + 滑动窗口。

1.6 【大厂真题 - 字节/DeepSeek高频】

真题 1:字节跳动算法岗------"既然你提到了KV-Cache,请推导一下在极限长文本下,KV Cache的显存占用公式,并解释DeepSeek V2/V3是如何用MLA(Multi-head Latent Attention)解决这个内存刺客问题的?"

痛点剖析:这道题考察对底层算力瓶颈的敏感度。长文本时代,参数量不再是推理最大的瓶颈,KV Cache才是。

极客解法(秒答+公式)

  1. 标准MHA的KV Cache公式 :每生成一个Token,需缓存其Key和Value向量。
    显存占用 = 2(K和V) × 序列长度L × 层数N × 隐层维度d_model × 批次大小B × 精度字节数(如FP16为2字节)。
    假设 N=32, d_model=4096,FP16下,每1K Token的KV Cache约为 0.5MB/生成流。当并发飙到1000,上下文长达128K时,光KV Cache就需要几十GB甚至上百GB显存!
  2. DeepSeek MLA的破局思路
    核心思想是低秩压缩(Low-Rank Compression) 。MLA不再专门存储庞大的K和V矩阵,而是只存储一个极度降维的隐变量(Latent Vector) c t c_t ct(例如降维到512)。
    • 推理时,当需要计算Attention,从这个极小的 c t c_t ct中当场恢复/投影 出K和V( k t = c t W K , v t = c t W V k_t = c_t W_K, v_t = c_t W_V kt=ctWK,vt=ctWV)。
    • 配合RoPE位移编码的特殊解耦设计,MLA将KV Cache的存储量压缩到了标准MHA的 5.4%(近20倍压缩率),这意味着单卡可以支撑原本多卡才能搞定的巨大Batch Size,吞吐量实现降维打击。

真题 2:DeepSeek 核心研发岗------"刚才推导了MLA,那你知道在MLA中应用RoPE(旋转位置编码)时会遇到什么数学冲突?DeepSeek是如何优雅化解(Decoupled RoPE)的?"

痛点剖析:这是一道纯纯的数学+工程地狱题。只背过概念的人立刻挂定。

极客解法(硬核推导)

  1. 数学冲突 :标准MHA中, q ⋅ k q \cdot k q⋅k 计算时会经过RoPE矩阵 R m R_m Rm 旋转( q T R m − n k q^T R_{m-n} k qTRm−nk)。但在MLA里, k k k 并不是存下来的,而是通过低维向量 c c c 实时乘以权重矩阵 W K W_K WK 算出来的( k = c W K k = c W_K k=cWK)。
    如果我们要把RoPE融入进去,公式变成了 ... ( c W K R n ) T ... \dots (c W_K R_n)^T \dots ...(cWKRn)T... 。问题来了!RoPE矩阵 R n R_n Rn 是受位置 n n n 影响的动态矩阵 ,它和静态权重 W K W_K WK 乘在一起,没法结合成一个常量提前算好(无法吸收到权重里),导致每次不得不重新施加庞大的RoPE矩阵,这直接打碎了低秩压缩带来的计算加速美梦
  2. DeepSeek的化解之道 (Decoupled RoPE)
    DeepSeek极其巧妙地将Query和Key撕成两半 :带RoPE的部分和不带RoPE的部分(内容部分)。
    • 内容部分(Content):纯粹通过低维隐向量 c c c 压缩和恢复,产生 k c k_c kc(完全不碰RoPE)。
    • 旋转部分(RoPE):专门单拉出一小截维度(比如64维)生成 k r k_r kr,单独施加RoPE并缓存。
    • 计算时:Attention Score = (q_c点乘k_c) + (q_r点乘k_r)
      绝杀结果 :绝大部分KV信息被压缩在 c c c 中(无需存巨大的K),只有极小维度的 k r k_r kr 带着位置信息被真实缓存。既完成了KV Cache极限压缩,又完美保留了RoPE的相对位置感知能力。

模块二:Multi-Head Attention | 多头并行的智慧

2.1 核心概念

为什么需要Multi-Head?

单头Attention只有一种"关注模式"。但自然语言中,同一个词可能需要同时关注:

  • 语法关系(主语-谓语-宾语)
  • 语义关系(近义词、反义词)
  • 位置关系(相邻词、句首句尾)

Multi-Head Attention让模型拥有多个"注意力头",每个头可以学习不同的关注模式,最后拼接起来。

Multi-Head公式

MultiHead ( Q , K , V ) = Concat ( head 1 , . . . , head h ) W O \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h) W^O MultiHead(Q,K,V)=Concat(head1,...,headh)WO

其中每个头:
head i = Attention ( Q W i Q , K W i K , V W i V ) \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) headi=Attention(QWiQ,KWiK,VWiV)

参数量分析

  • 输入维度 d model d_{\text{model}} dmodel,头数 h h h,每头维度 d k = d model / h d_k = d_{\text{model}} / h dk=dmodel/h
  • 每个头的Q/K/V投影: 3 × d model × d k = 3 d model 2 / h 3 \times d_{\text{model}} \times d_k = 3d_{\text{model}}^2/h 3×dmodel×dk=3dmodel2/h
  • h h h个头总共: 3 d model 2 3d_{\text{model}}^2 3dmodel2
  • 输出投影 W O W^O WO: d model 2 d_{\text{model}}^2 dmodel2
  • MHA总参数: 4 d model 2 4d_{\text{model}}^2 4dmodel2(和单头一样!多头并不增加参数量)

2.2 原理推导

为什么每个头要降维?

  • 单头: d k = d model d_k = d_{\text{model}} dk=dmodel,计算 Attention \text{Attention} Attention一次
  • 多头: d k = d model / h d_k = d_{\text{model}} / h dk=dmodel/h,计算 Attention \text{Attention} Attention共 h h h次
  • 总计算量基本相同: h × O ( n 2 ⋅ d model / h ) = O ( n 2 ⋅ d model ) h \times O(n^2 \cdot d_{\text{model}}/h) = O(n^2 \cdot d_{\text{model}}) h×O(n2⋅dmodel/h)=O(n2⋅dmodel)
  • 但多头提供了 h h h个不同的子空间表示,信息更丰富

2.3 代码实现

python 复制代码
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)

        # 线性投影 + reshape成多头: [batch, seq, d_model] -> [batch, heads, seq, d_k]
        Q = self.W_q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(V).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)

        # Scaled Dot-Product Attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        context = torch.matmul(attn, V)

        # 拼接多头: [batch, heads, seq, d_k] -> [batch, seq, d_model]
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        output = self.W_o(context)
        return output

2.4 工程实践

MHA → MQA → GQA 的演进

方法 Q头数 K头数 V头数 KV-Cache大小 代表模型
MHA h h h h h h h h h 2 n h d k 2nhd_k 2nhdk GPT-2, BERT
MQA h h h 1 1 2 n d k 2nd_k 2ndk PaLM, StarCoder
GQA h h h g g g g g g 2 n g d k 2ngd_k 2ngdk LLaMA-2-70B, Qwen

其中 g g g是KV的组数, g = 1 g=1 g=1退化为MQA, g = h g=h g=h退化为MHA。

实际影响:LLaMA-2-70B用GQA(8组),KV-Cache从MHA的100%降到约12%,推理速度提升显著,精度损失极小。

2.5 面试考点精讲

Q1:MQA和GQA分别是什么?为什么能加速推理?

秒答:MQA让所有Q头共享一组KV,GQA让每几个Q头共享一组KV。KV-Cache变小了,显存少了,推理自然快了。

展开:推理瓶颈在KV-Cache的显存占用和带宽。一个7B模型32层、32头,序列长度4096,用FP16,MHA的KV-Cache约2GB。GQA(8组)只要0.5GB。显存少了一方面能处理更长序列,另一方面减少了HBM读写量(推理是memory-bound),直接提速。

Q2:Multi-Head Attention相比Single-Head Attention的优势是什么?

秒答:多头让模型同时从不同子空间学习不同的注意力模式,信息更丰富。

展开:实验验证:可视化不同头的attention权重,发现有的头关注局部相邻词,有的头关注句法结构(动词-宾语),有的头关注长距离共指。如果只有一个头,这些模式被压缩到一个空间里会互相干扰。但注意:不是头越多越好,很多研究发现部分头是冗余的,可以剪枝。

Q3:如何把训练好的MHA模型转换为GQA模型?

秒答:把同一组内多个KV头的权重取平均,然后做少量继续训练来恢复精度。

展开 :LLaMA-2论文的做法------把 h h h个KV头分成 g g g组,每组内的KV权重取均值作为共享权重,然后用原始数据的5%做继续预训练。实测70B模型从MHA转GQA(8组),只需约5%训练量就能恢复到接近原始精度。


模块三:位置编码 | 让Transformer知道"顺序"

3.1 核心概念

为什么需要位置编码?

Self-Attention对输入做的是集合操作------如果打乱输入token的顺序,输出不变(置换不变性)。但语言是有顺序的:"我爱你"和"你爱我"含义完全不同。所以必须注入位置信息。

位置编码的三大流派

  1. 绝对位置编码:给每个位置一个固定或可学习的向量

    • Sinusoidal(Transformer原始)
    • Learnable(BERT、GPT-2)
  2. 相对位置编码:编码两个位置之间的距离

    • T5 Relative Bias
    • ALiBi
  3. 旋转位置编码(RoPE):巧妙地用绝对编码实现相对位置感知

    • LLaMA、Qwen、DeepSeek等主流模型标配

3.2 原理推导

Sinusoidal位置编码

P E ( p o s , 2 i ) = sin ⁡ ( p o s / 10000 2 i / d model ) PE_{(pos, 2i)} = \sin(pos / 10000^{2i/d_{\text{model}}}) PE(pos,2i)=sin(pos/100002i/dmodel)
P E ( p o s , 2 i + 1 ) = cos ⁡ ( p o s / 10000 2 i / d model ) PE_{(pos, 2i+1)} = \cos(pos / 10000^{2i/d_{\text{model}}}) PE(pos,2i+1)=cos(pos/100002i/dmodel)

设计直觉:

  • 不同维度使用不同频率的正弦/余弦波
  • 低维度变化快(高频),高维度变化慢(低频)
  • 任意位置 p o s + k pos+k pos+k的编码可以表示为 p o s pos pos位置编码的线性函数(线性外推性)

RoPE(旋转位置编码)

核心思想:不是把位置编码加到词向量上,而是把Q和K旋转一个与位置相关的角度。

对于位置 m m m处的Q向量 q q q,RoPE做如下变换(以2D为例):

f ( q , m ) = ( q 0 cos ⁡ m θ − q 1 sin ⁡ m θ q 0 sin ⁡ m θ + q 1 cos ⁡ m θ ) f(q, m) = \begin{pmatrix} q_0 \cos m\theta - q_1 \sin m\theta \\ q_0 \sin m\theta + q_1 \cos m\theta \end{pmatrix} f(q,m)=(q0cosmθ−q1sinmθq0sinmθ+q1cosmθ)

妙处在于: f ( q , m ) T f ( k , n ) = g ( q , k , m − n ) f(q, m)^T f(k, n) = g(q, k, m-n) f(q,m)Tf(k,n)=g(q,k,m−n)------点积只依赖相对位置 ( m − n ) (m-n) (m−n)!

用绝对位置编码实现了相对位置感知,而且对KV-Cache完全兼容(缓存的KV不需要随新token的加入而重新计算)。

ALiBi位置编码

完全不修改Q/K/V,而是在attention score上加一个与距离成正比的偏置:

score i j = q i T k j − m ⋅ ∣ i − j ∣ \text{score}_{ij} = q_i^T k_j - m \cdot |i - j| scoreij=qiTkj−m⋅∣i−j∣

其中 m m m是每个头不同的斜率(几何级数)。越远的token惩罚越大。

3.3 代码实现

python 复制代码
class RotaryPositionEmbedding(nn.Module):
    def __init__(self, d_k, max_len=8192, base=10000):
        super().__init__()
        # 计算频率: theta_i = 1 / base^(2i/d_k)
        freqs = 1.0 / (base ** (torch.arange(0, d_k, 2).float() / d_k))
        # 位置序列
        positions = torch.arange(max_len).float()
        # 外积得到角度矩阵: [max_len, d_k/2]
        angles = torch.outer(positions, freqs)
        # 计算cos和sin
        self.register_buffer('cos_cached', angles.cos())
        self.register_buffer('sin_cached', angles.sin())

    def forward(self, x, seq_len):
        # x: [batch, heads, seq_len, d_k]
        cos = self.cos_cached[:seq_len].unsqueeze(0).unsqueeze(0)
        sin = self.sin_cached[:seq_len].unsqueeze(0).unsqueeze(0)
        # 将x分成两半,应用旋转
        x1, x2 = x[..., ::2], x[..., 1::2]
        rotated = torch.stack([
            x1 * cos - x2 * sin,
            x1 * sin + x2 * cos
        ], dim=-1).flatten(-2)
        return rotated

3.4 工程实践

位置编码外推问题

模型训练时用4096长度,推理时输入8192会怎样?

方法 外推性 做法
Learnable PE 差(超出训练长度直接失败) 不推荐
Sinusoidal PE 理论可外推,实际效果差 基本不用了
RoPE 原始 较差(超出2x训练长度崩溃) 需要额外处理
RoPE + NTK-aware 修改base参数
RoPE + YaRN 很好 NTK + 注意力分布修正
ALiBi 较好(天然外推) 线性偏置容易泛化

NTK-aware Scaling的核心思想:不是简单地缩放位置索引(Position Interpolation),而是缩放RoPE的基频,让高频信息保持不变、低频信息被"拉伸"。

3.5 面试考点精讲

Q1:RoPE相比Sinusoidal PE的优势是什么?

秒答:RoPE让注意力分数天然只依赖相对位置,和KV-Cache完美兼容,外推性也更好。

展开 :Sinusoidal PE是加到输入上的,Q和K点积后位置信息混在语义信息里,不够"干净"。RoPE直接在Q和K上做旋转,点积结果只依赖 m − n m-n m−n(相对位置),理论上更优雅。而且RoPE不影响V,缓存的KV可以直接复用。这就是为什么LLaMA、GPT-NeoX、Qwen、DeepSeek全都选了RoPE。

Q2:训练长度4K的模型,怎么扩展到32K?

秒答:用Position Interpolation(PI)或NTK-aware缩放修改RoPE参数,再做少量长文本数据继续训练。

展开 :PI的做法是把位置索引缩小到训练范围内( p o s ′ = p o s × L train / L target pos' = pos \times L_{\text{train}}/L_{\text{target}} pos′=pos×Ltrain/Ltarget),但这样高频信息被压缩。NTK-aware更聪明:调大base(如10000→160000),等效于只"拉伸"低频维度。YaRN进一步加了attention scaling因子,效果最好。CodeLlama用PI+16K数据继续训练,成功从4K扩展到100K。

Q3:ALiBi和RoPE各自的优缺点?

秒答:ALiBi外推性天然好但表达力有限,RoPE表达力更强但需要额外处理外推。

展开:ALiBi只是一个线性衰减偏置,不修改模型参数,简单直接。但研究表明ALiBi在超长文本上注意力分布过于集中在局部,远距离信息利用不够。RoPE虽然原生外推性差,但配合NTK/YaRN后效果更好。实际选型:绝大多数2024-2026的主流模型都选了RoPE。


模块四:LayerNorm与归一化 | 稳定训练的关键

4.1 核心概念

为什么需要Normalization?

深层网络中,每一层的输入分布会随着训练不断变化(Internal Covariate Shift),导致训练不稳定。Normalization把每层的输入"拉回"到稳定的分布(均值0、方差1),让梯度流更健康。

Batch Norm vs Layer Norm

特性 BatchNorm LayerNorm
归一化方向 跨batch、在特征维度 跨特征、在单个样本内
依赖batch size 是(batch太小效果差)
适用场景 CV、batch大的场景 NLP、序列模型
推理时 用训练时的running mean/var 当场计算

Transformer用LayerNorm因为:NLP中序列长度不同,batch内padding多,BN统计不稳定。

4.2 原理推导

LayerNorm公式

LayerNorm ( x ) = γ ⊙ x − μ σ 2 + ϵ + β \text{LayerNorm}(x) = \gamma \odot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta LayerNorm(x)=γ⊙σ2+ϵ x−μ+β

其中 μ \mu μ和 σ 2 \sigma^2 σ2是对单个样本的特征维度计算的均值和方差, γ \gamma γ和 β \beta β是可学习的缩放和偏移参数。

RMSNorm公式

RMSNorm ( x ) = γ ⊙ x 1 d ∑ i = 1 d x i 2 + ϵ \text{RMSNorm}(x) = \gamma \odot \frac{x}{\sqrt{\frac{1}{d}\sum_{i=1}^{d}x_i^2 + \epsilon}} RMSNorm(x)=γ⊙d1∑i=1dxi2+ϵ x

去掉了"减均值"和"加偏置"两步。计算量减少约10-15%,效果基本一致。

Pre-Norm vs Post-Norm

  • Post-Norm (原始Transformer): x + Sublayer ( Norm ( x ) ) x + \text{Sublayer}(\text{Norm}(x)) x+Sublayer(Norm(x))...不对,原始是 Norm ( x + Sublayer ( x ) ) \text{Norm}(x + \text{Sublayer}(x)) Norm(x+Sublayer(x))
  • Pre-Norm (现在主流): x + Sublayer ( Norm ( x ) ) x + \text{Sublayer}(\text{Norm}(x)) x+Sublayer(Norm(x))

Pre-Norm的残差连接从输入直接到输出,梯度流更通畅,训练更稳定,大模型几乎都用Pre-Norm。但Post-Norm理论上上限更高(因为梯度传播不会"绕过"子层)。

4.3 代码实现

python 复制代码
class RMSNorm(nn.Module):
    def __init__(self, d_model, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(d_model))
        self.eps = eps

    def forward(self, x):
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        return x / rms * self.weight

4.4 工程实践

现代大模型的Norm选择

模型 Norm类型 位置
BERT LayerNorm Post-Norm
GPT-2 LayerNorm Pre-Norm
LLaMA RMSNorm Pre-Norm
Qwen RMSNorm Pre-Norm
DeepSeek RMSNorm Pre-Norm
Gemma RMSNorm + Pre/Post 入口出口都加

4.5 面试考点精讲

Q1:为什么LLaMA等模型选择RMSNorm而不是LayerNorm?

秒答:RMSNorm去掉了减均值和加偏置,计算量减少约10-15%,但效果几乎一样。大模型训练时间长,这点节省累积起来很可观。

展开 :RMSNorm的论文(Zhang & Sennrich, 2019)证明了LayerNorm中"重新居中"(减均值)这一步对最终效果贡献很小,主要是"重新缩放"(除以标准差)在起作用。去掉均值计算后,不仅少一次reduce操作,还减少了一个可学习参数( β \beta β),在数十亿参数的模型上,训练效率的提升是实实在在的。

Q2:Pre-Norm和Post-Norm的优缺点?为什么大模型都用Pre-Norm?

秒答:Pre-Norm训练更稳定容易收敛,Post-Norm理论上限更高但难训练。大模型追求稳定性,选Pre-Norm。

展开:Post-Norm中梯度必须经过Norm层才能到达残差连接,深层网络中梯度可能被Norm"截断"。Pre-Norm中残差是一条"高速公路",梯度可以直通输入层。代价是Pre-Norm的每层输出被Norm"压住"了,表达能力有一定损失。DeepNorm试图结合两者优点------在Post-Norm基础上修改残差系数,使梯度更稳定。

Q3:Batch Normalization为什么不适合Transformer?

秒答:NLP中batch内序列长度不一、padding比例不同,跨batch统计不稳定。而且推理时batch可能很小甚至为1。

展开:BN要在batch维度算均值方差,但NLP的batch通常很小(8/16/32),而且每个样本的有效长度不同,padding位不应参与统计。另外推理时用的running statistics是训练时累计的,如果训练和推理的数据分布差异大就会出问题。LayerNorm在特征维度上计算,每个token独立归一化,完全避免了这些问题。


模块五:Feed-Forward Network与激活函数 | Transformer的"记忆库"

5.1 核心概念

FFN在Transformer中的角色

每个Transformer Block = Attention + FFN。如果说Attention是"信息路由"(决定哪些信息该被关注),那FFN就是"信息加工"(对关注到的信息做非线性变换)。

有研究将FFN比作"键值记忆"------FFN的第一层权重是"键"(匹配输入模式),第二层权重是"值"(对应的输出知识)。这解释了为什么知识编辑(ROME/MEMIT)可以通过修改FFN权重来更新模型知识。

标准FFN结构

FFN ( x ) = W 2 ⋅ Act ( W 1 x + b 1 ) + b 2 \text{FFN}(x) = W_2 \cdot \text{Act}(W_1 x + b_1) + b_2 FFN(x)=W2⋅Act(W1x+b1)+b2

  • W 1 : d model → d ff W_1: d_{\text{model}} \to d_{\text{ff}} W1:dmodel→dff(上投影,通常 d ff = 4 d model d_{\text{ff}} = 4d_{\text{model}} dff=4dmodel)
  • W 2 : d ff → d model W_2: d_{\text{ff}} \to d_{\text{model}} W2:dff→dmodel(下投影)

5.2 原理推导

SwiGLU的计算

SwiGLU ( x ) = ( Swish ( x W 1 ) ⊙ x W 3 ) W 2 \text{SwiGLU}(x) = (\text{Swish}(xW_1) \odot xW_3) W_2 SwiGLU(x)=(Swish(xW1)⊙xW3)W2

这里有三个权重矩阵!为了保持总参数量和标准FFN( 2 × d model × 4 d model 2 \times d_{\text{model}} \times 4d_{\text{model}} 2×dmodel×4dmodel)一致,SwiGLU把 d ff d_{\text{ff}} dff从 4 d 4d 4d调整为 8 3 d \frac{8}{3}d 38d(约 2.67 d 2.67d 2.67d),三个矩阵的总参数量: 3 × d model × 8 3 d model = 8 d model 2 3 \times d_{\text{model}} \times \frac{8}{3}d_{\text{model}} = 8d_{\text{model}}^2 3×dmodel×38dmodel=8dmodel2,和原来的 8 d model 2 8d_{\text{model}}^2 8dmodel2相同。

Swish激活函数

Swish ( x ) = x ⋅ σ ( x ) \text{Swish}(x) = x \cdot \sigma(x) Swish(x)=x⋅σ(x)

其中 σ \sigma σ是sigmoid函数。Swish是光滑的非单调函数,在 x < 0 x<0 x<0区域不完全截断(不像ReLU),允许少量负值通过,有助于梯度流动。

5.3 代码实现

python 复制代码
class SwiGLU_FFN(nn.Module):
    def __init__(self, d_model, d_ff=None):
        super().__init__()
        d_ff = d_ff or int(d_model * 8 / 3)
        # 实际中通常round到64/128的倍数
        d_ff = ((d_ff + 63) // 64) * 64
        self.w1 = nn.Linear(d_model, d_ff, bias=False)
        self.w3 = nn.Linear(d_model, d_ff, bias=False)  # gate
        self.w2 = nn.Linear(d_ff, d_model, bias=False)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

5.4 工程实践

激活函数演进

激活函数 公式 使用模型
ReLU max ⁡ ( 0 , x ) \max(0, x) max(0,x) 原始Transformer
GeLU x ⋅ Φ ( x ) x \cdot \Phi(x) x⋅Φ(x) BERT, GPT-2
SwiGLU Swish ( x W 1 ) ⊙ x W 3 \text{Swish}(xW_1) \odot xW_3 Swish(xW1)⊙xW3 LLaMA, Qwen, DeepSeek
GeGLU GeLU ( x W 1 ) ⊙ x W 3 \text{GeLU}(xW_1) \odot xW_3 GeLU(xW1)⊙xW3 Gemma

5.5 面试考点精讲

Q1:FFN的中间维度为什么是 4 d model 4d_{\text{model}} 4dmodel?SwiGLU为什么改成 8 / 3 d 8/3d 8/3d?

秒答 : 4 d 4d 4d是经验值,在精度和计算量之间取得平衡。SwiGLU因为多了一个gate矩阵(三个权重),为了总参数量不变所以缩小到 8 / 3 d 8/3d 8/3d。

Q2:为什么现在主流模型都用SwiGLU而不是GeLU或ReLU?

秒答:SwiGLU的GLU门控机制让模型可以选择性地通过信息,比简单的逐元素激活函数效果更好。PaLM论文实验证明SwiGLU在同参数量下一致性地优于ReLU/GeLU。

Q3:有人说FFN是Transformer的"记忆模块",如何理解?

秒答 :FFN的参数是固定的权重矩阵,存储的是训练数据中学到的"知识"。 W 1 W_1 W1的每一行是一个"模式检测器",对应的 W 2 W_2 W2列是被触发时输出的"知识"。

展开 :Geva等人(2021)的研究显示,FFN的第一层做的是"模式匹配"(哪些输入模式会激活这个神经元),第二层做的是"知识输出"(激活后应该输出什么)。这就是为什么ROME等知识编辑方法可以通过修改FFN权重来精确更新某个事实知识------定位到存储该知识的神经元,修改对应的 W 2 W_2 W2行即可。


模块六:Encoder-Decoder架构与Decoder-Only

6.1 核心概念

三种架构范式

架构 注意力类型 代表模型 适合任务
Encoder-only 双向Self-Attention BERT, RoBERTa 分类、NER、语义理解
Encoder-Decoder Encoder双向 + Decoder因果 + Cross-Attention T5, BART, Flan-T5 翻译、摘要、seq2seq
Decoder-only 因果Self-Attention(左到右) GPT, LLaMA, Qwen 生成、对话、通用

为什么Decoder-only一统天下?

  1. 训练效率:Causal LM每个位置都能产生训练信号(预测下一个token),而MLM只有15%的mask位置有信号。
  2. KV Cache友好:自回归生成时已生成token的KV不变可cache。Encoder的双向Attention没法cache。
  3. Scaling Law的赢家:大规模实验表明Decoder-only的scaling性能最好。
  4. Zero-shot泛化:因果语言模型天然适合"续写"。

Prefix LM vs Causal LM

  • Causal LM:纯左到右,每个位置只能看前面。GPT系列。
  • Prefix LM:prefix部分双向可见(像Encoder),生成部分因果。GLM、U-PaLM。

6.2 原理推导

Decoder-only的"低秩"论证

苏剑林的分析:双向Attention的注意力矩阵是对称性较强的矩阵,由于softmax归一化约束,实际上是"低秩"的------很多行长得很像。低秩意味着不同token获得的上下文表示趋于相似,限制了表达能力。Decoder-only的因果掩码打破了这种对称性,注意力矩阵更"满秩"。

参数量估算

对于L层、d维、V词表的Decoder-only模型(SwiGLU FFN,d_ff=8d/3):

  • Embedding: V*d
  • 每层Attention: 4dd (QKV+O)
  • 每层FFN: 3d d_ff = 8dd
  • 每层Norm: 4*d
  • 总计约 Vd + L (12dd)

LLaMA-7B验证:V=32000, d=4096, L=32,约6.74B参数。

6.3 代码实现

python 复制代码
class CausalLM(nn.Module):
    def __init__(self, vocab_size, d_model, n_layers, n_heads, max_seq_len):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.layers = nn.ModuleList([
            TransformerBlock(d_model, n_heads, d_ff=int(8*d_model/3))
            for _ in range(n_layers)
        ])
        self.norm = RMSNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
        # 权重共享
        self.lm_head.weight = self.token_emb.weight
        mask = torch.tril(torch.ones(max_seq_len, max_seq_len))
        self.register_buffer("causal_mask", mask)

    def forward(self, input_ids):
        B, T = input_ids.shape
        x = self.token_emb(input_ids)
        mask = self.causal_mask[:T, :T]
        for layer in self.layers:
            x = layer(x, mask=mask)
        x = self.norm(x)
        return self.lm_head(x)

6.4 工程实践

Embedding和LM Head权重共享(Weight Tying):输入Embedding矩阵和输出LM Head共享同一个权重。节省V*d参数,且让输入输出在同一语义空间中。几乎所有现代模型都用。

FLOPs估算经验公式

  • 前向FLOPs约等于2倍参数量乘以序列长度
  • 训练FLOPs约等于6倍参数量乘以总训练token数(前向+反向+激活重计算)

6.5 面试考点精讲(5题)

Q1. ⭐ [高频] 为什么现在的LLM几乎都采用Decoder-only架构?

一句话秒答:训练效率高(每个token都产生loss信号)、推理KV Cache友好、scaling性能最好、zero-shot泛化能力强。

展开来说:四个核心原因:(1) Causal LM在每个位置都有loss,数据利用率比MLM高6-7倍;(2) KV Cache让自回归生成高效,双向Encoder不能cache;(3) Chinchilla、PaLM的大规模实验验证Decoder-only是scaling最优架构;(4) 所有任务都可以转成"续写"格式,不需要任务专属头。

面试加分:补一个反面------"Encoder-Decoder在特定任务(翻译、改写)上还有优势,Flan-T5在指令跟随上也不差。只是大模型的趋势是Decoder-only。"


Q2. ⭐⭐ [字节] Encoder-only、Encoder-Decoder、Decoder-only三种架构分别适合什么任务?

一句话秒答:Encoder-only适合理解(分类/NER),Encoder-Decoder适合seq2seq(翻译/摘要),Decoder-only适合生成(对话/续写),但现在Decoder-only通过足够大的规模几乎通吃。

展开来说:关键转折点是GPT-3/ChatGPT证明了Decoder-only足够大+指令微调后,在理解任务上也不比BERT差。于是"一个架构通吃"的路线胜出。

面试加分:补一句------"还有Prefix LM(如GLM),prefix部分双向、生成部分因果,试图兼顾两者优点但最终也没成主流。"


Q3. ⭐⭐ [高频] Transformer的Embedding权重和LM Head权重为什么可以共享?

一句话秒答:输入Embedding把token映射到向量空间,LM Head把向量映射回token------两者是反操作,共享权重让输入和输出在同一语义空间。

展开来说:LM Head的logits = h * E^T,即隐藏状态和每个词向量的点积,等价于在词向量空间做最近邻搜索。好处:节省V*d参数、训练信号共享、语义一致。

面试加分:提一句------"权重共享在大词表模型中更有价值。Qwen词表150K,共享能省600M参数。但BLOOM选择不共享,因为不共享训练初期收敛更快。"


Q4. ⭐⭐ [面经] Prefix LM和Causal LM的区别是什么?

一句话秒答:Causal LM全程只能看左边(严格下三角mask),Prefix LM的前缀部分双向可见、后续部分才是因果的。

展开来说:区别只在attention mask------Causal LM是严格下三角,Prefix LM在prefix区域是全1(双向)。GLM就是Prefix LM的代表,结合了BERT的双向理解和GPT的生成能力。


Q5. ⭐⭐⭐ [腾讯] Transformer模型的参数量如何计算?LLaMA-7B的参数量怎么得来的?

一句话秒答 :总参数约等于Vd + 12 Ldd。LLaMA-7B:V=32K, d=4096, L=32,约6.74B。

展开来说:逐项:Embedding 131M + 每层Attention(QKV+O) 67M + 每层FFN(SwiGLU) 135M + 每层Norm 8K。32层总计6.48B + Embedding 0.13B + Final Norm 4K = 约6.74B。

面试加分 :补FLOPs公式------"训练FLOPs约67B 1T=4.2e22。前向FLOPs约2参数量seq_len。"


模块七:Transformer训练技巧

7.1 核心概念

学习率调度:Transformer训练中最关键的超参不是学习率本身,而是学习率的变化曲线。

现代大模型标准做法:Warmup + Cosine Decay

  • 前2000步线性warmup到峰值学习率
  • 然后余弦衰减到峰值的1/10

为什么需要Warmup? 训练初期参数随机,梯度方向noisy。大学习率会把参数推飞。Warmup让模型先"热身",等梯度方向稳定后再加大步子。

优化器:AdamW是标准选择。关键超参:beta1=0.9, beta2=0.95, weight_decay=0.1。

训练稳定性trick

  • Gradient Clipping:梯度裁剪到max_norm=1.0
  • Mixed Precision:BF16前向,FP32累积更新
  • Gradient Accumulation:小batch多步累积,等效大batch

7.2 原理推导

BF16 vs FP16

格式 指数位 尾数位 数值范围 特点
FP32 8 23 3.4e38 标准精度
FP16 5 10 65504 范围小,易溢出
BF16 8 7 3.4e38 范围大,精度略低

BF16的指数位和FP32相同,数值范围一样大,训练时不容易出NaN。FP16指数位只有5位,需要配合loss scaling。A100以后BF16成默认选择。

7.3 代码实现

python 复制代码
import math
import torch

def get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps):
    def lr_lambda(step):
        if step < warmup_steps:
            return step / max(1, warmup_steps)
        progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

# 使用示例
optimizer = torch.optim.AdamW(
    model.parameters(), lr=3e-4,
    betas=(0.9, 0.95), weight_decay=0.1
)
scheduler = get_cosine_schedule_with_warmup(optimizer, 2000, 100000)

7.4 工程实践

Loss Spike排查与处理

  1. 数据质量:检查spike前后的batch是否含异常文本
  2. 学习率:是否在lr变化拐点
  3. 梯度:监控每层梯度范数
  4. 数值:FP16容易Inf/NaN

处理:轻微spike自愈不管;严重spike回滚checkpoint+跳过问题数据。PaLM训练中约遇20次spike,都是回滚解决。

Gradient Checkpointing:用时间换空间------前向时不保存中间激活,反向时重新计算。显存省60-70%,计算增加30%。

7.5 面试考点精讲(5题)

Q1. ⭐ [高频] 为什么Transformer训练要用Learning Rate Warmup?

一句话秒答:训练初期梯度方差大、Adam的二阶动量估计不准,大学习率会让参数发散。Warmup让优化器先积累准确的统计量。

展开来说:两个原因:(1) Adam的二阶动量v_t初始很小,bias-correction放大后有效学习率趋近无穷;(2) Post-Norm在训练初期梯度方差大。Pre-Norm可以不用warmup(GPT-2论文验证)。

面试加分:提一句"LION、Sophia等新优化器号称不需要warmup,但主流实践中大家还是加warmup。"


Q2. ⭐⭐ [字节] BF16和FP16有什么区别?为什么大模型训练优先选BF16?

一句话秒答:BF16指数位多(8位)数值范围大不易溢出,FP16指数位少(5位)易出NaN。大模型中"不爆炸"比"精度高"更重要。

展开来说:FP16上限65504,训练中gradient/activation经常超过这个值,需要loss scaling。BF16范围和FP32一样,几乎不可能溢出,不需要loss scaling。代价是精度略低(尾数7位vs10位),但对模型质量影响很小。

面试加分:提一句"FP8(E4M3/E5M2)是下一代格式,H100的FP8比BF16快2倍。推理已可用,训练还在成熟中。"


Q3. ⭐⭐ [高频] 什么是Gradient Accumulation?解决什么问题?

一句话秒答:显存不够装大batch时,多个小batch梯度累加再更新------等效大batch训练。

展开来说:等效batch_size = micro_batch * accumulation_steps * num_gpus。LN对batch维度独立,所以Transformer用gradient accumulation完全安全(不像BN有统计量问题)。

面试加分:提一句"等效batch太大可能导致收敛到更差的局部最优。LLM的甜点范围是2M到4M tokens。"


Q4. ⭐⭐ [面经] 现代大模型还用Dropout吗?为什么?

一句话秒答:不用了。LLaMA、GPT-3、PaLM的Dropout率都设为0。因为数据量够大不怕过拟合,weight decay就够了。

展开来说:原始Transformer有三个Dropout位置:Attention权重后、FFN中间、残差连接前。但现代大模型的训练数据足够大,过拟合风险小。Dropout在分布式训练中还引入额外通信开销。如果需要正则化,weight_decay=0.1是更好的选择。


Q5. ⭐⭐⭐ [阿里] 训练大模型时遇到Loss Spike怎么排查和处理?

一句话秒答:查数据质量、查学习率、查梯度范数、查数值溢出。处理:回滚checkpoint+跳过问题数据。

展开来说:排查清单:(1) 数据层面------spike前后的batch是否含异常文本;(2) 学习率------是否在warmup结束的拐点;(3) 梯度------某层梯度是否突变;(4) 数值------是否有Inf/NaN。

处理方法:轻微spike(涨了又降)不管;严重spike回滚100步+跳过问题batch。PaLM训练遇20次spike都这么解决的,问题batch主要是代码和数学公式。

面试加分:提一个预防措施------"训练时要实时监控每层的gradient norm和activation norm,设置告警阈值,在spike刚出现时就自动暂停回滚。"


模块八:高效Attention与推理优化

8.1 核心概念

为什么需要高效Attention? 标准Attention的O(n^2)复杂度在长序列场景下是致命瓶颈。n=128K时,注意力矩阵需要60+GB显存。

主流优化方向

方法 核心思路 复杂度 代表
FlashAttention IO优化(tiling+kernel fusion) O(n^2)但实际快2-4x FA1/FA2/FA3
稀疏Attention 只计算部分注意力对 O(nsqrt(n))到O(nk) Longformer, BigBird
滑动窗口 每个token只看最近w个 O(n*w) Mistral, Gemma
线性Attention kernel trick去掉softmax O(n*d) Linear Transformer
PagedAttention KV Cache内存管理优化 不改复杂度,减少碎片 vLLM

8.2 原理推导

FlashAttention的核心insight

标准Attention需要把n*n的注意力矩阵从GPU的HBM(高带宽内存)读写多次。FlashAttention通过tiling(分块)把计算分成小块,每块在SRAM(片上高速缓存)中完成所有计算,只把最终结果写回HBM。

关键挑战:softmax需要知道全局最大值来做数值稳定(减去max),但tiling时每块只能看到局部。解决方案:Online Softmax------维护一个running max和running sum,每处理一个新块就更新。

数学上:对于两块S1和S2拼接的softmax:

  • m_new = max(m1, m2)
  • l_new = l1 * exp(m1 - m_new) + l2 * exp(m2 - m_new)
  • output = (o1 * l1 * exp(m1-m_new) + o2 * l2 * exp(m2-m_new)) / l_new

这样就能逐块处理而不需要存储整个n*n矩阵。

Sliding Window Attention

每个token只attend到最近w个token。注意力矩阵从full nn变成band-diagonal nw。

直觉:对于第L层来说,token_i的有效感受野是wL(每层窗口w,L层堆叠后覆盖wL)。Mistral-7B用w=4096,32层后有效感受野=131072>max_seq_len,理论上信息可以传到任意远。

PagedAttention(vLLM核心)

KV Cache的内存管理问题:不同请求的序列长度不同,固定分配max_len的cache会造成大量碎片。PagedAttention借鉴OS的分页内存------把KV Cache切成固定大小的"页"(如16个token一页),按需分配,不同请求可以共享物理页。

好处:KV Cache利用率从50-60%提升到>95%,同等显存下可服务的并发请求翻倍。

8.3 代码实现

python 复制代码
# FlashAttention使用(调用而非实现,因为需要CUDA kernel)
from flash_attn import flash_attn_func

# Q, K, V: (batch, seq_len, n_heads, d_head)
# 注意:flash_attn的输入shape和标准实现不同
output = flash_attn_func(
    Q, K, V,
    dropout_p=0.0,
    causal=True,  # 因果mask
    softmax_scale=1.0 / math.sqrt(d_head)
)
# output: (batch, seq_len, n_heads, d_head)

# Sliding Window Attention(简化版)
def sliding_window_attention(Q, K, V, window_size):
    B, H, T, D = Q.shape
    # 创建band mask: 只有|i-j| <= window_size的位置为True
    i = torch.arange(T).unsqueeze(1)  # (T, 1)
    j = torch.arange(T).unsqueeze(0)  # (1, T)
    mask = ((i - j).abs() <= window_size) & (j <= i)  # causal + window
    mask = mask.float().unsqueeze(0).unsqueeze(0)  # (1, 1, T, T)
    
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(D)
    scores = scores.masked_fill(mask == 0, float('-inf'))
    attn = F.softmax(scores, dim=-1)
    return torch.matmul(attn, V)

8.4 工程实践

FlashAttention的使用注意事项

  • FA2比FA1快约2x,支持更多head dim(如128、256)
  • FA3在H100上进一步优化,支持FP8
  • 必须安装对应CUDA版本的flash-attn包
  • HuggingFace的Transformers已经内置FA2支持:model = AutoModel.from_pretrained(..., attn_implementation="flash_attention_2")

推理优化的完整栈

复制代码
模型层面: GQA/MLA减少KV头 -> 量化(W4A16) -> 稀疏注意力
系统层面: FlashAttention -> PagedAttention -> Continuous Batching
硬件层面: Tensor Core -> NVLink -> InfiniBand

KV Cache显存估算实例

LLaMA-2-70B(80层GQA-8,d_k=128,BF16):

  • 每token KV: 2 * 80 * 8 * 128 * 2 bytes = 320KB
  • 4K序列: 320KB * 4096 = 1.28GB
  • batch=32: 1.28GB * 32 = 41GB(超过单卡80GB的一半!)

所以长序列+大batch下KV Cache是显存大户,GQA/MLA和量化KV Cache至关重要。

8.5 面试考点精讲(5题)

Q1. ⭐⭐ [高频] FlashAttention的核心思想是什么?它为什么能加速Attention计算?

一句话秒答:FlashAttention通过tiling(分块)把Attention计算搬到GPU的SRAM中完成,大幅减少HBM读写次数------是IO优化而非算法复杂度优化。

展开来说:GPU有两级存储:HBM(高带宽但慢,80GB)和SRAM(极快但小,20MB)。标准Attention需要把n*n的注意力矩阵在HBM和SRAM之间搬运多次(写出S矩阵、读回做softmax、写出P矩阵、读回与V相乘)。

FlashAttention把Q/K/V分成小块(tile),每块在SRAM中完成所有计算(QK^T -> softmax -> *V),只把最终结果写回HBM。用Online Softmax解决了分块softmax的问题。

结果:Attention的FLOPs没变(还是O(n2d)),但HBM访问量从O(n2+nd)降到O(n^2d/M)(M是SRAM大小),实际速度快2-4x。

面试加分:提一句"FlashAttention还有一个bonus:因为不需要存储n*n的注意力矩阵,空间复杂度从O(n^2)降到O(n)------这对长序列至关重要。"


Q2. ⭐⭐ [字节] PagedAttention是什么?它解决了什么问题?

一句话秒答:PagedAttention把KV Cache切成固定大小的"页"按需分配,解决了不同长度请求共享GPU显存时的碎片问题------类似OS的虚拟内存。

展开来说:没有PagedAttention时,每个请求预分配max_seq_len的KV Cache空间。如果max=8K但实际只用了500个token,7500个位置的显存就浪费了。多请求时碎片率可达40-50%。

PagedAttention的做法:把KV Cache切成固定大小的page(如每页16个token),用page table管理。新生成一个token时按需分配新页,序列结束时释放页。不同请求可以共享空同的物理页。

vLLM就是基于PagedAttention构建的推理引擎,KV Cache利用率从50-60%提升到>95%。

面试加分:提一句"PagedAttention还支持copy-on-write------parallel sampling时多个生成分支可以共享prompt的KV Cache页,只在产生分歧时才复制。这对beam search和parallel decoding很重要。"


Q3. ⭐⭐ [腾讯] Sliding Window Attention的原理是什么?Mistral是怎么用的?

一句话秒答:每个token只attend到最近w个token。通过L层堆叠,有效感受野扩展到w*L,理论上覆盖任意长度。

展开来说:Mistral-7B用w=4096的滑动窗口。32层后有效感受野=4096*32=131072。这意味着虽然每层只看4K的局部窗口,但信息可以通过层层传递覆盖到128K+的范围。

好处:KV Cache大小固定为w(不随序列长度增长),推理显存可控。对比标准Attention的KV Cache随n线性增长。

限制:信息传递是"间接的"------远距离信息需要通过多层中转,不如full attention的直接连接准确。所以Mistral在某些需要精确长距离引用的任务上不如full attention的模型。


Q4. ⭐⭐⭐ [高频] 线性Attention的原理是什么?为什么它能把复杂度从O(n^2)降到O(n)?

一句话秒答 :线性Attention用kernel trick把softmax(QKT)V分解为phi(Q)(phi(K)T V)------改变矩阵乘法顺序,从先算nn降到先算dd。

展开来说 :标准Attention计算顺序是(softmax(QK^T))V,先算nn的注意力矩阵。

线性Attention的关键:找一个特征映射phi,使得exp(q*k)约等于phi(q)*phi(k)。然后:

Attention = phi(Q) * (phi(K)^T * V)

先算phi(K)^T * V得到dd矩阵(O(n dd)),再左乘phi(Q)得到n d的输出(O(ndd))。总复杂度O(n*d^2)------对n线性!

代价:phi的近似质量直接决定模型质量。实践中线性Attention的效果一般不如标准Attention,尤其在需要"尖锐注意力"(sharp attention)的任务上。

最新进展:Mamba/State Space Model走了另一条路------用线性RNN+选择机制替代Attention,在长序列建模上效果不错。


Q5. ⭐⭐⭐ [阿里] 比较FlashAttention、稀疏Attention、线性Attention三种优化方式的本质区别。

一句话秒答:FlashAttention是IO优化(不改算法,改计算方式),稀疏Attention是算法优化(只算部分注意力对),线性Attention是数学优化(用核近似降低复杂度)。

展开来说

维度 FlashAttention 稀疏Attention 线性Attention
算法复杂度 O(n^2d) 不变 O(nwd) 或 O(n*sqrt(n)*d) O(n*d^2)
空间复杂度 O(n) O(n*w) O(n*d)
质量损失 无(数值等价) 有(信息截断) 有(核近似误差)
适用场景 通用加速 已知局部性强的任务 超长序列、RNN替代
工程难度 高(需要CUDA kernel) 中(改mask即可) 低(纯数学变换)

实际工程中最常用的组合:FlashAttention(必开)+ GQA/MLA(减少KV头)+ 量化KV Cache。稀疏和线性Attention在特定场景有用但不是主流。

面试加分:提一句"Mamba2(2024)尝试统一Attention和SSM,提出了SSD(Structured State Space Duality)框架,证明线性Attention和状态空间模型在数学上是等价的。这可能是未来序列建模的方向。"

Q6. ⭐⭐⭐ [高频] KV Cache显存怎么算?70B模型、batch=8、seq_len=32K需要多少KV Cache显存?

这是大厂面试的高频计算题,考察对Transformer推理机制的理解。

公式:KV Cache显存 = 2 x num_layers x num_kv_heads x head_dim x seq_len x batch_size x dtype_bytes

以LLaMA-70B为例

  • num_layers = 80, num_kv_heads = 8 (GQA), head_dim = 128, dtype = FP16 (2 bytes)
  • 单条请求:2 x 80 x 8 x 128 x 32768 x 2 = 10 GB
  • batch=8:10 x 8 = 80 GB -- 仅KV Cache就占满一张A100

面试追问与应答

  • "怎么减少KV Cache?" --> GQA/MQA减少kv_heads;量化KV Cache到INT8/FP8(减半);PagedAttention减少碎片浪费
  • "KV Cache和模型权重谁占显存多?" --> 长序列+大batch时KV Cache远超权重(70B FP16权重=140GB,但KV Cache可达数百GB)
  • "为什么说KV Cache是推理的第一瓶颈?" --> 因为它随seq_len线性增长、随batch线性增长,是限制吞吐量和最大序列长度的直接因素

本章总结

模块 核心考点 必记公式/概念
模块一 Self-Attention Attention(Q,K,V)=softmax(QK^T/sqrt(d_k))V
模块二 MHA/MQA/GQA/MLA KV Cache大小 = 2L gd_kn*B
模块三 位置编码 RoPE旋转实现相对位置,YaRN外推
模块四 FFN/SwiGLU SwiGLU = Swish(xW1) * xW3, d_ff=8d/3
模块五 LayerNorm/RMSNorm Pre-Norm更稳定,RMSNorm更快
模块六 架构选型 Decoder-only主流,参数量约12Ld^2
模块七 训练技巧 Warmup+Cosine, BF16, Gradient Clip
模块八 高效Attention FlashAttention(IO优化), PagedAttention(内存管理), KV Cache显存计算

全章共覆盖 41道 面试真题,涵盖概念理解、数学推导、代码实现、工程实践四个层次。

相关推荐
浅念-2 小时前
C++ STL stack、queue 与容器适配器详解
开发语言·c++·经验分享·笔记·学习·面试
nudt_qxx2 小时前
讲透Transformer(五):Self-Attention与KV Cache的深度解析——从原理到实现
人工智能·深度学习·transformer
TracyCoder1233 小时前
LeetCode Hot100(57/100)——5. 最长回文子串
算法·leetcode·职场和发展
香芋Yu3 小时前
【2026大模型面试圣经】(2)主流大模型架构全景 | GPT/LLaMA/DeepSeek/Qwen深度对比
gpt·面试·架构
我命由我123453 小时前
Photoshop - Photoshop 工具栏(68)内容感知移动工具
学习·ui·职场和发展·求职招聘·职场发展·学习方法·photoshop
一个努力编程人4 小时前
NLP领域————Transformer
人工智能·自然语言处理·transformer
indexsunny4 小时前
互联网大厂Java面试实战:Spring Boot与微服务在电商场景的应用
java·spring boot·微服务·面试·kafka·prometheus·电商
重生之后端学习4 小时前
39. 组合总和
java·数据结构·算法·职场和发展·深度优先
UrbanJazzerati4 小时前
Python Logging库完全指南:从小白到熟练
后端·面试