解密Prompt系列70. 从 MLA 到 CSA,聊聊大模型 Attention 的“瘦身”与“闪送”

最近做了太多应用层的事,底层原理反而搁置了。趁着放假,围绕 Attention 机制,聊聊这两年注意力架构的技术演化路线。全文覆盖三个方向,每个方向都是上一个方向的"接力棒":

  • KV Cache 压缩类:MQA → GQA → MLA(每个 token 的 KV 变细)
  • 推理效率优化类:Flash Attention、Paged Attention(让 GPU 跑得更满)
  • 长文本优化类:NSA → DSA → CSA + HCA(需要 attend 的 token 变少)

一、MHA:从古老的Multi Head Attention说起

哈哈,LLM 技术一天,往往人间一年。说是"古老",掐指一算,Transformer 论文正式发表也不过才 10 年不到。但就在这短短几年里,MHA 已经从"标准答案"变成了"需要优化的起点"。

标准 MHA(Multi-Head Attention)的逻辑很直接:为每个注意力头独立计算 Q、K、V 投影,每个 token 都和所有历史 token 建立关联。

问题也很直接:

问题 说明
KV Cache内存爆炸 长上下文(128K+)或较大batch时,KV Cache迅速成为GPU显存瓶颈
推理带宽受限 每生成一个新token,需要从HBM加载整个历史KV Cache到SRAM进行attention计算,带宽需求随heads和seq_len线性增长
长上下文计算代价高 估算在解码 64k 长度上下文时,基于 softmax 的 attention 计算占总延迟的 70--80%,凸显出对更高效 attention 机制的迫切需求

核心矛盾只有一个:KV 太胖,历史太长,搬不动。

优化方向也随之分化成两条路:要么让每个 token 的 KV 变瘦(纵向压缩);要么让需要 attend 的 token 变少(横向稀疏)。

二、压缩 KV Cache:让每个 token 的 KV 变瘦

2.1 MQA:激进共享

核心设计:多个 Query head,只保留一个 K head 和一个 V head。概念就是参数共享

MQA(Multi-Query Attention)是第一个掀桌子的方案。如果 32 个 Query head 都用同一套 KV,KV Cache 直接缩小 32 倍。

代价是什么?精度有损失。把"32 套独立的视角"强行合并成"1 套公共视角",不同 head 之间的多样性丢失了。所以 MQA 一般只用在对速度极度敏感、容忍一定精度下降的场景。

2.2 GQA:折中方案 ⭐

核心设计:引入差值概念,提供介于MQA和MHA之间,在质量与速度之间提供了良好的折中案

GQA(Grouped Query Attention)引入了一个参数 G(分组数):

  • G = num_heads:退化为 MHA,每个 Q head 独享 KV
  • G = 1:退化为 MQA,所有 Q head 共用一套 KV
  • 1 < G < num_heads:GQA 的甜蜜区间,质量与速度的折中

主流模型的压缩幅度一般在4~8倍

python 复制代码
def GQA(X, W_Q, W_K, W_V, W_O, num_q_heads, num_kv_heads, d_k):
    """
    num_q_heads: Query 头数 (e.g., 32)
    num_kv_heads: KV 头数 (e.g., 8),即分组数 G
    每组 num_q_heads//num_kv_heads 个 Q 共享一对 K/V
    """
    B, T, D = X.shape
    groups = num_q_heads // num_kv_heads  # 每个KV head服务的Q head数
    
    # Q 有 num_q_heads 个
    Q = (X @ W_Q).view(B, T, num_q_heads, d_k).transpose(1, 2)     # [B, Hq, T, d_k]
    
    # K/V 只有 num_kv_heads 个
    K = (X @ W_K).view(B, T, num_kv_heads, d_k).transpose(1, 2)    # [B, Hkv, T, d_k]
    V = (X @ W_V).view(B, T, num_kv_heads, d_k).transpose(1, 2)    # [B, Hkv, T, d_k]
    
    # 将 K/V 重复扩展以匹配 Q 的头数 (或用 reshape 广播)
    K = K.repeat_interleave(groups, dim=1)  # [B, Hq, T, d_k]
    V = V.repeat_interleave(groups, dim=1)
    
    scores = (Q @ K.transpose(-2, -1)) / sqrt(d_k)  # [B, Hq, T, T]
    attn = softmax(scores, dim=-1)
    out = attn @ V  # [B, Hq, T, d_k]
    
    out = out.transpose(1, 2).contiguous().view(B, T, D)
    return out @ W_O

2.3 MLA:低秩压缩,DeepSeek 的独门绝技 ⭐

核心设计:不通过参数共享,而是通过矩阵低秩压缩来缩小 KV。

说到低秩压缩,熟悉微调的同学一定不陌生------LoRA 背后就是同一个思想。MLA引入了两个投影矩阵,通过在存储KV时对KV进行下投影压缩到更小的矩阵存储,再在推理时通过上投影矩阵恢复到原有维度,来降低KV Cache的存储量级。

但 MLA 面临一个额外的挑战:KV 里有 RoPE 位置编码。位置编码必须感知序列顺序,不能被低秩压缩"抹平"。所以 MLA 把 K 拆成两部分:

  • NoPE 部分(无位置编码):KV共同走低秩压缩通道,存入压缩潜变量 c_KV
  • RoPE 部分(有位置编码):QK各保留一部分维度经过ROPE处理以保持对位置变量的敏感性不压缩单独保留

最终推理时通过把RoPE和NoPE进行拼接做Attention运算。

python 复制代码
def MLA(X, params, num_heads, d_model, d_c, d_r):
    """
    d_c: KV 压缩后的潜在维度 (远小于 num_heads * d_k)
    d_r: RoPE 分量维度
    """
    B, T, D = X.shape
    
    # === Step 1: KV 压缩(核心!)===
    # 下投影:将 X 压缩为低秩潜变量 c_kv
    c_kv = X @ params['W_DKV']          # [B, T, d_c]  d_c << num_heads*d_k
    # 缓存 c_kv 即可,不需缓存完整 K/V!!
    
    # 从潜变量重建 K/V 的 NoPE 部分(推理时从 c_kv 上投影)
    K_nope = c_kv @ params['W_UK']      # [B, T, num_heads * d_k]
    V      = c_kv @ params['W_UV']      # [B, T, num_heads * d_v]
    
    # === Step 2: RoPE 部分(位置信息保留)===
    # K 的 RoPE 分量,单独从 X 投影(不走压缩路径)
    K_rope = RoPE(X @ params['W_KR'])   # [B, T, d_r]  共享,不按头区分
    
    # Q 也分两部分(低秩压缩 + RoPE)
    c_q = X @ params['W_DQ']            # Q 低秩压缩
    Q_nope = c_q @ params['W_UQ']       # [B, T, num_heads * d_k]
    Q_rope = RoPE(c_q @ params['W_QR']) # [B, T, num_heads * d_r]
    
    # === Step 3: 拼合完整 Q 和 K ===
    # 每个 head 的 Q = [Q_nope_h ; Q_rope_h]
    # 每个 head 的 K = [K_nope_h ; K_rope]  (K_rope 各 head 共享)
    Q = concat([Q_nope, Q_rope], dim=-1)  # [B, T, num_heads, d_k + d_r]
    K = concat([K_nope, K_rope.expand(num_heads)], dim=-1)
    
    # Reshape
    Q = Q.view(B, T, num_heads, -1).transpose(1, 2)
    K = K.view(B, T, num_heads, -1).transpose(1, 2)
    V = V.view(B, T, num_heads, d_v).transpose(1, 2)
    
    # === Step 4: 标准 Attention ===
    d_total = Q.shape[-1]
    scores = (Q @ K.transpose(-2, -1)) / sqrt(d_total)
    attn = softmax(scores, dim=-1)
    out = attn @ V
    
    # === Step 5: 输出 ===
    # 优化技巧:W_UV 可吸收进 W_O,推理时无需显式还原 V
    out = out.transpose(1, 2).view(B, T, -1)
    return out @ params['W_O']

📌 关三种架构的 KV Cache 对比:

Attention KV cache token
MHA 2 × n_h × d_h (per layer)
GQA 2 × n_g × d_h (n_g << n_h)
MLA d_c + d_r (d_c << n_h × d_h)

三、 优化推理效率:让 GPU 更忙,而不是更闲

有了更小的 KV,接下来的问题是:怎么让 GPU 把这些 KV 用得更高效?

这里有两个层次的优化:

  • 计算效率:让每一次注意力运算更快 → Flash Attention
  • 存储效率:让显存的 KV Cache 分配更合理 → Paged Attention

两者不是竞争关系,而是计算协议和存储协议的配合,现代推理框架(如 vLLM)同时使用两者。

3.1 Flash Attention:小锅快炒,别让 GPU 等数据

问题

要理解 Flash Attention,必须先知道 GPU 有一个经常被忽视的内存层级:

GPU 的算力极强,但数据必须先从 HBM(仓库)搬到 SRAM(厨房)才能计算。大量时间不是浪费在"算"上,而是浪费在"等搬运"上。

标准 Attention 有多糟糕?它要把 S×S 的注意力矩阵在 HBM 和 SRAM 之间反复搬运:

复制代码
① 计算 Q×K^T → 写入 HBM(S×S矩阵)
② 读回 HBM → 做 Softmax → 写回 HBM(再存一次)
③ 读回 HBM → 乘 V → 写出最终结果

解决方案:Tiling+ Online Softmax

Flash Attention 的核心思路是小锅快炒:每次只从 HBM 取一小块 Q、K、V,放进 SRAM 里,把打分、Softmax、加权求和全都在 SRAM 里一次性完成,只把最终结果写回 HBM。

关键数学技巧是 Online Softmax:Softmax 本来需要看全所有分数才能归一化,但通过维护两个滚动变量(当前最大值 m 和累加和 l),可以流式地更新归一化结果,数学上与完整的 Softmax 完全等价。

效果对比

指标 标准 Attention Flash Attention
HBM 读写量 O(S²) O(S)(线性!)
中间矩阵显存占用 O(S²) O(1)(不需要写出!)

Paged Attention:停车场不必都停豪华车位

FlashAttention解决的是"计算"效率,PagedAttention解决的是"存储管理"效率。

问题

在大模型生成对话时,KV Cache(键值缓存)会随着字数增加而增长。系统不知道你会聊多长,所以通常会为每个用户预分配一块连续的大空间。

那问题就来了,该预先分配多少内存用于Kv Cache呢?分配少了不够用,分配多了浪费。于是就会出现三种常见的浪费现象

  • 预留浪费: 我预留了能存 2048 个词的空间,结果你只说了 10 个词,剩下的 1900 多个位置全空着,别人也用不了。
  • 碎片浪费: 有些空间零零散散,虽然加起来很大,但因为不连续,新来的长请求放不进去。 这导致显存利用率通常只有 20% - 40%,非常浪费钱。
  • 前缀浪费:因为给每个请求都单独分配kv cache空间,所以当多个请求有相同前缀时,也无法共享,而是要重复N份相同的kv Cache存储。

解决方案:虚拟内存,按需分配

PagedAttention 借用了计算机系统的**"虚拟内存"**概念。

text 复制代码
操作系统:物理内存 → 固定大小的"页" → 页表映射虚拟地址
                    ↕ 一一对应 ↕
Paged Attn:KV Cache → 固定大小的"Block"→ Block Table映射逻辑位置

具体机制:

  1. 显存切块:把全部显存预先切成固定大小的 Block(比如每块存 16 个 token 的 KV)
  2. 按需分配:一开始只分配 1 个 Block,用满再申请下一个,不用就立刻释放
  3. Block Table:维护每个请求的逻辑位置→物理 Block 的映射表,物理上可以不连续
  4. 前缀共享:100 个请求的 System Prompt 物理上只存 1 份,Block Table 里各自指向同一个物理 Block(Copy-on-Write)

引入Paged Attention平均预估throughtput能提升8倍以上

四、长文本稀疏注意力:让需要 Attend 的 Token 变少

前面 MLA、GQA 优化的是"每个 token 的 KV 有多厚"(纵向压缩)。但面对 1M token 的超长上下文,即便每个 token 的 KV 只有 70KB,整个 KV Cache 依然高达 70GB!

更棘手的是计算量:标准 Attention 的复杂度是 O(L²),L=1M 时就是 10¹² 次运算,根本无法承受。

所以我们需要另一个维度的压缩:横向稀疏化,让每个 token 只 attend 最相关的一小部分历史,而不是全部。

4.1 NSA:三镜头摄影系统,DeepSeek 稀疏注意力的起点

NSA是DeepSeek最早提出的,原生可训练的稀疏注意力机制。

NSA 的核心思想可以用一个摄影类比来理解:

📸 三镜头摄影系统

  • 超广角镜(压缩分支):捕捉整体场景,分辨率低但覆盖全
  • 长焦镜(选择分支):锁定远处关键目标,精准清晰
  • 标准镜(滑动窗口):拍清近处细节,最高分辨率

具体的实现方案如下

  1. 分支一:压缩注意力(超广角------宏观感知)

把历史序列按 block_size 分组,每组用一个可学习的 MLP 压缩成 1 个向量。Query 和这些摘要向量做注意力,快速扫视全局。

复制代码
原始序列: [t1 t2 t3 t4] [t5 t6 t7 t8] ... [t_n-3 t_n-2 t_n-1 t_n]
                 ↓ MLP压缩              ↓                    ↓
摘要向量:      [cmp_1]              [cmp_2]   ...        [cmp_k]

计算量: 从 O(L) → O(L/block_size)
  1. 分支二:选择注意力(长焦------精准检索)

    压缩分支已经算过"每个 block 有多相关"了。直接复用这些分数,Top-K 选出最相关的块,回到原始未压缩的 KV 中取出精细内容,再做一次精确注意力。

  2. 滑动窗口(标准镜------局部连贯)

取最近 win_size 个 token 的完整 KV,保证模型对近期上下文的精细理解不受稀疏化影响。

最终输出的是以上三路分支拼接后,通过门控进行加权的结果,让模型学会选择究竟在不同情况下应该使用哪个分支。

NSA其实研究性质的论文,而后面DeepSeek在V3.2,V4两个最新的模型版本中,正式把NSA工程化落地。

4.2 DSA:把 NSA 装进工厂,V3.2 的工程验证

NSA 是方向,DSA(DeepSeek Sparse Attention)是第一次工程落地,随 DeepSeek-V3.2 发布。

DSA 的核心创新叫 Lightning Indexer------一个用低精度(FP8/FP4)、低维度(远小于 d_model)的 Q/K 投影来快速打分的轻量级索引器

python 复制代码
# 核心:Lightning Indexer 实现
class LightningIndexer(nn.Module):
    def forward(self, x):
        # 1. 低维投影,d_indexer 远小于 d_model
        q_idx = linear_fp8(x, W_qi) 
        k_idx = linear_fp8(x, W_ki) 
        
        # 2. Block 聚合 (工程上通常在投影后做 Pool)
        k_idx_blocks = pool1d(k_idx, kernel=block_size, stride=block_size) 
        
        # 3. 计算相关性分数 (此处不使用 Softmax,使用 ReLU 提高硬件吞吐)
        scores = relu(q_idx @ k_idx_blocks.T) 
        return scores

# 4. Top-K 掩码生成

关键工程细节:

  1. 粒度问题:DSA 论文层面是 token 粒度打分,但由于叠加了 MLA 的 MQA 模式(同一 KV entry 被所有 Q head 共享),工程 kernel 层面强制变成了 Block 粒度选择,才能保证计算效率。
  2. ReLU + Dot:softmax的打分机制同样被简化成直接QK乘积过Relu,使用FP8或者FP4低精度
  3. 对齐训练:DSA 不是从头训练,而是在已有稠密模型上做后训练。为了让 Indexer "学会"稠密 Attention 的打分直觉,引入了KL 对齐损失,让 Lightning Indexer 的输出分布逐渐逼近完整 Attention 的权重分布。

DSA更像是一个实验性作品,验证NSA提出的稀疏注意力是能跑通的。

4.3 V4 的 Hybrid Attention:重新排列组合

随后DeepSeek V4推出了Hybrid Attention,进一步完善了NSA的设计,通过把粗粒度记忆、细粒度检索、局部窗口,这三种机制重新排列组合。把CSA(精准稀疏+窗口)和HCA(粗粒度全局+窗口)这两种注意力在层与层之间交替使用, 来实现更优的长文本推理效果。

CSA:先压缩 4×,再精准检索

DeepSeek V4推出的CSA在DSA的基础上,主要增加了三个改进

1. 改进一(最核心):先压缩序列,再做稀疏选择

text 复制代码
DSA: [1M token] → Indexer 打分 → Top-K 选择
CSA: [1M token] → 4× 压缩 → [250K 条目] → Indexer 打分 → Top-K 选择

Indexer 的搜索空间直接缩小 4 倍!
KV Cache 的存储量也缩小 4 倍!

2. 改进二:双流重叠压缩,消除"块切割"边界伪影

这是 CSA 最精妙的设计。固定分块压缩有一个天然缺陷:如果"深度学习"这个词组恰好被切成 深度学习 两个 Block,后续稀疏选择可能只选了其中一个,关键信息就断掉了。

CSA 的解法是引入两个位置重叠的KV流:

复制代码
流 a 和流 b 不是简单的"错位偏移"!
而是对有重叠的两段隐状态 H,用两套独立的投影矩阵产生两套 KV 表示:

C_a = H @ W_aKV   +   位置偏置 B_a   ← [0,1,2,3]
C_b = H @ W_bKV   +   位置偏置 B_b   ← [2,3,4,5]

关键:两个流在 Block 内做跨流联合 Softmax 归一化
模型自动学会:这个位置是流 a 表示得更好,还是流 b?

相似的思路你在RAG的overlap-chunking中也能看到。

3. 改进三:保留滑动窗口分支,补回局部精细细节

压缩后的远端历史 + 未压缩的近端 token,两者拼接后做最终注意力。

HCA:压缩 128×,换取廉价全局感知

HCA 的思路更激进:既然 CSA 已经处理"精准远端检索",那 HCA 就专门负责"粗粒度全局感知"。也就是NSA压缩压缩注意力的实现。

  • 机制压缩比:把CSA1:4的压缩比,放大到1:128
  • 稠密注意力:不再使用稀疏注意力,而是通过稠密注意力,学习粗粒度的全局信息。实现上也没有使用CSA的双流,毕竟压缩比例这么大,追求边界是否清晰已经没有意义。
  • 拼接token粒度滑动窗口注意力

层级交替策略:三种角色各司其职

对于 DeepSeek-V4-Flash,前两层使用纯滑动窗口注意力,其余层交替使用 CSA 和 HCA;对于 DeepSeek-V4-Pro,前两层使用 HCA,其余层交替使用 CSA 和 HCA。

在 1M token 上下文设置下,DeepSeek-V4-Pro 仅需 DeepSeek-V3.2 单 token 推理 FLOPs 的 27% 和 10% 的 KV Cache。


这一篇就先聊到这里啦~

相关推荐
在路上走着走着15 天前
Prompt Engineering 入门指南:从原理到上手
人工智能·prompt
coft15 天前
Loop Engineering — 从“写 prompt“到“设计循环“,AI Agent 的下一次进化
人工智能·prompt
CoLiuRs15 天前
从 Prompt 到 Loop:AI 工程到底在卷什么
人工智能·prompt
AI 小老六15 天前
GEPA 架构拆解:让 Prompt 和 Skill 优化不靠玄学
数据库·人工智能·ai·架构·开源·prompt
凯丨15 天前
从写 Prompt 到Loop Engineering:AI 编程的下一次跃迁
prompt
奋飛15 天前
从 Prompt 到 Agent:LangChain 究竟解决了什么问题
ai·langchain·prompt·agent
沪漂阿龙16 天前
Context Engineering:比 Prompt Engineering 更重要的上下文工程
人工智能·langchain·prompt
猿人谷16 天前
从 Prompt Engineering 到 Loop Engineering:AI 编程正在进入“闭环工程”时代
大数据·人工智能·prompt
取个鸣字真的难16 天前
Image2 生成 PPT 的最后分水岭:Prompt
人工智能·prompt·powerpoint