LLMs-from-scratch:多头潜在注意力(MLA)

代码链接:https://github.com/rasbt/LLMs-from-scratch/blob/main/ch04/05_mla/README.md

多头潜在注意力(MLA)

这份扩展材料展示了在常规多头注意力(MHA)之上使用多头潜在注意力(MLA)时的内存节省效果。

简介

.../04_gqa 中,我们讨论了分组查询注意力(GQA)作为提升 MHA 计算效率的一种替代方案。多项消融研究(例如 原始 GQA 论文Llama 2 论文)显示,其在大语言模型的建模性能上与标准 MHA 相当。

现在,多头潜在注意力(MLA),被应用于 DeepSeek V2、V3 和 R1,提供了另一种与 KV 缓存非常契合的内存节省策略。与 GQA 通过共享键/值头不同,MLA 在将键和值张量存入 KV 缓存之前,先将它们压缩到更低维的空间。

在推理阶段,这些压缩后的张量会在使用前投影回原始大小,如下图所示。这会增加一次矩阵乘法,但可降低内存占用。

(顺带一提,查询也会被压缩,但仅在训练期间,推理时不会。)

另外,正如前文所述,MLA 并非 DeepSeek V3 的新内容,其前身 DeepSeek V2 也使用并引入了它。此外,V2 论文包含一些有趣的消融研究,或许能解释为什么 DeepSeek 团队选择 MLA 而非 GQA(见下图)。

如上图所示,GQA 的表现似乎劣于 MHA,而 MLA 的建模性能优于 MHA,这很可能是 DeepSeek 团队选择 MLA 而非 GQA 的原因。(如果还能看到 MLA 与 GQA 在"每个 token 的 KV 缓存"节省上的比较就更有趣了!)

本节小结:在进入下一架构组件前,MLA 是一种巧妙的技巧,可以降低 KV 缓存的内存占用,并在建模性能上略优于 MHA。

MLA 内存节省

内存节省主要体现在 KV 存储上。我们可以用如下公式计算 KV 存储大小:

bytes ≈ batch_size × seqlen × n_layers × latent_dim × bytes_per_elem

相比之下,MHA 的 KV 缓存内存按如下方式计算:

bytes ≈ batch_size × seqlen × n_layers × embed_dim × 2 (K,V) × bytes_per_elem

这意味着,在 MLA 中,我们将"embed_dim × 2(K、V)"缩减为"latent_dim",因为我们只存储压缩后的潜在表示,而不是完整的键和值向量(如前图所示)。

你可以使用本目录中的脚本 <memory_estimator_mla.py>,针对不同模型配置计算使用 MLA 相较于 MHA 能节省多少内存:

bash 复制代码
➜ uv run memory_estimator_mla.py \
  --context_length 8192 \
  --emb_dim 2048 \
  --n_heads 24 \
  --n_layers 48 \
  --n_kv_groups 4 \
  --batch_size 1 \
  --dtype bf16 \
  --latent_dim 1024
==== Config ====
context_length   : 8192
emb_dim          : 2048
n_heads          : 24
n_layers         : 48
n_kv_groups      : 4
latent_dim       : 1024
batch_size       : 1
dtype            : bf16 (2 Bytes/elem)
head_dim         : 86
GQA n_kv_heads   : 6

==== KV-cache totals across all layers ====
MHA total KV cache  : 3.25 GB
GQA total KV cache  : 0.81 GB
MLA total KV cache  : 0.81 GB
Ratio (MHA / GQA)   : 4.00x
Savings (GQA vs MHA): 75.00%
Ratio (MHA / MLA)   : 4.03x
Savings (MLA vs MHA): 75.19%

注意,上述压缩(--emb_dim 2048 -> latent_dim 1024)实现了与 GQA 相近的节省效果。实际使用中,压缩程度是需要仔细研究的超参数,因为将 latent_dim 设得过小会对建模性能产生负面影响(类似于在 GQA 中选择过多的 n_kv_groups)。

下图进一步展示了在不同 latent_dim 取值下,随上下文长度变化时 MLA 相较于 MHA 的节省:

你可以通过执行 uv run plot_memory_estimates_mla.py 复现该图。

MLA 代码示例

本目录中的 <gpt_with_kv_mha.py> 与 <gpt_with_kv_mla.py> 脚本提供了在 GPT 模型实现背景下比较 MHA 与 MLA 内存占用的实操示例。

此处的 MLA 代码受该实现启发:https://huggingface.co/bird-of-paradise/deepseek-mla

需要注意,MLA 也可以与 GQA 结合使用,但为简化起见,这里没有这样做。(目前我也不了解有知名的 LLM 同时采用两者。)

另请注意,该模型未经过训练,因此会生成无意义的文本。不过,你可以在第 5-7 章中将其作为标准 GPT 模型的替代实现并进行训练。

最后,该实现使用了 另一扩展章节 中解释的 KV 缓存,因此内存节省更为显著。

bash 复制代码
uv run gpt_with_kv_mha.py \
--max_new_tokens 32768 \
--n_heads 24 \
--n_layers 12 \
--emb_dim 768

...

Time: 453.81 sec
72 tokens/sec
Max memory allocated: 1.54 GB
bash 复制代码
uv run gpt_with_kv_mla.py \
--max_new_tokens 32768 \
--n_heads 24 \
--n_layers 12 \
--emb_dim 768 \
--latent_dim 192 # 4x compression

...

Time: 487.21 sec
67 tokens/sec
Max memory allocated: 0.68 GB

之所以没有像上面的图表那样看到更大的节省,原因有二:

  1. 为了让模型在合理时间内完成生成,我使用了较小的配置。
  2. 更重要的是,我们这里观察的是整个模型,而不仅仅是注意力机制;模型中的全连接层占据了大部分内存(这本身值得另行分析)。
相关推荐
Mintopia8 小时前
🤖 2025 年的人类还需要 “Prompt 工程师” 吗?
人工智能·llm·aigc
Mintopia8 小时前
意图驱动编程(Intent-Driven Programming)
人工智能·llm·aigc
想用offer打牌9 小时前
逃出结构化思维:从向量,向量数据库到RAG
后端·架构·llm
想用offer打牌9 小时前
Reasoning + Acting: ReAct范式与ReAct Agent
人工智能·后端·llm
爱可生开源社区10 小时前
在数据库迁移中,如何让 AI 真正“可用、可信、可落地”?
数据库·sql·llm
Elwin Wong13 小时前
关于熵的一些概念及其计算
人工智能·大模型·llm
hzp6661 天前
新兴存储全景与未来架构走向
大数据·大模型·llm·aigc·数据存储
EdisonZhou1 天前
MAF快速入门(8)条件路由工作流
llm·aigc·agent·.net core
暴风鱼划水1 天前
大型语言模型(入门篇)B
人工智能·语言模型·大模型·llm
xhxxx1 天前
别再让 AI 自由发挥了!用 LangChain + Zod 强制它输出合法 JSON
前端·langchain·llm