目录
- 第一篇:多头潜在注意力(MLA)
- [第二篇:DeepSeekMoE --- 无辅助损失负载均衡](#第二篇:DeepSeekMoE — 无辅助损失负载均衡)
- [第三篇:多 Token 预测(MTP)](#第三篇:多 Token 预测(MTP))
- 参考文献
第一篇:多头潜在注意力(MLA)
1. 引言
标准多头注意力(MHA)的 KV Cache 是大语言模型推理的核心瓶颈。对于一个具有 n h n_h nh 个注意力头、每头维度 d h d_h dh 的模型,每个 token 需要缓存的 KV 大小为 2 n h d h 2 n_h d_h 2nhdh。当序列长度达到数万 token 时,KV Cache 占用的显存远超模型参数本身。
MLA(Multi-head Latent Attention, DeepSeek-V2, 2024) 提出了一种革命性的 KV Cache 压缩方案:通过低秩分解将 KV 压缩到一个远小于原始维度的潜在向量中,推理时通过"吸收技巧"避免显式解压。这一设计实现了:
- KV Cache 压缩 57× :从每 token 2 n h d h = 32768 2 n_h d_h = 32768 2nhdh=32768 维压缩到 d c + d n = 576 d_c + d_n = 576 dc+dn=576 维
- 推理性能不降:通过矩阵吸收,数学上等价于标准 MHA
- 位置编码兼容:通过解耦 RoPE 机制,完美兼容旋转位置编码
2. 理论基础 --- 注意力机制的 KV Cache
2.1 标准多头注意力回顾
标准 MHA 的计算流程。对于第 i i i 个注意力头:
q t ( i ) = W Q ( i ) h t , k t ( i ) = W K ( i ) h t , v t ( i ) = W V ( i ) h t \mathbf{q}_t^{(i)} = \mathbf{W}_Q^{(i)} \mathbf{h}_t, \quad \mathbf{k}_t^{(i)} = \mathbf{W}_K^{(i)} \mathbf{h}_t, \quad \mathbf{v}_t^{(i)} = \mathbf{W}_V^{(i)} \mathbf{h}_t qt(i)=WQ(i)ht,kt(i)=WK(i)ht,vt(i)=WV(i)ht
Attn ( i ) ( q t , k ≤ t , v ≤ t ) = softmax ( ( q t ( i ) ) T k ≤ t ( i ) d h ) v ≤ t ( i ) \text{Attn}^{(i)}(\mathbf{q}t, \mathbf{k}{\leq t}, \mathbf{v}{\leq t}) = \text{softmax}\left(\frac{(\mathbf{q}t^{(i)})^T \mathbf{k}{\leq t}^{(i)}}{\sqrt{d_h}}\right) \mathbf{v}{\leq t}^{(i)} Attn(i)(qt,k≤t,v≤t)=softmax(dh (qt(i))Tk≤t(i))v≤t(i)
其中 h t ∈ R d \mathbf{h}_t \in \mathbb{R}^d ht∈Rd 是第 t t t 个 token 的隐藏状态。
2.2 KV Cache 的显存分析
在自回归推理中,每生成一个新 token,需要将当前层的 k t , v t \mathbf{k}_t, \mathbf{v}_t kt,vt 追加到缓存中。
| 模型 | n h n_h nh | d h d_h dh | d = n h d h d = n_h d_h d=nhdh | 每 token KV Cache |
|---|---|---|---|---|
| LLaMA-2-7B | 32 | 128 | 4096 | 2 × 32 × 128 = 8192 2 \times 32 \times 128 = 8192 2×32×128=8192 |
| LLaMA-2-70B | 64 | 128 | 8192 | 2 × 64 × 128 = 16384 2 \times 64 \times 128 = 16384 2×64×128=16384 |
| DeepSeek-V2 | 128 | 128 | 16384 | 2 × 128 × 128 = 32768 2 \times 128 \times 128 = 32768 2×128×128=32768 |
对于 DeepSeek-V2 的 128 头注意力,每 token 每层需要缓存 32768 维的 KV 向量。在 64 层模型、序列长度 128K 的情况下:
KV Cache 总量 = 64 × 128000 × 32768 × 2 bytes ≈ 537 GB \text{KV Cache 总量} = 64 \times 128000 \times 32768 \times 2 \text{ bytes} \approx 537 \text{ GB} KV Cache 总量=64×128000×32768×2 bytes≈537 GB
这远远超出了单张 GPU 的显存容量。
2.3 现有 KV Cache 压缩方案
GQA(Grouped-Query Attention):多个查询头共享同一组 KV 头。
KV Cache = 2 × n g × d h , n g < n h \text{KV Cache} = 2 \times n_g \times d_h, \quad n_g < n_h KV Cache=2×ng×dh,ng<nh
MQA(Multi-Query Attention):所有查询头共享单一 KV 头。
KV Cache = 2 × d h \text{KV Cache} = 2 \times d_h KV Cache=2×dh
局限性 :GQA/MQA 通过减少 KV 头数来压缩,但压缩率受限于组数,且共享 KV 会损失表达能力。
3. MLA 的核心创新 --- 低秩 KV 压缩
3.1 核心思想
MLA 的核心洞察:高维 KV 向量存在大量冗余,可以通过低秩投影压缩到一个紧凑的潜在空间。
不同于 GQA/MQA 减少头数,MLA 保留所有头的信息,但通过下投影-上投影架构压缩 KV Cache:
h t ∈ R d ⏟ 隐藏状态 → W D K V c t K V ∈ R d c ⏟ 压缩潜在向量 → W U K , W U V k t , v t ⏟ 重建 KV \underbrace{\mathbf{h}t \in \mathbb{R}^d}{\text{隐藏状态}} \xrightarrow{\mathbf{W}{DKV}} \underbrace{\mathbf{c}t^{KV} \in \mathbb{R}^{d_c}}{\text{压缩潜在向量}} \xrightarrow{\mathbf{W}{UK}, \mathbf{W}_{UV}} \underbrace{\mathbf{k}_t, \mathbf{v}t}{\text{重建 KV}} 隐藏状态 ht∈RdWDKV 压缩潜在向量 ctKV∈RdcWUK,WUV 重建 KV kt,vt
3.2 数学形式化
下投影(压缩):
c t K V = W D K V h t ∈ R d c \mathbf{c}t^{KV} = \mathbf{W}{DKV} \mathbf{h}_t \in \mathbb{R}^{d_c} ctKV=WDKVht∈Rdc
其中 W D K V ∈ R d c × d \mathbf{W}_{DKV} \in \mathbb{R}^{d_c \times d} WDKV∈Rdc×d, d c ≪ n h d h d_c \ll n_h d_h dc≪nhdh。
上投影(重建):
k t C = W U K c t K V ∈ R n h d h \mathbf{k}t^C = \mathbf{W}{UK} \mathbf{c}_t^{KV} \in \mathbb{R}^{n_h d_h} ktC=WUKctKV∈Rnhdh
v t = W U V c t K V ∈ R n h d h \mathbf{v}t = \mathbf{W}{UV} \mathbf{c}_t^{KV} \in \mathbb{R}^{n_h d_h} vt=WUVctKV∈Rnhdh
其中 W U K ∈ R n h d h × d c \mathbf{W}{UK} \in \mathbb{R}^{n_h d_h \times d_c} WUK∈Rnhdh×dc, W U V ∈ R n h d h × d c \mathbf{W}{UV} \in \mathbb{R}^{n_h d_h \times d_c} WUV∈Rnhdh×dc。
KV Cache 只需存储 c t K V \mathbf{c}_t^{KV} ctKV ,而非完整的 k t , v t \mathbf{k}_t, \mathbf{v}_t kt,vt。
3.3 低秩分解的信息论解释
将 KV 矩阵视为数据矩阵 X = h 1 , h 2 , ... , h L T ∈ R L × d \mathbf{X} = \\mathbf{h}_1, \\mathbf{h}_2, \\ldots, \\mathbf{h}_L^T \in \mathbb{R}^{L \times d} X=h1,h2,...,hLT∈RL×d,则:
K = X W K T = X W D K V T W U K T \mathbf{K} = \mathbf{X} \mathbf{W}K^T = \mathbf{X} \mathbf{W}{DKV}^T \mathbf{W}_{UK}^T K=XWKT=XWDKVTWUKT
这等价于对 X \mathbf{X} X 做 PCA 降维 后再重建。当 d c d_c dc 足够大时,重建误差可以忽略。
定理 (Eckart-Young):对于秩为 r r r 的矩阵 M \mathbf{M} M,其最优秩- k k k 近似( k ≤ r k \leq r k≤r)为:
M k = ∑ i = 1 k σ i u i v i T \mathbf{M}k = \sum{i=1}^{k} \sigma_i \mathbf{u}_i \mathbf{v}_i^T Mk=i=1∑kσiuiviT
其中 σ i \sigma_i σi 是奇异值, u i , v i \mathbf{u}_i, \mathbf{v}_i ui,vi 是对应的奇异向量。近似误差为:
∥ M − M k ∥ F = ∑ i = k + 1 r σ i 2 \|\mathbf{M} - \mathbf{M}_k\|F = \sqrt{\sum{i=k+1}^{r} \sigma_i^2} ∥M−Mk∥F=i=k+1∑rσi2
当 KV 矩阵的有效秩远小于 n h d h n_h d_h nhdh 时(实践中通常如此),低秩压缩几乎无损。
3.4 为什么 KV 是低秩的?
直觉 1:语义冗余
相邻 token 的 K、V 向量高度相似(如 "the" "cat" "sat" 的语义渐变),存在大量线性相关性。
直觉 2:注意力稀疏
实际注意力矩阵通常是稀疏的(只有少数 token 对被高度关注),这意味着高维 KV 空间中的大部分维度是冗余的。
直觉 3:头间冗余
不同注意力头的 K、V 往往关注相似的模式,导致跨头的 KV 向量存在相关性。
4. 吸收技巧 --- 推理时的等价加速
4.1 问题:上投影的计算开销
如果在推理时显式执行上投影:
k t = W U K c t K V , v t = W U V c t K V \mathbf{k}t = \mathbf{W}{UK} \mathbf{c}_t^{KV}, \quad \mathbf{v}t = \mathbf{W}{UV} \mathbf{c}_t^{KV} kt=WUKctKV,vt=WUVctKV
则每生成一个新 token,需要对所有 历史位置的 c s K V \mathbf{c}_s^{KV} csKV 执行上投影来计算注意力分数。这使得压缩带来的显存节省被计算开销抵消。
4.2 吸收技巧的数学推导
MLA 的关键技巧:将上投影矩阵吸收到查询投影中。
注意力分数的计算:
score ( t , s ) = ( q t ( i ) ) T k s ( i ) = ( W Q ( i ) h t ) T ( W U K ( i ) c s K V ) \text{score}(t, s) = (\mathbf{q}_t^{(i)})^T \mathbf{k}_s^{(i)} = (\mathbf{W}_Q^{(i)} \mathbf{h}t)^T (\mathbf{W}{UK}^{(i)} \mathbf{c}_s^{KV}) score(t,s)=(qt(i))Tks(i)=(WQ(i)ht)T(WUK(i)csKV)
= h t T ( W Q ( i ) ) T W U K ( i ) c s K V = \mathbf{h}_t^T (\mathbf{W}Q^{(i)})^T \mathbf{W}{UK}^{(i)} \mathbf{c}_s^{KV} =htT(WQ(i))TWUK(i)csKV
定义吸收后的查询投影:
W ~ Q ( i ) = ( W U K ( i ) ) T W Q ( i ) ∈ R d c × d h \tilde{\mathbf{W}}Q^{(i)} = (\mathbf{W}{UK}^{(i)})^T \mathbf{W}_Q^{(i)} \in \mathbb{R}^{d_c \times d_h} W~Q(i)=(WUK(i))TWQ(i)∈Rdc×dh
则:
score ( t , s ) = ( W ~ Q ( i ) h t ) T c s K V = ( q ~ t ( i ) ) T c s K V \text{score}(t, s) = (\tilde{\mathbf{W}}_Q^{(i)} \mathbf{h}_t)^T \mathbf{c}_s^{KV} = (\tilde{\mathbf{q}}_t^{(i)})^T \mathbf{c}_s^{KV} score(t,s)=(W~Q(i)ht)TcsKV=(q~t(i))TcsKV
关键结果 :推理时只需存储 c s K V ∈ R d c \mathbf{c}_s^{KV} \in \mathbb{R}^{d_c} csKV∈Rdc,无需显式重建 k s \mathbf{k}_s ks。查询直接与压缩潜在向量做点积,数学上完全等价于标准 MHA。
4.3 吸收技巧的复杂度分析
| 操作 | 标准 MHA | MLA(无吸收) | MLA(吸收后) |
|---|---|---|---|
| KV Cache | 2 n h d h 2 n_h d_h 2nhdh | d c d_c dc | d c d_c dc |
| 查询投影 | O ( d ⋅ d h ) O(d \cdot d_h) O(d⋅dh) | O ( d ⋅ d h ) O(d \cdot d_h) O(d⋅dh) | O ( d ⋅ d c ) O(d \cdot d_c) O(d⋅dc) |
| 注意力分数 | O ( d h ) O(d_h) O(dh) | O ( d h ) O(d_h) O(dh) | O ( d c ) O(d_c) O(dc) |
| 输出聚合 | O ( d h ) O(d_h) O(dh) | O ( d h ) O(d_h) O(dh) | O ( d h ) O(d_h) O(dh) |
吸收后,注意力分数的计算从 O ( d h ) O(d_h) O(dh) 变为 O ( d c ) O(d_c) O(dc),但由于 d c > d h d_c > d_h dc>dh(DeepSeek-V2 中 d c = 512 , d h = 128 d_c = 512, d_h = 128 dc=512,dh=128),单次点积的计算量略有增加。然而,显存的巨大节省远超这点计算开销。
5. 解耦 RoPE --- 位置编码的兼容性
5.1 RoPE 对吸收技巧的破坏
RoPE(Rotary Position Embedding)对 Q、K 施加位置相关的旋转:
RoPE ( x , t ) = ( x 1 cos ( t θ 1 ) − x 2 sin ( t θ 1 ) x 1 sin ( t θ 1 ) + x 2 cos ( t θ 1 ) ⋮ ) \text{RoPE}(\mathbf{x}, t) = \begin{pmatrix} x_1 \cos(t\theta_1) - x_2 \sin(t\theta_1) \\ x_1 \sin(t\theta_1) + x_2 \cos(t\theta_1) \\ \vdots \end{pmatrix} RoPE(x,t)= x1cos(tθ1)−x2sin(tθ1)x1sin(tθ1)+x2cos(tθ1)⋮
RoPE 后的注意力分数:
score ( t , s ) = ( RoPE ( q t , t ) ) T RoPE ( k s , s ) \text{score}(t, s) = (\text{RoPE}(\mathbf{q}_t, t))^T \text{RoPE}(\mathbf{k}_s, s) score(t,s)=(RoPE(qt,t))TRoPE(ks,s)
问题:RoPE 将位置信息乘入 Q、K,使得:
RoPE ( k s , s ) = RoPE ( W U K c s K V , s ) \text{RoPE}(\mathbf{k}s, s) = \text{RoPE}(\mathbf{W}{UK} \mathbf{c}_s^{KV}, s) RoPE(ks,s)=RoPE(WUKcsKV,s)
由于 RoPE 是非线性的(三角函数),无法将 W U K \mathbf{W}_{UK} WUK 吸收到查询投影中:
RoPE ( W U K c s K V , s ) ≠ W U K RoPE ( c s K V , s ) \text{RoPE}(\mathbf{W}_{UK} \mathbf{c}s^{KV}, s) \neq \mathbf{W}{UK} \text{RoPE}(\mathbf{c}_s^{KV}, s) RoPE(WUKcsKV,s)=WUKRoPE(csKV,s)
5.2 解耦 RoPE 的设计
MLA 将 Key 分解为两个独立部分:
k t = k t C ⏟ 内容键 ; k t R ⏟ 位置键 \mathbf{k}_t = \\underbrace{\\mathbf{k}_t\^C}_{\\text{内容键}} ; \\underbrace{\\mathbf{k}_t\^R}_{\\text{位置键}} kt=内容键 ktC;位置键 ktR
q t = q t C ⏟ 内容查询 ; q t R ⏟ 位置查询 \mathbf{q}_t = \\underbrace{\\mathbf{q}_t\^C}_{\\text{内容查询}} ; \\underbrace{\\mathbf{q}_t\^R}_{\\text{位置查询}} qt=内容查询 qtC;位置查询 qtR
内容键(无 RoPE,支持吸收):
k t C = W U K c t K V ∈ R n h d h \mathbf{k}t^C = \mathbf{W}{UK} \mathbf{c}_t^{KV} \in \mathbb{R}^{n_h d_h} ktC=WUKctKV∈Rnhdh
位置键(有 RoPE,独立计算):
k t R = RoPE ( W K R h t , t ) ∈ R d n \mathbf{k}t^R = \text{RoPE}(\mathbf{W}{KR} \mathbf{h}_t, t) \in \mathbb{R}^{d_n} ktR=RoPE(WKRht,t)∈Rdn
其中 W K R ∈ R d n × d \mathbf{W}_{KR} \in \mathbb{R}^{d_n \times d} WKR∈Rdn×d, d n d_n dn 是位置编码维度(通常较小,如 64)。
5.3 解耦注意力分数
最终注意力分数为两部分之和:
score ( t , s ) = ( q t C ) T k s C + ( q t R ) T k s R \text{score}(t, s) = (\mathbf{q}_t^C)^T \mathbf{k}_s^C + (\mathbf{q}_t^R)^T \mathbf{k}_s^R score(t,s)=(qtC)TksC+(qtR)TksR
= ( q ~ t ) T c s K V + ( q t R ) T k s R = (\tilde{\mathbf{q}}_t)^T \mathbf{c}_s^{KV} + (\mathbf{q}_t^R)^T \mathbf{k}_s^R =(q~t)TcsKV+(qtR)TksR
- 第一项 :使用吸收技巧,KV Cache 只需 c s K V ∈ R d c \mathbf{c}_s^{KV} \in \mathbb{R}^{d_c} csKV∈Rdc
- 第二项 :位置键需要单独缓存 k s R ∈ R d n \mathbf{k}_s^R \in \mathbb{R}^{d_n} ksR∈Rdn,但 d n d_n dn 很小
总 KV Cache:
MLA Cache = d c ⏟ 内容潜在向量 + d n ⏟ 位置键 = 512 + 64 = 576 \text{MLA Cache} = \underbrace{d_c}{\text{内容潜在向量}} + \underbrace{d_n}{\text{位置键}} = 512 + 64 = 576 MLA Cache=内容潜在向量 dc+位置键 dn=512+64=576
对比标准 MHA: 2 × 128 × 128 = 32768 2 \times 128 \times 128 = 32768 2×128×128=32768,压缩比约 57×。
5.4 解耦 RoPE 的直觉理解
| 组件 | 编码内容 | 是否需要位置 | 是否支持吸收 |
|---|---|---|---|
| k t C \mathbf{k}_t^C ktC | 语义内容 | 否 | 是 |
| k t R \mathbf{k}_t^R ktR | 位置信息 | 是 | 否 |
直觉:语义信息是低秩的(可以压缩),位置信息是稀疏的(只需少量维度)。MLA 将两者解耦,分别用最优的方式处理。
6. 完整可运行实现
6.1 MLA 核心实现
python
"""
Multi-head Latent Attention (MLA) --- 完整可运行实现
依赖: torch >= 2.0, numpy, matplotlib
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from dataclasses import dataclass
from typing import Optional, Tuple
@dataclass
class MLAConfig:
"""MLA 配置"""
d_model: int = 2048 # 模型维度
n_heads: int = 16 # 注意力头数
d_head: int = 128 # 每头维度
d_c: int = 512 # KV 压缩维度
d_n: int = 64 # RoPE 维度
rope_theta: float = 10000.0 # RoPE 基频
max_seq_len: int = 4096 # 最大序列长度
class RotaryEmbedding(nn.Module):
"""旋转位置编码 (RoPE)"""
def __init__(self, dim: int, theta: float = 10000.0, max_seq_len: int = 4096):
super().__init__()
self.dim = dim
self.theta = theta
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
t = torch.arange(max_seq_len, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
self.register_buffer("cos_cached", emb.cos())
self.register_buffer("sin_cached", emb.sin())
def forward(self, x: torch.Tensor, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
return (
self.cos_cached[:seq_len].to(x.dtype),
self.sin_cached[:seq_len].to(x.dtype),
)
def rotate_half(x: torch.Tensor) -> torch.Tensor:
"""RoPE 的半旋转操作"""
x1, x2 = x.chunk(2, dim=-1)
return torch.cat([-x2, x1], dim=-1)
def apply_rope(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""应用旋转位置编码"""
q_rot = q * cos + rotate_half(q) * sin
k_rot = k * cos + rotate_half(k) * sin
return q_rot, k_rot
class MultiHeadLatentAttention(nn.Module):
"""多头潜在注意力 (MLA)"""
def __init__(self, config: MLAConfig):
super().__init__()
self.config = config
self.n_heads = config.n_heads
self.d_head = config.d_head
self.d_c = config.d_c
self.d_n = config.d_n
self.d_model = config.d_model
# 查询投影
self.W_Q = nn.Linear(config.d_model, config.n_heads * config.d_head, bias=False)
# KV 压缩 (下投影)
self.W_DKV = nn.Linear(config.d_model, config.d_c, bias=False)
# KV 重建 (上投影)
self.W_UK = nn.Linear(config.d_c, config.n_heads * config.d_head, bias=False)
self.W_UV = nn.Linear(config.d_c, config.n_heads * config.d_head, bias=False)
# 位置编码分支 (解耦 RoPE)
self.W_KR = nn.Linear(config.d_model, config.d_n, bias=False)
self.W_QR = nn.Linear(config.d_model, config.n_heads * config.d_n, bias=False)
# 输出投影
self.W_O = nn.Linear(config.n_heads * config.d_head, config.d_model, bias=False)
# RoPE
self.rope = RotaryEmbedding(config.d_n, config.rope_theta, config.max_seq_len)
# 吸收后的查询投影 (预计算)
self._absorbed = False
def _compute_absorbed_weights(self):
"""预计算吸收后的查询权重"""
# W_Q: (n_heads * d_head, d_model)
# W_UK: (n_heads * d_head, d_c)
# 吸收: W_Q' = W_UK^T @ W_Q, shape: (d_c, d_model) per head
W_Q_reshaped = self.W_Q.weight.view(self.n_heads, self.d_head, self.d_model)
W_UK_reshaped = self.W_UK.weight.view(self.n_heads, self.d_head, self.d_c)
# 吸收后的查询投影: (n_heads, d_c, d_model)
self.absorbed_W_Q = torch.bmm(W_UK_reshaped.transpose(1, 2), W_Q_reshaped)
self._absorbed = True
def forward_train(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
训练时前向传播 (显式上投影,便于理解)
x: (B, L, D)
"""
B, L, D = x.shape
# 1. 查询投影
q = self.W_Q(x) # (B, L, n_heads * d_head)
q = q.view(B, L, self.n_heads, self.d_head).transpose(1, 2) # (B, H, L, d_h)
# 2. KV 压缩
c_kv = self.W_DKV(x) # (B, L, d_c)
# 3. KV 重建
k = self.W_UK(c_kv) # (B, L, n_heads * d_head)
k = k.view(B, L, self.n_heads, self.d_head).transpose(1, 2) # (B, H, L, d_h)
v = self.W_UV(c_kv) # (B, L, n_heads * d_head)
v = v.view(B, L, self.n_heads, self.d_head).transpose(1, 2) # (B, H, L, d_h)
# 4. 位置编码分支
q_r = self.W_QR(x) # (B, L, n_heads * d_n)
q_r = q_r.view(B, L, self.n_heads, self.d_n).transpose(1, 2) # (B, H, L, d_n)
k_r = self.W_KR(x) # (B, L, d_n)
# 应用 RoPE
cos, sin = self.rope(x, L)
q_r, k_r = apply_rope(q_r, k_r.unsqueeze(1), cos, sin)
# 5. 计算注意力分数
scale = 1.0 / math.sqrt(self.d_head + self.d_n)
# 内容部分
attn_content = torch.matmul(q, k.transpose(-2, -1)) # (B, H, L, L)
# 位置部分
attn_position = torch.matmul(q_r, k_r.transpose(-2, -1)) # (B, H, L, L)
attn = (attn_content + attn_position) * scale
if mask is not None:
attn = attn.masked_fill(mask == 0, float("-inf"))
attn = F.softmax(attn, dim=-1)
# 6. 输出
out = torch.matmul(attn, v) # (B, H, L, d_h)
out = out.transpose(1, 2).contiguous().view(B, L, -1) # (B, L, n_heads * d_h)
out = self.W_O(out) # (B, L, D)
return out
def forward_inference(
self,
x_t: torch.Tensor,
c_kv_cache: torch.Tensor,
k_r_cache: torch.Tensor,
position: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
推理时前向传播 (使用吸收技巧)
x_t: (B, 1, D) 当前 token
c_kv_cache: (B, T, d_c) 压缩 KV 缓存
k_r_cache: (B, 1, T, d_n) 位置键缓存
position: 当前位置
"""
B = x_t.shape[0]
# 1. 计算压缩潜在向量
c_kv_new = self.W_DKV(x_t) # (B, 1, d_c)
# 2. 吸收后的查询投影
if not self._absorbed:
self._compute_absorbed_weights()
# q_absorbed: (B, n_heads, 1, d_c)
q_absorbed = torch.einsum("bld,hcd->bhlc", x_t, self.absorbed_W_Q)
q_absorbed = q_absorbed.unsqueeze(2) if q_absorbed.dim() == 3 else q_absorbed
# 3. 位置分支
q_r = self.W_QR(x_t) # (B, 1, n_heads * d_n)
q_r = q_r.view(B, 1, self.n_heads, self.d_n).transpose(1, 2) # (B, H, 1, d_n)
k_r_new = self.W_KR(x_t) # (B, 1, d_n)
cos, sin = self.rope(x_t, position + 1)
cos_t = cos[position:position+1].unsqueeze(0).unsqueeze(0)
sin_t = sin[position:position+1].unsqueeze(0).unsqueeze(0)
q_r = q_r * cos_t + rotate_half(q_r) * sin_t
k_r_new = k_r_new * cos[position:position+1] + rotate_half(k_r_new) * sin[position:position+1]
# 4. 注意力分数 (吸收后)
# 内容部分: q_absorbed @ c_kv_cache^T
attn_content = torch.matmul(q_absorbed, c_kv_cache.transpose(-2, -1)) # (B, H, 1, T)
# 位置部分: q_r @ k_r_cache^T
attn_position = torch.matmul(q_r, k_r_cache.transpose(-2, -1)) # (B, H, 1, T)
attn = (attn_content + attn_position) / math.sqrt(self.d_head + self.d_n)
attn = F.softmax(attn, dim=-1)
# 5. 输出 (使用 W_UV 重建 V)
v = self.W_UV(c_kv_cache) # (B, T, n_heads * d_head)
v = v.view(B, -1, self.n_heads, self.d_head).transpose(1, 2) # (B, H, T, d_h)
out = torch.matmul(attn, v) # (B, H, 1, d_h)
out = out.transpose(1, 2).contiguous().view(B, 1, -1)
out = self.W_O(out)
return out, c_kv_new, k_r_new
6.2 KV Cache 对比实验
python
def compare_kv_cache_sizes():
"""对比不同注意力机制的 KV Cache 大小"""
configs = {
"MHA (128 heads)": {"n_heads": 128, "d_head": 128, "d_c": None, "d_n": None},
"GQA (8 groups)": {"n_heads": 128, "d_head": 128, "d_c": None, "d_n": None, "n_groups": 8},
"MQA (1 KV head)": {"n_heads": 128, "d_head": 128, "d_c": None, "d_n": None, "n_groups": 1},
"MLA (d_c=512)": {"n_heads": 128, "d_head": 128, "d_c": 512, "d_n": 64},
}
seq_len = 128000
n_layers = 64
bytes_per_elem = 2 # FP16
print("KV Cache 对比 (seq_len=128K, 64层, FP16)")
print("=" * 60)
for name, cfg in configs.items():
if cfg.get("d_c") is not None:
# MLA
cache_per_token = cfg["d_c"] + cfg["d_n"]
elif cfg.get("n_groups") is not None:
# GQA / MQA
cache_per_token = 2 * cfg["n_groups"] * cfg["d_head"]
else:
# MHA
cache_per_token = 2 * cfg["n_heads"] * cfg["d_head"]
total_gb = n_layers * seq_len * cache_per_token * bytes_per_elem / (1024**3)
print(f" {name:25s} | 每 token: {cache_per_token:>6d} | 总计: {total_gb:>8.1f} GB")
输出:
KV Cache 对比 (seq_len=128K, 64层, FP16)
============================================================
MHA (128 heads) | 每 token: 32768 | 总计: 512.0 GB
GQA (8 groups) | 每 token: 2048 | 总计: 32.0 GB
MQA (1 KV head) | 每 token: 256 | 总计: 4.0 GB
MLA (d_c=512) | 每 token: 576 | 总计: 9.0 GB
6.3 吸收技巧等价性验证
python
def verify_absorption_equivalence():
"""验证吸收技巧与标准 MHA 的数学等价性"""
torch.manual_seed(42)
B, L, D = 2, 16, 64
n_heads, d_head, d_c = 4, 16, 32
x = torch.randn(B, L, D)
# 创建 MLA 模块
config = MLAConfig(d_model=D, n_heads=n_heads, d_head=d_head, d_c=d_c, d_n=8)
mla = MultiHeadLatentAttention(config)
# 标准训练模式 (显式上投影)
out_train = mla.forward_train(x)
# 吸收模式
mla._compute_absorbed_weights()
# 手动验证吸收后的注意力分数
c_kv = mla.W_DKV(x)
k_full = mla.W_UK(c_kv)
q_full = mla.W_Q(x)
# 标准注意力分数
q_reshaped = q_full.view(B, L, n_heads, d_head).transpose(1, 2)
k_reshaped = k_full.view(B, L, n_heads, d_head).transpose(1, 2)
score_standard = torch.matmul(q_reshaped, k_reshaped.transpose(-2, -1))
# 吸收后注意力分数
q_absorbed = torch.einsum("bld,hcd->bhlc", x, mla.absorbed_W_Q)
score_absorbed = torch.matmul(q_absorbed, c_kv.unsqueeze(1).transpose(-2, -1))
max_diff = (score_standard - score_absorbed).abs().max().item()
print(f"吸收技巧等价性验证:")
print(f" 标准注意力分数 vs 吸收后分数")
print(f" 最大绝对误差: {max_diff:.6e}")
print(f" 数学等价: {max_diff < 1e-5}")
return max_diff < 1e-5
7. MLA 与其他方案的理论对比
7.1 压缩率对比
| 机制 | KV Cache / token | 压缩率 (vs MHA) | 是否有损 |
|---|---|---|---|
| MHA | 2 n h d h 2 n_h d_h 2nhdh | 1× | 无 |
| GQA | 2 n g d h 2 n_g d_h 2ngdh | n h / n g n_h / n_g nh/ng × | 有(头间共享) |
| MQA | 2 d h 2 d_h 2dh | n h n_h nh × | 有(单 KV 头) |
| MLA | d c + d n d_c + d_n dc+dn | 2 n h d h / ( d c + d n ) 2 n_h d_h / (d_c + d_n) 2nhdh/(dc+dn) × | 近似无损 |
7.2 表达能力分析
GQA/MQA 的问题:减少 KV 头数直接限制了注意力的表达能力。不同查询头被迫共享相同的 KV,无法学习多样化的注意力模式。
MLA 的优势 :保留所有查询头的独立性,KV 的"压缩"通过低秩投影实现,理论上当 d c ≥ rank ( K ) d_c \geq \text{rank}(\mathbf{K}) dc≥rank(K) 时完全无损。
7.3 MLA 的理论局限
局限 1:查询投影的计算量增加
吸收后,查询投影从 O ( d ⋅ d h ) O(d \cdot d_h) O(d⋅dh) 变为 O ( d ⋅ d c ) O(d \cdot d_c) O(d⋅dc),由于 d c > d h d_c > d_h dc>dh,计算量增加。
局限 2:低秩假设的有效性
如果 KV 矩阵的有效秩接近 n h d h n_h d_h nhdh,低秩压缩会导致信息损失。实践中,深层的 KV 冗余度更高,低秩假设更有效。
8. MLA 数学公式总结
╔══════════════════════════════════════════════════════════════════════════════════════════╗
║ MLA (Multi-head Latent Attention) 数学总结 ║
╠══════════════════════════════════════════════════════════════════════════════════════════╣
║ ║
║ 1. 低秩 KV 压缩: ║
║ c_t^{KV} = W_DKV · h_t (下投影, d → d_c) ║
║ k_t^C = W_UK · c_t^{KV} (上投影, d_c → n_h·d_h) ║
║ v_t = W_UV · c_t^{KV} (上投影, d_c → n_h·d_h) ║
║ KV Cache 只存 c_t^{KV} ∈ R^{d_c} ║
║ ║
║ 2. 吸收技巧: ║
║ score(t,s) = q_t^T · k_s = q̃_t^T · c_s^{KV} ║
║ 其中 q̃_t = W_UK^T · W_Q · h_t (吸收后查询, 直接与 c^{KV} 点积) ║
║ 推理时无需显式重建 k_s ║
║ ║
║ 3. 解耦 RoPE: ║
║ k_t = [k_t^C ; k_t^R] ║
║ k_t^C = W_UK · c_t^{KV} (内容键, 无 RoPE, 支持吸收) ║
║ k_t^R = RoPE(W_KR · h_t, t) (位置键, 有 RoPE, 独立缓存) ║
║ score = (q^C)^T·k^C + (q^R)^T·k^R ║
║ ║
║ 4. KV Cache 复杂度: ║
║ MHA: 2·n_h·d_h = 32768 per token ║
║ GQA: 2·n_g·d_h = 2048 (n_g=8) ║
║ MQA: 2·d_h = 256 ║
║ MLA: d_c + d_n = 512 + 64 = 576 per token ║
║ 压缩比: 32768 / 576 ≈ 57× ║
║ ║
║ 5. 推理时 KV Cache 存储: ║
║ 每 token 每层: c_t^{KV} ∈ R^{d_c} + k_t^R ∈ R^{d_n} ║
║ 总计: L_layers × seq_len × (d_c + d_n) × 2 bytes ║
║ DeepSeek-V2: 64 × 128K × 576 × 2 ≈ 9 GB (vs MHA 512 GB) ║
║ ║
║ 6. 低秩分解的信息论保证: ║
║ ‖K - K_k‖_F = √(Σ_{i>k} σ_i²) ║
║ 当 KV 矩阵有效秩 << n_h·d_h 时, 压缩近似无损 ║
║ ║
╚══════════════════════════════════════════════════════════════════════════════════════════╝
第二篇:DeepSeekMoE --- 无辅助损失负载均衡
1. 引言
混合专家模型(MoE)通过稀疏激活实现"大参数量、低计算量"的理想组合。然而,传统 MoE 的负载均衡问题一直是训练稳定性的隐患------如果路由不均匀,部分专家过载而其他专家闲置,导致训练崩溃。
DeepSeekMoE(DeepSeek-V3, 2024) 提出了一种无辅助损失的负载均衡方法:通过动态调整路由偏置项,在不干扰梯度的情况下实现完美的专家负载均衡。
2. MoE 负载均衡的数学背景
2.1 标准 MoE 路由
给定 token 表示 x ∈ R d \mathbf{x} \in \mathbb{R}^d x∈Rd,标准 MoE 的路由计算:
g i = Softmax ( x T e i ) , i = 1 , 2 , ... , N g_i = \text{Softmax}(\mathbf{x}^T \mathbf{e}_i), \quad i = 1, 2, \ldots, N gi=Softmax(xTei),i=1,2,...,N
选择 Top-K 个专家:
T = TopK ( { g 1 , g 2 , ... , g N } , K ) \mathcal{T} = \text{TopK}(\{g_1, g_2, \ldots, g_N\}, K) T=TopK({g1,g2,...,gN},K)
y = ∑ i ∈ T g i ⋅ Expert i ( x ) \mathbf{y} = \sum_{i \in \mathcal{T}} g_i \cdot \text{Expert}_i(\mathbf{x}) y=i∈T∑gi⋅Experti(x)
2.2 负载不均衡问题
定义专家 i i i 的负载为一个训练批次中被路由到该专家的 token 数量:
Load i = ∑ t = 1 B ⋅ L 1 i ∈ T t \text{Load}i = \sum{t=1}^{B \cdot L} \mathbb{1}i \\in \\mathcal{T}_t Loadi=t=1∑B⋅L1i∈Tt
理想状态 :所有专家负载相等, Load i = B ⋅ L ⋅ K N \text{Load}_i = \frac{B \cdot L \cdot K}{N} Loadi=NB⋅L⋅K。
实际问题 :训练过程中会出现赢者通吃现象------少数专家获得更多路由,进而获得更多梯度更新,变得更"强",形成正反馈循环。
2.3 传统解决方案:辅助损失
辅助负载均衡损失(Switch Transformer, GShard):
L aux = α ⋅ N ∑ i = 1 N f i ⋅ p i \mathcal{L}{\text{aux}} = \alpha \cdot N \sum{i=1}^{N} f_i \cdot p_i Laux=α⋅Ni=1∑Nfi⋅pi
其中:
f i = Load i B ⋅ L , p i = 1 B ⋅ L ∑ t g i ( t ) f_i = \frac{\text{Load}i}{B \cdot L}, \quad p_i = \frac{1}{B \cdot L} \sum{t} g_i^{(t)} fi=B⋅LLoadi,pi=B⋅L1t∑gi(t)
问题 :辅助损失与语言建模目标 L LM \mathcal{L}_{\text{LM}} LLM 存在冲突:
L = L LM + α L aux \mathcal{L} = \mathcal{L}{\text{LM}} + \alpha \mathcal{L}{\text{aux}} L=LLM+αLaux
- α \alpha α 太大:强制均匀路由,损害模型表达能力
- α \alpha α 太小:负载不均衡,训练不稳定
3. 无辅助损失负载均衡
3.1 核心思想
DeepSeekMoE 的关键洞察:将路由决策与梯度计算解耦。
引入偏置项 b = ( b 1 , b 2 , ... , b N ) \mathbf{b} = (b_1, b_2, \ldots, b_N) b=(b1,b2,...,bN),仅用于路由选择,不参与前向传播的梯度计算。
3.2 带偏置的路由
路由分数:
s i = Sigmoid ( x T e i + b i ) s_i = \text{Sigmoid}(\mathbf{x}^T \mathbf{e}_i + b_i) si=Sigmoid(xTei+bi)
注意:使用 Sigmoid 而非 Softmax,因为 Sigmoid 允许每个专家独立评分。
Top-K 选择 :基于 s i s_i si 选择专家,但前向传播时不使用 b i b_i bi:
T = TopK ( { s 1 , s 2 , ... , s N } , K ) \mathcal{T} = \text{TopK}(\{s_1, s_2, \ldots, s_N\}, K) T=TopK({s1,s2,...,sN},K)
g i = Sigmoid ( x T e i ) ∑ j ∈ T Sigmoid ( x T e j ) , i ∈ T g_i = \frac{\text{Sigmoid}(\mathbf{x}^T \mathbf{e}i)}{\sum{j \in \mathcal{T}} \text{Sigmoid}(\mathbf{x}^T \mathbf{e}_j)}, \quad i \in \mathcal{T} gi=∑j∈TSigmoid(xTej)Sigmoid(xTei),i∈T
y = ∑ i ∈ T g i ⋅ Expert i ( x ) \mathbf{y} = \sum_{i \in \mathcal{T}} g_i \cdot \text{Expert}_i(\mathbf{x}) y=i∈T∑gi⋅Experti(x)
关键 :门控权重 g i g_i gi 的计算不包含 b i b_i bi ,因此 b i b_i bi 不影响梯度。
3.3 偏置更新规则
在训练循环之外(如每 T T T 步),根据专家负载动态调整偏置:
b i ← { b i − γ if Load i > Target b i + γ if Load i < Target b_i \leftarrow \begin{cases} b_i - \gamma & \text{if } \text{Load}_i > \text{Target} \\ b_i + \gamma & \text{if } \text{Load}_i < \text{Target} \end{cases} bi←{bi−γbi+γif Loadi>Targetif Loadi<Target
其中 γ > 0 \gamma > 0 γ>0 是调整步长, Target = B ⋅ L ⋅ K N \text{Target} = \frac{B \cdot L \cdot K}{N} Target=NB⋅L⋅K 是理想负载。
3.4 收敛性分析
定理(非正式):在适当条件下,偏置更新规则收敛到均衡状态。
直觉:
- 当专家 i i i 过载时, b i b_i bi 减小 → 路由分数降低 → 更少 token 被路由到 i i i
- 当专家 i i i 欠载时, b i b_i bi 增大 → 路由分数升高 → 更多 token 被路由到 i i i
- 系统自动趋向 Load i = Target \text{Load}_i = \text{Target} Loadi=Target 的均衡状态
4. 完整可运行实现
4.1 DeepSeekMoE 实现
python
"""
DeepSeekMoE --- 无辅助损失负载均衡 --- 完整可运行实现
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Optional
from dataclasses import dataclass
@dataclass
class MoEConfig:
d_model: int = 2048
d_expert: int = 14336
n_experts: int = 64
n_shared: int = 2 # 共享专家数
top_k: int = 6 # 激活专家数
bias_update_gamma: float = 0.001
bias_update_interval: int = 100
class Expert(nn.Module):
"""单个专家 (SwiGLU 结构)"""
def __init__(self, d_model: int, d_expert: int):
super().__init__()
self.w_gate = nn.Linear(d_model, d_expert, bias=False)
self.w_up = nn.Linear(d_model, d_expert, bias=False)
self.w_down = nn.Linear(d_expert, d_model, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w_down(F.silu(self.w_gate(x)) * self.w_up(x))
class DeepSeekMoELayer(nn.Module):
"""DeepSeekMoE 层 --- 无辅助损失负载均衡"""
def __init__(self, config: MoEConfig):
super().__init__()
self.config = config
# 共享专家 (始终激活)
self.shared_experts = nn.ModuleList([
Expert(config.d_model, config.d_expert)
for _ in range(config.n_shared)
])
# 路由专家
self.routed_experts = nn.ModuleList([
Expert(config.d_model, config.d_expert)
for _ in range(config.n_experts)
])
# 路由器
self.gate = nn.Linear(config.d_model, config.n_experts, bias=False)
# 偏置项 (不参与梯度计算)
self.register_buffer(
"expert_bias",
torch.zeros(config.n_experts)
)
# 负载统计
self.register_buffer(
"expert_load",
torch.zeros(config.n_experts)
)
self.step_count = 0
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (B, L, D)
"""
B, L, D = x.shape
x_flat = x.view(-1, D) # (B*L, D)
N = x_flat.shape[0]
# 1. 共享专家输出
shared_out = sum(expert(x_flat) for expert in self.shared_experts)
# 2. 路由分数 (含偏置, 仅用于选择)
gate_logits = self.gate(x_flat) # (N, n_experts)
routing_scores = torch.sigmoid(gate_logits + self.expert_bias) # 含偏置
# 3. Top-K 选择
top_k_scores, top_k_indices = torch.topk(
routing_scores, self.config.top_k, dim=-1
) # (N, top_k)
# 4. 门控权重 (不含偏置!)
gate_weights_clean = torch.sigmoid(gate_logits) # 不含偏置
top_k_weights = gate_weights_clean.gather(1, top_k_indices) # (N, top_k)
top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-9)
# 5. 路由专家计算
routed_out = torch.zeros_like(x_flat)
for k in range(self.config.top_k):
expert_idx = top_k_indices[:, k] # (N,)
weight = top_k_weights[:, k] # (N,)
for e in range(self.config.n_experts):
mask = (expert_idx == e)
if mask.any():
expert_input = x_flat[mask]
expert_output = self.routed_experts[e](expert_input)
routed_out[mask] += weight[mask].unsqueeze(-1) * expert_output
# 6. 更新负载统计
with torch.no_grad():
for e in range(self.config.n_experts):
self.expert_load[e] += (top_k_indices == e).sum().item()
self.step_count += 1
# 定期更新偏置
if self.step_count % self.config.bias_update_interval == 0:
target_load = N * self.config.top_k / self.config.n_experts
gamma = self.config.bias_update_gamma
overloaded = self.expert_load > target_load * 1.1
underloaded = self.expert_load < target_load * 0.9
self.expert_bias[overloaded] -= gamma
self.expert_bias[underloaded] += gamma
# 重置统计
self.expert_load.zero_()
# 7. 合并共享专家和路由专家
return (shared_out + routed_out).view(B, L, D)
4.2 负载均衡对比实验
python
def compare_load_balancing():
"""对比辅助损失方法与无辅助损失方法的负载均衡效果"""
torch.manual_seed(42)
n_experts = 16
n_tokens = 1000
top_k = 2
# 模拟路由分数 (初始不均匀)
logits = torch.randn(n_tokens, n_experts)
logits[:, 0] += 2.0 # 专家 0 获得更多路由
def compute_load_imbalance(scores, top_k):
_, indices = torch.topk(scores, top_k, dim=-1)
loads = torch.zeros(n_experts)
for i in range(n_experts):
loads[i] = (indices == i).sum().item()
return loads.std().item() / (loads.mean().item() + 1e-8)
# 无偏置
imbalance_no_bias = compute_load_imbalance(logits, top_k)
# 辅助损失 (模拟均匀化效果)
aux_weight = 0.1
uniform_logits = logits - logits.mean(dim=-1, keepdim=True)
logits_aux = logits * (1 - aux_weight) + uniform_logits * aux_weight
imbalance_aux = compute_load_imbalance(logits_aux, top_k)
# 无辅助损失偏置方法
bias = torch.zeros(n_experts)
for _ in range(10):
scores = torch.sigmoid(logits + bias)
_, indices = torch.topk(scores, top_k, dim=-1)
loads = torch.zeros(n_experts)
for i in range(n_experts):
loads[i] = (indices == i).sum().item()
target = n_tokens * top_k / n_experts
bias[loads > target * 1.1] -= 0.05
bias[loads < target * 0.9] += 0.05
imbalance_bias = compute_load_imbalance(torch.sigmoid(logits + bias), top_k)
print("负载均衡对比:")
print(f" 无均衡: 负载标准差/均值 = {imbalance_no_bias:.4f}")
print(f" 辅助损失: 负载标准差/均值 = {imbalance_aux:.4f}")
print(f" 无辅助损失偏置: 负载标准差/均值 = {imbalance_bias:.4f}")
5. 无辅助损失负载均衡的数学公式总结
╔══════════════════════════════════════════════════════════════════════════════════════════╗
║ DeepSeekMoE 无辅助损失负载均衡 数学总结 ║
╠══════════════════════════════════════════════════════════════════════════════════════════╣
║ ║
║ 1. 带偏置的路由: ║
║ s_i = Sigmoid(x^T · e_i + b_i) (b_i 仅用于路由选择) ║
║ T = TopK({s_i}, K) ║
║ g_i = Sigmoid(x^T · e_i) / Σ_{j∈T} Sigmoid(x^T · e_j) (门控不含 b_i) ║
║ y = Σ_{i∈T} g_i · Expert_i(x) ║
║ ║
║ 2. 偏置更新 (训练循环外): ║
║ if Load_i > Target: b_i ← b_i - γ ║
║ if Load_i < Target: b_i ← b_i + γ ║
║ Target = B·L·K / N ║
║ ║
║ 3. 关键性质: ║
║ - b_i 不参与梯度计算 → 不干扰语言建模目标 ║
║ - b_i 仅影响路由选择 → 间接控制负载均衡 ║
║ - 偏置更新在训练循环外 → 不增加计算开销 ║
║ ║
║ 4. 对比辅助损失: ║
║ 辅助损失: L = L_LM + α·L_aux (目标冲突) ║
║ 无辅助损失: L = L_LM (纯语言建模, 偏置独立调整) ║
║ ║
╚══════════════════════════════════════════════════════════════════════════════════════════╝
第三篇:多 Token 预测(MTP)
1. 引言
标准语言模型每步只预测下一个 token,训练信号稀疏。多 Token 预测(Multi-Token Prediction, MTP) 让模型同时预测多个未来 token,提供更密集的训练信号,并可在推理时用于投机解码加速。
2. MTP 的数学框架
2.1 标准下一 Token 预测
标准自回归语言模型的训练目标:
L NTP = − ∑ t = 1 T log p θ ( x t ∣ x < t ) \mathcal{L}{\text{NTP}} = -\sum{t=1}^{T} \log p_\theta(x_t | x_{<t}) LNTP=−t=1∑Tlogpθ(xt∣x<t)
每个位置只提供一个训练信号。
2.2 多 Token 预测
MTP 在每个位置预测未来 D D D 个 token:
L MTP = ∑ k = 1 D λ k L CE ( k ) \mathcal{L}{\text{MTP}} = \sum{k=1}^{D} \lambda_k \mathcal{L}_{\text{CE}}^{(k)} LMTP=k=1∑DλkLCE(k)
L CE ( k ) = − ∑ t = 1 T − k log p θ ( k ) ( x t + k ∣ x ≤ t ) \mathcal{L}{\text{CE}}^{(k)} = -\sum{t=1}^{T-k} \log p_\theta^{(k)}(x_{t+k} | x_{\leq t}) LCE(k)=−t=1∑T−klogpθ(k)(xt+k∣x≤t)
其中 p θ ( k ) p_\theta^{(k)} pθ(k) 是第 k k k 个预测头的输出概率, λ k \lambda_k λk 是权重系数。
2.3 DeepSeek-V3 的 MTP 架构
DeepSeek-V3 使用顺序依赖的 MTP 模块:
h t ( k ) = TRM k ( Concat h t ( k − 1 ) , e t + k ) \mathbf{h}_t^{(k)} = \text{TRM}_k\left(\text{Concat}\\mathbf{h}_t\^{(k-1)}, \\mathbf{e}_{t+k}\right) ht(k)=TRMk(Concatht(k−1),et+k)
其中:
- h t ( k − 1 ) \mathbf{h}_t^{(k-1)} ht(k−1) 是上一层 MTP 模块的输出
- e t + k \mathbf{e}_{t+k} et+k 是第 t + k t+k t+k 个 token 的嵌入
- TRM k \text{TRM}_k TRMk 是共享的 Transformer 层
输出头共享 :所有 MTP 深度共享同一个输出投影 W out \mathbf{W}_{\text{out}} Wout:
p ( k ) ( x t + k ) = Softmax ( W out h t ( k ) ) p^{(k)}(x_{t+k}) = \text{Softmax}(\mathbf{W}_{\text{out}} \mathbf{h}_t^{(k)}) p(k)(xt+k)=Softmax(Woutht(k))
2.4 MTP 损失函数
L = L CE + λ ∑ k = 2 D L MTP ( k ) \mathcal{L} = \mathcal{L}{\text{CE}} + \lambda \sum{k=2}^{D} \mathcal{L}_{\text{MTP}}^{(k)} L=LCE+λk=2∑DLMTP(k)
DeepSeek-V3 中 D = 2 D = 2 D=2(预测当前和下一个 token), λ = 0.3 \lambda = 0.3 λ=0.3。
3. MTP 的理论优势
3.1 训练信号密度
标准 NTP:每个位置 1 个梯度信号
MTP( D = 2 D=2 D=2):每个位置 2 个梯度信号,信号密度提升 2×
3.2 表示学习的隐式正则化
MTP 迫使模型的隐藏表示同时编码多个未来 token 的信息,这隐式地鼓励了更丰富的语义表示。
形式化 :设 h t \mathbf{h}_t ht 为位置 t t t 的隐藏状态,则:
I ( h t ; x t + 1 , x t + 2 , ... , x t + D ) ≥ ∑ k = 1 D I ( h t ; x t + k ) I(\mathbf{h}t; x{t+1}, x_{t+2}, \ldots, x_{t+D}) \geq \sum_{k=1}^{D} I(\mathbf{h}t; x{t+k}) I(ht;xt+1,xt+2,...,xt+D)≥k=1∑DI(ht;xt+k)
MTP 通过多头预测,隐式最大化了 h t \mathbf{h}_t ht 与未来多个 token 的互信息。
3.3 投机解码加速
MTP 模块在推理时可用于投机解码(Speculative Decoding):
- 用 MTP 头快速预测未来 D D D 个 token
- 用主模型并行验证这 D D D 个 token
- 接受正确的预测,拒绝错误的
加速比 :DeepSeek-V3 报告约 1.8× 吞吐量提升。
4. 完整可运行实现
python
"""
Multi-Token Prediction (MTP) --- 完整可运行实现
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class MTPModule(nn.Module):
"""DeepSeek-V3 风格的 MTP 模块"""
def __init__(self, d_model: int, vocab_size: int, depth: int = 2):
super().__init__()
self.depth = depth
self.d_model = d_model
# 共享 Transformer 层
self.trm_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=8, dim_feedforward=d_model * 4,
batch_first=True, norm_first=True
)
# 投影层 (将拼接向量投影回 d_model)
self.projections = nn.ModuleList([
nn.Linear(d_model * 2, d_model) for _ in range(depth - 1)
])
# 共享输出头
self.output_head = nn.Linear(d_model, vocab_size, bias=False)
def forward(
self,
hidden_states: torch.Tensor, # (B, L, D)
token_embeddings: torch.Tensor, # (B, L+D, D)
targets: torch.Tensor, # (B, L+D)
labels: torch.Tensor = None, # (B, L+D) 标记有效位置
):
"""
计算 MTP 损失
"""
B, L, D = hidden_states.shape
losses = []
# 第 1 层: 标准下一 token 预测
logits_1 = self.output_head(hidden_states) # (B, L, V)
if labels is not None:
loss_1 = F.cross_entropy(
logits_1.view(-1, logits_1.size(-1)),
targets[:, :L].reshape(-1),
ignore_index=-100
)
else:
loss_1 = F.cross_entropy(
logits_1.view(-1, logits_1.size(-1)),
targets[:, :L].reshape(-1)
)
losses.append(loss_1)
# 第 2+ 层: MTP
current_hidden = hidden_states
for k in range(1, self.depth):
# 获取未来第 k 个 token 的嵌入
future_emb = token_embeddings[:, k:k+L] # (B, L, D)
# 拼接并投影
combined = torch.cat([current_hidden, future_emb], dim=-1) # (B, L, 2D)
projected = self.projections[k-1](combined) # (B, L, D)
# 通过 Transformer 层
current_hidden = self.trm_layer(projected) # (B, L, D)
# 预测
logits_k = self.output_head(current_hidden) # (B, L, V)
target_k = targets[:, k:k+L] # (B, L)
if labels is not None:
mask_k = labels[:, k:k+L]
loss_k = F.cross_entropy(
logits_k.view(-1, logits_k.size(-1)),
target_k.reshape(-1),
ignore_index=-100,
reduction='sum'
) / (mask_k.sum() + 1e-8)
else:
loss_k = F.cross_entropy(
logits_k.view(-1, logits_k.size(-1)),
target_k.reshape(-1)
)
losses.append(loss_k)
return losses
def demonstrate_mtp_training():
"""演示 MTP 训练"""
torch.manual_seed(42)
B, L, D, V = 4, 32, 128, 1000
depth = 2
hidden_states = torch.randn(B, L, D)
token_embeddings = torch.randn(B, L + depth, D)
targets = torch.randint(0, V, (B, L + depth))
mtp = MTPModule(d_model=D, vocab_size=V, depth=depth)
losses = mtp(hidden_states, token_embeddings, targets)
lambda_weight = 0.3
total_loss = losses[0] + lambda_weight * sum(losses[1:])
print("MTP 训练损失:")
print(f" 标准 NTP 损失: {losses[0].item():.4f}")
for k, l in enumerate(losses[1:], 2):
print(f" MTP-{k} 损失: {l.item():.4f}")
print(f" 总损失 (λ={lambda_weight}): {total_loss.item():.4f}")
5. MTP 数学公式总结
╔══════════════════════════════════════════════════════════════════════════════════════════╗
║ MTP (Multi-Token Prediction) 数学总结 ║
╠══════════════════════════════════════════════════════════════════════════════════════════╣
║ ║
║ 1. MTP 损失函数: ║
║ L = L_NTP + λ · Σ_{k=2}^{D} L_MTP^{(k)} ║
║ L_MTP^{(k)} = -Σ_t log p^{(k)}(x_{t+k} | x_{≤t}) ║
║ ║
║ 2. 顺序依赖架构: ║
║ h_t^{(k)} = TRM_k(Concat[h_t^{(k-1)}, e_{t+k}]) ║
║ p^{(k)}(x_{t+k}) = Softmax(W_out · h_t^{(k)}) ║
║ 共享 W_out → 参数开销最小 ║
║ ║
║ 3. 训练信号密度: ║
║ NTP: 每位置 1 个信号 ║
║ MTP (D=2): 每位置 2 个信号 → 密度 2× ║
║ ║
║ 4. 投机解码: ║
║ MTP 头快速预测 D 个 token → 主模型并行验证 → 1.8× 加速 ║
║ ║
║ 5. DeepSeek-V3 配置: ║
║ D = 2, λ = 0.3 ║
║ ║
╚══════════════════════════════════════════════════════════════════════════════════════════╝
参考文献
MLA
- DeepSeek-AI. (2024). DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model. arXiv:2405.04434.
- Vaswani, A., Shazeer, N., et al. (2017). Attention Is All You Need. NeurIPS 2017.
- Shazeer, N. (2019). Fast Transformer Decoding: One Write-Head is All You Need. arXiv:1911.02150.
- Ainslie, J., Lee-Thorp, J., et al. (2023). GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints. EMNLP 2023.
- Su, J., Lu, Y., et al. (2021). RoFormer: Enhanced Transformer with Rotary Position Embedding. arXiv:2104.09864.
DeepSeekMoE
- DeepSeek-AI. (2024). DeepSeek-V3 Technical Report. arXiv:2412.19437.
- Fedus, W., Zoph, B., & Shazeer, N. (2022). Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity. JMLR.
- Lepikhin, D., Lee, H., et al. (2021). GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding. ICLR 2021.
MTP
- Gloeckle, F., Badr, F., et al. (2024). Better & Faster Large Language Models via Multi-token Prediction. ICML 2024.
- Leviathan, Y., Kalman, M., & Matias, Y. (2023). Fast Inference from Transformers via Speculative Decoding. ICML 2023.