MLA — 多头潜在注意力深度解析

目录


第一篇:多头潜在注意力(MLA)

1. 引言

标准多头注意力(MHA)的 KV Cache 是大语言模型推理的核心瓶颈。对于一个具有 n h n_h nh 个注意力头、每头维度 d h d_h dh 的模型,每个 token 需要缓存的 KV 大小为 2 n h d h 2 n_h d_h 2nhdh。当序列长度达到数万 token 时,KV Cache 占用的显存远超模型参数本身。

MLA(Multi-head Latent Attention, DeepSeek-V2, 2024) 提出了一种革命性的 KV Cache 压缩方案:通过低秩分解将 KV 压缩到一个远小于原始维度的潜在向量中,推理时通过"吸收技巧"避免显式解压。这一设计实现了:

  1. KV Cache 压缩 57× :从每 token 2 n h d h = 32768 2 n_h d_h = 32768 2nhdh=32768 维压缩到 d c + d n = 576 d_c + d_n = 576 dc+dn=576 维
  2. 推理性能不降:通过矩阵吸收,数学上等价于标准 MHA
  3. 位置编码兼容:通过解耦 RoPE 机制,完美兼容旋转位置编码

2. 理论基础 --- 注意力机制的 KV Cache

2.1 标准多头注意力回顾

标准 MHA 的计算流程。对于第 i i i 个注意力头:

q t ( i ) = W Q ( i ) h t , k t ( i ) = W K ( i ) h t , v t ( i ) = W V ( i ) h t \mathbf{q}_t^{(i)} = \mathbf{W}_Q^{(i)} \mathbf{h}_t, \quad \mathbf{k}_t^{(i)} = \mathbf{W}_K^{(i)} \mathbf{h}_t, \quad \mathbf{v}_t^{(i)} = \mathbf{W}_V^{(i)} \mathbf{h}_t qt(i)=WQ(i)ht,kt(i)=WK(i)ht,vt(i)=WV(i)ht

Attn ( i ) ( q t , k ≤ t , v ≤ t ) = softmax ( ( q t ( i ) ) T k ≤ t ( i ) d h ) v ≤ t ( i ) \text{Attn}^{(i)}(\mathbf{q}t, \mathbf{k}{\leq t}, \mathbf{v}{\leq t}) = \text{softmax}\left(\frac{(\mathbf{q}t^{(i)})^T \mathbf{k}{\leq t}^{(i)}}{\sqrt{d_h}}\right) \mathbf{v}{\leq t}^{(i)} Attn(i)(qt,k≤t,v≤t)=softmax(dh (qt(i))Tk≤t(i))v≤t(i)

其中 h t ∈ R d \mathbf{h}_t \in \mathbb{R}^d ht∈Rd 是第 t t t 个 token 的隐藏状态。

2.2 KV Cache 的显存分析

在自回归推理中,每生成一个新 token,需要将当前层的 k t , v t \mathbf{k}_t, \mathbf{v}_t kt,vt 追加到缓存中。

模型 n h n_h nh d h d_h dh d = n h d h d = n_h d_h d=nhdh 每 token KV Cache
LLaMA-2-7B 32 128 4096 2 × 32 × 128 = 8192 2 \times 32 \times 128 = 8192 2×32×128=8192
LLaMA-2-70B 64 128 8192 2 × 64 × 128 = 16384 2 \times 64 \times 128 = 16384 2×64×128=16384
DeepSeek-V2 128 128 16384 2 × 128 × 128 = 32768 2 \times 128 \times 128 = 32768 2×128×128=32768

对于 DeepSeek-V2 的 128 头注意力,每 token 每层需要缓存 32768 维的 KV 向量。在 64 层模型、序列长度 128K 的情况下:

KV Cache 总量 = 64 × 128000 × 32768 × 2 bytes ≈ 537 GB \text{KV Cache 总量} = 64 \times 128000 \times 32768 \times 2 \text{ bytes} \approx 537 \text{ GB} KV Cache 总量=64×128000×32768×2 bytes≈537 GB

这远远超出了单张 GPU 的显存容量。

2.3 现有 KV Cache 压缩方案

GQA(Grouped-Query Attention):多个查询头共享同一组 KV 头。

KV Cache = 2 × n g × d h , n g < n h \text{KV Cache} = 2 \times n_g \times d_h, \quad n_g < n_h KV Cache=2×ng×dh,ng<nh

MQA(Multi-Query Attention):所有查询头共享单一 KV 头。

KV Cache = 2 × d h \text{KV Cache} = 2 \times d_h KV Cache=2×dh

局限性 :GQA/MQA 通过减少 KV 头数来压缩,但压缩率受限于组数,且共享 KV 会损失表达能力。


3. MLA 的核心创新 --- 低秩 KV 压缩

3.1 核心思想

MLA 的核心洞察:高维 KV 向量存在大量冗余,可以通过低秩投影压缩到一个紧凑的潜在空间

不同于 GQA/MQA 减少头数,MLA 保留所有头的信息,但通过下投影-上投影架构压缩 KV Cache:

h t ∈ R d ⏟ 隐藏状态 → W D K V c t K V ∈ R d c ⏟ 压缩潜在向量 → W U K , W U V k t , v t ⏟ 重建 KV \underbrace{\mathbf{h}t \in \mathbb{R}^d}{\text{隐藏状态}} \xrightarrow{\mathbf{W}{DKV}} \underbrace{\mathbf{c}t^{KV} \in \mathbb{R}^{d_c}}{\text{压缩潜在向量}} \xrightarrow{\mathbf{W}{UK}, \mathbf{W}_{UV}} \underbrace{\mathbf{k}_t, \mathbf{v}t}{\text{重建 KV}} 隐藏状态 ht∈RdWDKV 压缩潜在向量 ctKV∈RdcWUK,WUV 重建 KV kt,vt

3.2 数学形式化

下投影(压缩)

c t K V = W D K V h t ∈ R d c \mathbf{c}t^{KV} = \mathbf{W}{DKV} \mathbf{h}_t \in \mathbb{R}^{d_c} ctKV=WDKVht∈Rdc

其中 W D K V ∈ R d c × d \mathbf{W}_{DKV} \in \mathbb{R}^{d_c \times d} WDKV∈Rdc×d, d c ≪ n h d h d_c \ll n_h d_h dc≪nhdh。

上投影(重建)

k t C = W U K c t K V ∈ R n h d h \mathbf{k}t^C = \mathbf{W}{UK} \mathbf{c}_t^{KV} \in \mathbb{R}^{n_h d_h} ktC=WUKctKV∈Rnhdh

v t = W U V c t K V ∈ R n h d h \mathbf{v}t = \mathbf{W}{UV} \mathbf{c}_t^{KV} \in \mathbb{R}^{n_h d_h} vt=WUVctKV∈Rnhdh

其中 W U K ∈ R n h d h × d c \mathbf{W}{UK} \in \mathbb{R}^{n_h d_h \times d_c} WUK∈Rnhdh×dc, W U V ∈ R n h d h × d c \mathbf{W}{UV} \in \mathbb{R}^{n_h d_h \times d_c} WUV∈Rnhdh×dc。

KV Cache 只需存储 c t K V \mathbf{c}_t^{KV} ctKV ,而非完整的 k t , v t \mathbf{k}_t, \mathbf{v}_t kt,vt。

3.3 低秩分解的信息论解释

将 KV 矩阵视为数据矩阵 X = h 1 , h 2 , ... , h L T ∈ R L × d \mathbf{X} = \\mathbf{h}_1, \\mathbf{h}_2, \\ldots, \\mathbf{h}_L^T \in \mathbb{R}^{L \times d} X=h1,h2,...,hLT∈RL×d,则:

K = X W K T = X W D K V T W U K T \mathbf{K} = \mathbf{X} \mathbf{W}K^T = \mathbf{X} \mathbf{W}{DKV}^T \mathbf{W}_{UK}^T K=XWKT=XWDKVTWUKT

这等价于对 X \mathbf{X} X 做 PCA 降维 后再重建。当 d c d_c dc 足够大时,重建误差可以忽略。

定理 (Eckart-Young):对于秩为 r r r 的矩阵 M \mathbf{M} M,其最优秩- k k k 近似( k ≤ r k \leq r k≤r)为:

M k = ∑ i = 1 k σ i u i v i T \mathbf{M}k = \sum{i=1}^{k} \sigma_i \mathbf{u}_i \mathbf{v}_i^T Mk=i=1∑kσiuiviT

其中 σ i \sigma_i σi 是奇异值, u i , v i \mathbf{u}_i, \mathbf{v}_i ui,vi 是对应的奇异向量。近似误差为:

∥ M − M k ∥ F = ∑ i = k + 1 r σ i 2 \|\mathbf{M} - \mathbf{M}_k\|F = \sqrt{\sum{i=k+1}^{r} \sigma_i^2} ∥M−Mk∥F=i=k+1∑rσi2

当 KV 矩阵的有效秩远小于 n h d h n_h d_h nhdh 时(实践中通常如此),低秩压缩几乎无损。

3.4 为什么 KV 是低秩的?

直觉 1:语义冗余

相邻 token 的 K、V 向量高度相似(如 "the" "cat" "sat" 的语义渐变),存在大量线性相关性。

直觉 2:注意力稀疏

实际注意力矩阵通常是稀疏的(只有少数 token 对被高度关注),这意味着高维 KV 空间中的大部分维度是冗余的。

直觉 3:头间冗余

不同注意力头的 K、V 往往关注相似的模式,导致跨头的 KV 向量存在相关性。


4. 吸收技巧 --- 推理时的等价加速

4.1 问题:上投影的计算开销

如果在推理时显式执行上投影:

k t = W U K c t K V , v t = W U V c t K V \mathbf{k}t = \mathbf{W}{UK} \mathbf{c}_t^{KV}, \quad \mathbf{v}t = \mathbf{W}{UV} \mathbf{c}_t^{KV} kt=WUKctKV,vt=WUVctKV

则每生成一个新 token,需要对所有 历史位置的 c s K V \mathbf{c}_s^{KV} csKV 执行上投影来计算注意力分数。这使得压缩带来的显存节省被计算开销抵消。

4.2 吸收技巧的数学推导

MLA 的关键技巧:将上投影矩阵吸收到查询投影中

注意力分数的计算:

score ( t , s ) = ( q t ( i ) ) T k s ( i ) = ( W Q ( i ) h t ) T ( W U K ( i ) c s K V ) \text{score}(t, s) = (\mathbf{q}_t^{(i)})^T \mathbf{k}_s^{(i)} = (\mathbf{W}_Q^{(i)} \mathbf{h}t)^T (\mathbf{W}{UK}^{(i)} \mathbf{c}_s^{KV}) score(t,s)=(qt(i))Tks(i)=(WQ(i)ht)T(WUK(i)csKV)

= h t T ( W Q ( i ) ) T W U K ( i ) c s K V = \mathbf{h}_t^T (\mathbf{W}Q^{(i)})^T \mathbf{W}{UK}^{(i)} \mathbf{c}_s^{KV} =htT(WQ(i))TWUK(i)csKV

定义吸收后的查询投影

W ~ Q ( i ) = ( W U K ( i ) ) T W Q ( i ) ∈ R d c × d h \tilde{\mathbf{W}}Q^{(i)} = (\mathbf{W}{UK}^{(i)})^T \mathbf{W}_Q^{(i)} \in \mathbb{R}^{d_c \times d_h} W~Q(i)=(WUK(i))TWQ(i)∈Rdc×dh

则:

score ( t , s ) = ( W ~ Q ( i ) h t ) T c s K V = ( q ~ t ( i ) ) T c s K V \text{score}(t, s) = (\tilde{\mathbf{W}}_Q^{(i)} \mathbf{h}_t)^T \mathbf{c}_s^{KV} = (\tilde{\mathbf{q}}_t^{(i)})^T \mathbf{c}_s^{KV} score(t,s)=(W~Q(i)ht)TcsKV=(q~t(i))TcsKV

关键结果 :推理时只需存储 c s K V ∈ R d c \mathbf{c}_s^{KV} \in \mathbb{R}^{d_c} csKV∈Rdc,无需显式重建 k s \mathbf{k}_s ks。查询直接与压缩潜在向量做点积,数学上完全等价于标准 MHA。

4.3 吸收技巧的复杂度分析

操作 标准 MHA MLA(无吸收) MLA(吸收后)
KV Cache 2 n h d h 2 n_h d_h 2nhdh d c d_c dc d c d_c dc
查询投影 O ( d ⋅ d h ) O(d \cdot d_h) O(d⋅dh) O ( d ⋅ d h ) O(d \cdot d_h) O(d⋅dh) O ( d ⋅ d c ) O(d \cdot d_c) O(d⋅dc)
注意力分数 O ( d h ) O(d_h) O(dh) O ( d h ) O(d_h) O(dh) O ( d c ) O(d_c) O(dc)
输出聚合 O ( d h ) O(d_h) O(dh) O ( d h ) O(d_h) O(dh) O ( d h ) O(d_h) O(dh)

吸收后,注意力分数的计算从 O ( d h ) O(d_h) O(dh) 变为 O ( d c ) O(d_c) O(dc),但由于 d c > d h d_c > d_h dc>dh(DeepSeek-V2 中 d c = 512 , d h = 128 d_c = 512, d_h = 128 dc=512,dh=128),单次点积的计算量略有增加。然而,显存的巨大节省远超这点计算开销


5. 解耦 RoPE --- 位置编码的兼容性

5.1 RoPE 对吸收技巧的破坏

RoPE(Rotary Position Embedding)对 Q、K 施加位置相关的旋转:

RoPE ( x , t ) = ( x 1 cos ⁡ ( t θ 1 ) − x 2 sin ⁡ ( t θ 1 ) x 1 sin ⁡ ( t θ 1 ) + x 2 cos ⁡ ( t θ 1 ) ⋮ ) \text{RoPE}(\mathbf{x}, t) = \begin{pmatrix} x_1 \cos(t\theta_1) - x_2 \sin(t\theta_1) \\ x_1 \sin(t\theta_1) + x_2 \cos(t\theta_1) \\ \vdots \end{pmatrix} RoPE(x,t)= x1cos(tθ1)−x2sin(tθ1)x1sin(tθ1)+x2cos(tθ1)⋮

RoPE 后的注意力分数:

score ( t , s ) = ( RoPE ( q t , t ) ) T RoPE ( k s , s ) \text{score}(t, s) = (\text{RoPE}(\mathbf{q}_t, t))^T \text{RoPE}(\mathbf{k}_s, s) score(t,s)=(RoPE(qt,t))TRoPE(ks,s)

问题:RoPE 将位置信息乘入 Q、K,使得:

RoPE ( k s , s ) = RoPE ( W U K c s K V , s ) \text{RoPE}(\mathbf{k}s, s) = \text{RoPE}(\mathbf{W}{UK} \mathbf{c}_s^{KV}, s) RoPE(ks,s)=RoPE(WUKcsKV,s)

由于 RoPE 是非线性的(三角函数),无法将 W U K \mathbf{W}_{UK} WUK 吸收到查询投影中

RoPE ( W U K c s K V , s ) ≠ W U K RoPE ( c s K V , s ) \text{RoPE}(\mathbf{W}_{UK} \mathbf{c}s^{KV}, s) \neq \mathbf{W}{UK} \text{RoPE}(\mathbf{c}_s^{KV}, s) RoPE(WUKcsKV,s)=WUKRoPE(csKV,s)

5.2 解耦 RoPE 的设计

MLA 将 Key 分解为两个独立部分:

k t = k t C ⏟ 内容键 ; k t R ⏟ 位置键 \mathbf{k}_t = \\underbrace{\\mathbf{k}_t\^C}_{\\text{内容键}} ; \\underbrace{\\mathbf{k}_t\^R}_{\\text{位置键}} kt=内容键 ktC;位置键 ktR

q t = q t C ⏟ 内容查询 ; q t R ⏟ 位置查询 \mathbf{q}_t = \\underbrace{\\mathbf{q}_t\^C}_{\\text{内容查询}} ; \\underbrace{\\mathbf{q}_t\^R}_{\\text{位置查询}} qt=内容查询 qtC;位置查询 qtR

内容键(无 RoPE,支持吸收):

k t C = W U K c t K V ∈ R n h d h \mathbf{k}t^C = \mathbf{W}{UK} \mathbf{c}_t^{KV} \in \mathbb{R}^{n_h d_h} ktC=WUKctKV∈Rnhdh

位置键(有 RoPE,独立计算):

k t R = RoPE ( W K R h t , t ) ∈ R d n \mathbf{k}t^R = \text{RoPE}(\mathbf{W}{KR} \mathbf{h}_t, t) \in \mathbb{R}^{d_n} ktR=RoPE(WKRht,t)∈Rdn

其中 W K R ∈ R d n × d \mathbf{W}_{KR} \in \mathbb{R}^{d_n \times d} WKR∈Rdn×d, d n d_n dn 是位置编码维度(通常较小,如 64)。

5.3 解耦注意力分数

最终注意力分数为两部分之和:

score ( t , s ) = ( q t C ) T k s C + ( q t R ) T k s R \text{score}(t, s) = (\mathbf{q}_t^C)^T \mathbf{k}_s^C + (\mathbf{q}_t^R)^T \mathbf{k}_s^R score(t,s)=(qtC)TksC+(qtR)TksR

= ( q ~ t ) T c s K V + ( q t R ) T k s R = (\tilde{\mathbf{q}}_t)^T \mathbf{c}_s^{KV} + (\mathbf{q}_t^R)^T \mathbf{k}_s^R =(q~t)TcsKV+(qtR)TksR

  • 第一项 :使用吸收技巧,KV Cache 只需 c s K V ∈ R d c \mathbf{c}_s^{KV} \in \mathbb{R}^{d_c} csKV∈Rdc
  • 第二项 :位置键需要单独缓存 k s R ∈ R d n \mathbf{k}_s^R \in \mathbb{R}^{d_n} ksR∈Rdn,但 d n d_n dn 很小

总 KV Cache

MLA Cache = d c ⏟ 内容潜在向量 + d n ⏟ 位置键 = 512 + 64 = 576 \text{MLA Cache} = \underbrace{d_c}{\text{内容潜在向量}} + \underbrace{d_n}{\text{位置键}} = 512 + 64 = 576 MLA Cache=内容潜在向量 dc+位置键 dn=512+64=576

对比标准 MHA: 2 × 128 × 128 = 32768 2 \times 128 \times 128 = 32768 2×128×128=32768,压缩比约 57×

5.4 解耦 RoPE 的直觉理解

组件 编码内容 是否需要位置 是否支持吸收
k t C \mathbf{k}_t^C ktC 语义内容
k t R \mathbf{k}_t^R ktR 位置信息

直觉:语义信息是低秩的(可以压缩),位置信息是稀疏的(只需少量维度)。MLA 将两者解耦,分别用最优的方式处理。


6. 完整可运行实现

6.1 MLA 核心实现

python 复制代码
"""
Multi-head Latent Attention (MLA) --- 完整可运行实现
依赖: torch >= 2.0, numpy, matplotlib
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from dataclasses import dataclass
from typing import Optional, Tuple


@dataclass
class MLAConfig:
    """MLA 配置"""
    d_model: int = 2048          # 模型维度
    n_heads: int = 16            # 注意力头数
    d_head: int = 128            # 每头维度
    d_c: int = 512               # KV 压缩维度
    d_n: int = 64                # RoPE 维度
    rope_theta: float = 10000.0  # RoPE 基频
    max_seq_len: int = 4096      # 最大序列长度


class RotaryEmbedding(nn.Module):
    """旋转位置编码 (RoPE)"""

    def __init__(self, dim: int, theta: float = 10000.0, max_seq_len: int = 4096):
        super().__init__()
        self.dim = dim
        self.theta = theta

        inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

        t = torch.arange(max_seq_len, dtype=self.inv_freq.dtype)
        freqs = torch.outer(t, self.inv_freq)
        emb = torch.cat([freqs, freqs], dim=-1)
        self.register_buffer("cos_cached", emb.cos())
        self.register_buffer("sin_cached", emb.sin())

    def forward(self, x: torch.Tensor, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
        return (
            self.cos_cached[:seq_len].to(x.dtype),
            self.sin_cached[:seq_len].to(x.dtype),
        )


def rotate_half(x: torch.Tensor) -> torch.Tensor:
    """RoPE 的半旋转操作"""
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat([-x2, x1], dim=-1)


def apply_rope(
    q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
    """应用旋转位置编码"""
    q_rot = q * cos + rotate_half(q) * sin
    k_rot = k * cos + rotate_half(k) * sin
    return q_rot, k_rot


class MultiHeadLatentAttention(nn.Module):
    """多头潜在注意力 (MLA)"""

    def __init__(self, config: MLAConfig):
        super().__init__()
        self.config = config
        self.n_heads = config.n_heads
        self.d_head = config.d_head
        self.d_c = config.d_c
        self.d_n = config.d_n
        self.d_model = config.d_model

        # 查询投影
        self.W_Q = nn.Linear(config.d_model, config.n_heads * config.d_head, bias=False)

        # KV 压缩 (下投影)
        self.W_DKV = nn.Linear(config.d_model, config.d_c, bias=False)

        # KV 重建 (上投影)
        self.W_UK = nn.Linear(config.d_c, config.n_heads * config.d_head, bias=False)
        self.W_UV = nn.Linear(config.d_c, config.n_heads * config.d_head, bias=False)

        # 位置编码分支 (解耦 RoPE)
        self.W_KR = nn.Linear(config.d_model, config.d_n, bias=False)
        self.W_QR = nn.Linear(config.d_model, config.n_heads * config.d_n, bias=False)

        # 输出投影
        self.W_O = nn.Linear(config.n_heads * config.d_head, config.d_model, bias=False)

        # RoPE
        self.rope = RotaryEmbedding(config.d_n, config.rope_theta, config.max_seq_len)

        # 吸收后的查询投影 (预计算)
        self._absorbed = False

    def _compute_absorbed_weights(self):
        """预计算吸收后的查询权重"""
        # W_Q: (n_heads * d_head, d_model)
        # W_UK: (n_heads * d_head, d_c)
        # 吸收: W_Q' = W_UK^T @ W_Q, shape: (d_c, d_model) per head
        W_Q_reshaped = self.W_Q.weight.view(self.n_heads, self.d_head, self.d_model)
        W_UK_reshaped = self.W_UK.weight.view(self.n_heads, self.d_head, self.d_c)

        # 吸收后的查询投影: (n_heads, d_c, d_model)
        self.absorbed_W_Q = torch.bmm(W_UK_reshaped.transpose(1, 2), W_Q_reshaped)
        self._absorbed = True

    def forward_train(
        self,
        x: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        训练时前向传播 (显式上投影,便于理解)

        x: (B, L, D)
        """
        B, L, D = x.shape

        # 1. 查询投影
        q = self.W_Q(x)  # (B, L, n_heads * d_head)
        q = q.view(B, L, self.n_heads, self.d_head).transpose(1, 2)  # (B, H, L, d_h)

        # 2. KV 压缩
        c_kv = self.W_DKV(x)  # (B, L, d_c)

        # 3. KV 重建
        k = self.W_UK(c_kv)  # (B, L, n_heads * d_head)
        k = k.view(B, L, self.n_heads, self.d_head).transpose(1, 2)  # (B, H, L, d_h)
        v = self.W_UV(c_kv)  # (B, L, n_heads * d_head)
        v = v.view(B, L, self.n_heads, self.d_head).transpose(1, 2)  # (B, H, L, d_h)

        # 4. 位置编码分支
        q_r = self.W_QR(x)  # (B, L, n_heads * d_n)
        q_r = q_r.view(B, L, self.n_heads, self.d_n).transpose(1, 2)  # (B, H, L, d_n)
        k_r = self.W_KR(x)  # (B, L, d_n)

        # 应用 RoPE
        cos, sin = self.rope(x, L)
        q_r, k_r = apply_rope(q_r, k_r.unsqueeze(1), cos, sin)

        # 5. 计算注意力分数
        scale = 1.0 / math.sqrt(self.d_head + self.d_n)

        # 内容部分
        attn_content = torch.matmul(q, k.transpose(-2, -1))  # (B, H, L, L)
        # 位置部分
        attn_position = torch.matmul(q_r, k_r.transpose(-2, -1))  # (B, H, L, L)

        attn = (attn_content + attn_position) * scale

        if mask is not None:
            attn = attn.masked_fill(mask == 0, float("-inf"))

        attn = F.softmax(attn, dim=-1)

        # 6. 输出
        out = torch.matmul(attn, v)  # (B, H, L, d_h)
        out = out.transpose(1, 2).contiguous().view(B, L, -1)  # (B, L, n_heads * d_h)
        out = self.W_O(out)  # (B, L, D)

        return out

    def forward_inference(
        self,
        x_t: torch.Tensor,
        c_kv_cache: torch.Tensor,
        k_r_cache: torch.Tensor,
        position: int,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        推理时前向传播 (使用吸收技巧)

        x_t: (B, 1, D) 当前 token
        c_kv_cache: (B, T, d_c) 压缩 KV 缓存
        k_r_cache: (B, 1, T, d_n) 位置键缓存
        position: 当前位置
        """
        B = x_t.shape[0]

        # 1. 计算压缩潜在向量
        c_kv_new = self.W_DKV(x_t)  # (B, 1, d_c)

        # 2. 吸收后的查询投影
        if not self._absorbed:
            self._compute_absorbed_weights()

        # q_absorbed: (B, n_heads, 1, d_c)
        q_absorbed = torch.einsum("bld,hcd->bhlc", x_t, self.absorbed_W_Q)
        q_absorbed = q_absorbed.unsqueeze(2) if q_absorbed.dim() == 3 else q_absorbed

        # 3. 位置分支
        q_r = self.W_QR(x_t)  # (B, 1, n_heads * d_n)
        q_r = q_r.view(B, 1, self.n_heads, self.d_n).transpose(1, 2)  # (B, H, 1, d_n)
        k_r_new = self.W_KR(x_t)  # (B, 1, d_n)

        cos, sin = self.rope(x_t, position + 1)
        cos_t = cos[position:position+1].unsqueeze(0).unsqueeze(0)
        sin_t = sin[position:position+1].unsqueeze(0).unsqueeze(0)
        q_r = q_r * cos_t + rotate_half(q_r) * sin_t
        k_r_new = k_r_new * cos[position:position+1] + rotate_half(k_r_new) * sin[position:position+1]

        # 4. 注意力分数 (吸收后)
        # 内容部分: q_absorbed @ c_kv_cache^T
        attn_content = torch.matmul(q_absorbed, c_kv_cache.transpose(-2, -1))  # (B, H, 1, T)
        # 位置部分: q_r @ k_r_cache^T
        attn_position = torch.matmul(q_r, k_r_cache.transpose(-2, -1))  # (B, H, 1, T)

        attn = (attn_content + attn_position) / math.sqrt(self.d_head + self.d_n)
        attn = F.softmax(attn, dim=-1)

        # 5. 输出 (使用 W_UV 重建 V)
        v = self.W_UV(c_kv_cache)  # (B, T, n_heads * d_head)
        v = v.view(B, -1, self.n_heads, self.d_head).transpose(1, 2)  # (B, H, T, d_h)
        out = torch.matmul(attn, v)  # (B, H, 1, d_h)
        out = out.transpose(1, 2).contiguous().view(B, 1, -1)
        out = self.W_O(out)

        return out, c_kv_new, k_r_new

6.2 KV Cache 对比实验

python 复制代码
def compare_kv_cache_sizes():
    """对比不同注意力机制的 KV Cache 大小"""
    configs = {
        "MHA (128 heads)": {"n_heads": 128, "d_head": 128, "d_c": None, "d_n": None},
        "GQA (8 groups)":  {"n_heads": 128, "d_head": 128, "d_c": None, "d_n": None, "n_groups": 8},
        "MQA (1 KV head)": {"n_heads": 128, "d_head": 128, "d_c": None, "d_n": None, "n_groups": 1},
        "MLA (d_c=512)":   {"n_heads": 128, "d_head": 128, "d_c": 512, "d_n": 64},
    }

    seq_len = 128000
    n_layers = 64
    bytes_per_elem = 2  # FP16

    print("KV Cache 对比 (seq_len=128K, 64层, FP16)")
    print("=" * 60)

    for name, cfg in configs.items():
        if cfg.get("d_c") is not None:
            # MLA
            cache_per_token = cfg["d_c"] + cfg["d_n"]
        elif cfg.get("n_groups") is not None:
            # GQA / MQA
            cache_per_token = 2 * cfg["n_groups"] * cfg["d_head"]
        else:
            # MHA
            cache_per_token = 2 * cfg["n_heads"] * cfg["d_head"]

        total_gb = n_layers * seq_len * cache_per_token * bytes_per_elem / (1024**3)
        print(f"  {name:25s} | 每 token: {cache_per_token:>6d} | 总计: {total_gb:>8.1f} GB")

输出:

复制代码
KV Cache 对比 (seq_len=128K, 64层, FP16)
============================================================
  MHA (128 heads)           | 每 token:  32768 | 总计:    512.0 GB
  GQA (8 groups)            | 每 token:   2048 | 总计:     32.0 GB
  MQA (1 KV head)           | 每 token:    256 | 总计:      4.0 GB
  MLA (d_c=512)             | 每 token:    576 | 总计:      9.0 GB

6.3 吸收技巧等价性验证

python 复制代码
def verify_absorption_equivalence():
    """验证吸收技巧与标准 MHA 的数学等价性"""
    torch.manual_seed(42)

    B, L, D = 2, 16, 64
    n_heads, d_head, d_c = 4, 16, 32

    x = torch.randn(B, L, D)

    # 创建 MLA 模块
    config = MLAConfig(d_model=D, n_heads=n_heads, d_head=d_head, d_c=d_c, d_n=8)
    mla = MultiHeadLatentAttention(config)

    # 标准训练模式 (显式上投影)
    out_train = mla.forward_train(x)

    # 吸收模式
    mla._compute_absorbed_weights()

    # 手动验证吸收后的注意力分数
    c_kv = mla.W_DKV(x)
    k_full = mla.W_UK(c_kv)

    q_full = mla.W_Q(x)

    # 标准注意力分数
    q_reshaped = q_full.view(B, L, n_heads, d_head).transpose(1, 2)
    k_reshaped = k_full.view(B, L, n_heads, d_head).transpose(1, 2)
    score_standard = torch.matmul(q_reshaped, k_reshaped.transpose(-2, -1))

    # 吸收后注意力分数
    q_absorbed = torch.einsum("bld,hcd->bhlc", x, mla.absorbed_W_Q)
    score_absorbed = torch.matmul(q_absorbed, c_kv.unsqueeze(1).transpose(-2, -1))

    max_diff = (score_standard - score_absorbed).abs().max().item()
    print(f"吸收技巧等价性验证:")
    print(f"  标准注意力分数 vs 吸收后分数")
    print(f"  最大绝对误差: {max_diff:.6e}")
    print(f"  数学等价: {max_diff < 1e-5}")

    return max_diff < 1e-5

7. MLA 与其他方案的理论对比

7.1 压缩率对比

机制 KV Cache / token 压缩率 (vs MHA) 是否有损
MHA 2 n h d h 2 n_h d_h 2nhdh
GQA 2 n g d h 2 n_g d_h 2ngdh n h / n g n_h / n_g nh/ng × 有(头间共享)
MQA 2 d h 2 d_h 2dh n h n_h nh × 有(单 KV 头)
MLA d c + d n d_c + d_n dc+dn 2 n h d h / ( d c + d n ) 2 n_h d_h / (d_c + d_n) 2nhdh/(dc+dn) × 近似无损

7.2 表达能力分析

GQA/MQA 的问题:减少 KV 头数直接限制了注意力的表达能力。不同查询头被迫共享相同的 KV,无法学习多样化的注意力模式。

MLA 的优势 :保留所有查询头的独立性,KV 的"压缩"通过低秩投影实现,理论上当 d c ≥ rank ( K ) d_c \geq \text{rank}(\mathbf{K}) dc≥rank(K) 时完全无损。

7.3 MLA 的理论局限

局限 1:查询投影的计算量增加

吸收后,查询投影从 O ( d ⋅ d h ) O(d \cdot d_h) O(d⋅dh) 变为 O ( d ⋅ d c ) O(d \cdot d_c) O(d⋅dc),由于 d c > d h d_c > d_h dc>dh,计算量增加。

局限 2:低秩假设的有效性

如果 KV 矩阵的有效秩接近 n h d h n_h d_h nhdh,低秩压缩会导致信息损失。实践中,深层的 KV 冗余度更高,低秩假设更有效。


8. MLA 数学公式总结

复制代码
╔══════════════════════════════════════════════════════════════════════════════════════════╗
║                    MLA (Multi-head Latent Attention) 数学总结                             ║
╠══════════════════════════════════════════════════════════════════════════════════════════╣
║                                                                                        ║
║  1. 低秩 KV 压缩:                                                                      ║
║     c_t^{KV} = W_DKV · h_t              (下投影, d → d_c)                              ║
║     k_t^C = W_UK · c_t^{KV}            (上投影, d_c → n_h·d_h)                        ║
║     v_t   = W_UV · c_t^{KV}            (上投影, d_c → n_h·d_h)                        ║
║     KV Cache 只存 c_t^{KV} ∈ R^{d_c}                                                 ║
║                                                                                        ║
║  2. 吸收技巧:                                                                           ║
║     score(t,s) = q_t^T · k_s = q̃_t^T · c_s^{KV}                                      ║
║     其中 q̃_t = W_UK^T · W_Q · h_t    (吸收后查询, 直接与 c^{KV} 点积)                ║
║     推理时无需显式重建 k_s                                                             ║
║                                                                                        ║
║  3. 解耦 RoPE:                                                                         ║
║     k_t = [k_t^C ; k_t^R]                                                             ║
║     k_t^C = W_UK · c_t^{KV}          (内容键, 无 RoPE, 支持吸收)                      ║
║     k_t^R = RoPE(W_KR · h_t, t)      (位置键, 有 RoPE, 独立缓存)                      ║
║     score = (q^C)^T·k^C + (q^R)^T·k^R                                                ║
║                                                                                        ║
║  4. KV Cache 复杂度:                                                                   ║
║     MHA:  2·n_h·d_h = 32768 per token                                                 ║
║     GQA:  2·n_g·d_h = 2048 (n_g=8)                                                   ║
║     MQA:  2·d_h = 256                                                                  ║
║     MLA:  d_c + d_n = 512 + 64 = 576 per token                                        ║
║     压缩比: 32768 / 576 ≈ 57×                                                         ║
║                                                                                        ║
║  5. 推理时 KV Cache 存储:                                                              ║
║     每 token 每层: c_t^{KV} ∈ R^{d_c} + k_t^R ∈ R^{d_n}                              ║
║     总计: L_layers × seq_len × (d_c + d_n) × 2 bytes                                  ║
║     DeepSeek-V2: 64 × 128K × 576 × 2 ≈ 9 GB (vs MHA 512 GB)                          ║
║                                                                                        ║
║  6. 低秩分解的信息论保证:                                                               ║
║     ‖K - K_k‖_F = √(Σ_{i>k} σ_i²)                                                    ║
║     当 KV 矩阵有效秩 << n_h·d_h 时, 压缩近似无损                                     ║
║                                                                                        ║
╚══════════════════════════════════════════════════════════════════════════════════════════╝

第二篇:DeepSeekMoE --- 无辅助损失负载均衡

1. 引言

混合专家模型(MoE)通过稀疏激活实现"大参数量、低计算量"的理想组合。然而,传统 MoE 的负载均衡问题一直是训练稳定性的隐患------如果路由不均匀,部分专家过载而其他专家闲置,导致训练崩溃。

DeepSeekMoE(DeepSeek-V3, 2024) 提出了一种无辅助损失的负载均衡方法:通过动态调整路由偏置项,在不干扰梯度的情况下实现完美的专家负载均衡。


2. MoE 负载均衡的数学背景

2.1 标准 MoE 路由

给定 token 表示 x ∈ R d \mathbf{x} \in \mathbb{R}^d x∈Rd,标准 MoE 的路由计算:

g i = Softmax ( x T e i ) , i = 1 , 2 , ... , N g_i = \text{Softmax}(\mathbf{x}^T \mathbf{e}_i), \quad i = 1, 2, \ldots, N gi=Softmax(xTei),i=1,2,...,N

选择 Top-K 个专家:

T = TopK ( { g 1 , g 2 , ... , g N } , K ) \mathcal{T} = \text{TopK}(\{g_1, g_2, \ldots, g_N\}, K) T=TopK({g1,g2,...,gN},K)

y = ∑ i ∈ T g i ⋅ Expert i ( x ) \mathbf{y} = \sum_{i \in \mathcal{T}} g_i \cdot \text{Expert}_i(\mathbf{x}) y=i∈T∑gi⋅Experti(x)

2.2 负载不均衡问题

定义专家 i i i 的负载为一个训练批次中被路由到该专家的 token 数量:

Load i = ∑ t = 1 B ⋅ L 1 i ∈ T t \text{Load}i = \sum{t=1}^{B \cdot L} \mathbb{1}i \\in \\mathcal{T}_t Loadi=t=1∑B⋅L1i∈Tt

理想状态 :所有专家负载相等, Load i = B ⋅ L ⋅ K N \text{Load}_i = \frac{B \cdot L \cdot K}{N} Loadi=NB⋅L⋅K。

实际问题 :训练过程中会出现赢者通吃现象------少数专家获得更多路由,进而获得更多梯度更新,变得更"强",形成正反馈循环。

2.3 传统解决方案:辅助损失

辅助负载均衡损失(Switch Transformer, GShard):

L aux = α ⋅ N ∑ i = 1 N f i ⋅ p i \mathcal{L}{\text{aux}} = \alpha \cdot N \sum{i=1}^{N} f_i \cdot p_i Laux=α⋅Ni=1∑Nfi⋅pi

其中:

f i = Load i B ⋅ L , p i = 1 B ⋅ L ∑ t g i ( t ) f_i = \frac{\text{Load}i}{B \cdot L}, \quad p_i = \frac{1}{B \cdot L} \sum{t} g_i^{(t)} fi=B⋅LLoadi,pi=B⋅L1t∑gi(t)

问题 :辅助损失与语言建模目标 L LM \mathcal{L}_{\text{LM}} LLM 存在冲突:

L = L LM + α L aux \mathcal{L} = \mathcal{L}{\text{LM}} + \alpha \mathcal{L}{\text{aux}} L=LLM+αLaux

  • α \alpha α 太大:强制均匀路由,损害模型表达能力
  • α \alpha α 太小:负载不均衡,训练不稳定

3. 无辅助损失负载均衡

3.1 核心思想

DeepSeekMoE 的关键洞察:将路由决策与梯度计算解耦

引入偏置项 b = ( b 1 , b 2 , ... , b N ) \mathbf{b} = (b_1, b_2, \ldots, b_N) b=(b1,b2,...,bN),仅用于路由选择,不参与前向传播的梯度计算

3.2 带偏置的路由

路由分数

s i = Sigmoid ( x T e i + b i ) s_i = \text{Sigmoid}(\mathbf{x}^T \mathbf{e}_i + b_i) si=Sigmoid(xTei+bi)

注意:使用 Sigmoid 而非 Softmax,因为 Sigmoid 允许每个专家独立评分。

Top-K 选择 :基于 s i s_i si 选择专家,但前向传播时不使用 b i b_i bi

T = TopK ( { s 1 , s 2 , ... , s N } , K ) \mathcal{T} = \text{TopK}(\{s_1, s_2, \ldots, s_N\}, K) T=TopK({s1,s2,...,sN},K)

g i = Sigmoid ( x T e i ) ∑ j ∈ T Sigmoid ( x T e j ) , i ∈ T g_i = \frac{\text{Sigmoid}(\mathbf{x}^T \mathbf{e}i)}{\sum{j \in \mathcal{T}} \text{Sigmoid}(\mathbf{x}^T \mathbf{e}_j)}, \quad i \in \mathcal{T} gi=∑j∈TSigmoid(xTej)Sigmoid(xTei),i∈T

y = ∑ i ∈ T g i ⋅ Expert i ( x ) \mathbf{y} = \sum_{i \in \mathcal{T}} g_i \cdot \text{Expert}_i(\mathbf{x}) y=i∈T∑gi⋅Experti(x)

关键 :门控权重 g i g_i gi 的计算不包含 b i b_i bi ,因此 b i b_i bi 不影响梯度。

3.3 偏置更新规则

在训练循环之外(如每 T T T 步),根据专家负载动态调整偏置:

b i ← { b i − γ if Load i > Target b i + γ if Load i < Target b_i \leftarrow \begin{cases} b_i - \gamma & \text{if } \text{Load}_i > \text{Target} \\ b_i + \gamma & \text{if } \text{Load}_i < \text{Target} \end{cases} bi←{bi−γbi+γif Loadi>Targetif Loadi<Target

其中 γ > 0 \gamma > 0 γ>0 是调整步长, Target = B ⋅ L ⋅ K N \text{Target} = \frac{B \cdot L \cdot K}{N} Target=NB⋅L⋅K 是理想负载。

3.4 收敛性分析

定理(非正式):在适当条件下,偏置更新规则收敛到均衡状态。

直觉

  • 当专家 i i i 过载时, b i b_i bi 减小 → 路由分数降低 → 更少 token 被路由到 i i i
  • 当专家 i i i 欠载时, b i b_i bi 增大 → 路由分数升高 → 更多 token 被路由到 i i i
  • 系统自动趋向 Load i = Target \text{Load}_i = \text{Target} Loadi=Target 的均衡状态

4. 完整可运行实现

4.1 DeepSeekMoE 实现

python 复制代码
"""
DeepSeekMoE --- 无辅助损失负载均衡 --- 完整可运行实现
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Optional
from dataclasses import dataclass


@dataclass
class MoEConfig:
    d_model: int = 2048
    d_expert: int = 14336
    n_experts: int = 64
    n_shared: int = 2       # 共享专家数
    top_k: int = 6           # 激活专家数
    bias_update_gamma: float = 0.001
    bias_update_interval: int = 100


class Expert(nn.Module):
    """单个专家 (SwiGLU 结构)"""
    def __init__(self, d_model: int, d_expert: int):
        super().__init__()
        self.w_gate = nn.Linear(d_model, d_expert, bias=False)
        self.w_up = nn.Linear(d_model, d_expert, bias=False)
        self.w_down = nn.Linear(d_expert, d_model, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.w_down(F.silu(self.w_gate(x)) * self.w_up(x))


class DeepSeekMoELayer(nn.Module):
    """DeepSeekMoE 层 --- 无辅助损失负载均衡"""

    def __init__(self, config: MoEConfig):
        super().__init__()
        self.config = config

        # 共享专家 (始终激活)
        self.shared_experts = nn.ModuleList([
            Expert(config.d_model, config.d_expert)
            for _ in range(config.n_shared)
        ])

        # 路由专家
        self.routed_experts = nn.ModuleList([
            Expert(config.d_model, config.d_expert)
            for _ in range(config.n_experts)
        ])

        # 路由器
        self.gate = nn.Linear(config.d_model, config.n_experts, bias=False)

        # 偏置项 (不参与梯度计算)
        self.register_buffer(
            "expert_bias",
            torch.zeros(config.n_experts)
        )

        # 负载统计
        self.register_buffer(
            "expert_load",
            torch.zeros(config.n_experts)
        )
        self.step_count = 0

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B, L, D)
        """
        B, L, D = x.shape
        x_flat = x.view(-1, D)  # (B*L, D)
        N = x_flat.shape[0]

        # 1. 共享专家输出
        shared_out = sum(expert(x_flat) for expert in self.shared_experts)

        # 2. 路由分数 (含偏置, 仅用于选择)
        gate_logits = self.gate(x_flat)  # (N, n_experts)
        routing_scores = torch.sigmoid(gate_logits + self.expert_bias)  # 含偏置

        # 3. Top-K 选择
        top_k_scores, top_k_indices = torch.topk(
            routing_scores, self.config.top_k, dim=-1
        )  # (N, top_k)

        # 4. 门控权重 (不含偏置!)
        gate_weights_clean = torch.sigmoid(gate_logits)  # 不含偏置
        top_k_weights = gate_weights_clean.gather(1, top_k_indices)  # (N, top_k)
        top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-9)

        # 5. 路由专家计算
        routed_out = torch.zeros_like(x_flat)
        for k in range(self.config.top_k):
            expert_idx = top_k_indices[:, k]  # (N,)
            weight = top_k_weights[:, k]  # (N,)

            for e in range(self.config.n_experts):
                mask = (expert_idx == e)
                if mask.any():
                    expert_input = x_flat[mask]
                    expert_output = self.routed_experts[e](expert_input)
                    routed_out[mask] += weight[mask].unsqueeze(-1) * expert_output

        # 6. 更新负载统计
        with torch.no_grad():
            for e in range(self.config.n_experts):
                self.expert_load[e] += (top_k_indices == e).sum().item()
            self.step_count += 1

            # 定期更新偏置
            if self.step_count % self.config.bias_update_interval == 0:
                target_load = N * self.config.top_k / self.config.n_experts
                gamma = self.config.bias_update_gamma

                overloaded = self.expert_load > target_load * 1.1
                underloaded = self.expert_load < target_load * 0.9

                self.expert_bias[overloaded] -= gamma
                self.expert_bias[underloaded] += gamma

                # 重置统计
                self.expert_load.zero_()

        # 7. 合并共享专家和路由专家
        return (shared_out + routed_out).view(B, L, D)

4.2 负载均衡对比实验

python 复制代码
def compare_load_balancing():
    """对比辅助损失方法与无辅助损失方法的负载均衡效果"""
    torch.manual_seed(42)

    n_experts = 16
    n_tokens = 1000
    top_k = 2

    # 模拟路由分数 (初始不均匀)
    logits = torch.randn(n_tokens, n_experts)
    logits[:, 0] += 2.0  # 专家 0 获得更多路由

    def compute_load_imbalance(scores, top_k):
        _, indices = torch.topk(scores, top_k, dim=-1)
        loads = torch.zeros(n_experts)
        for i in range(n_experts):
            loads[i] = (indices == i).sum().item()
        return loads.std().item() / (loads.mean().item() + 1e-8)

    # 无偏置
    imbalance_no_bias = compute_load_imbalance(logits, top_k)

    # 辅助损失 (模拟均匀化效果)
    aux_weight = 0.1
    uniform_logits = logits - logits.mean(dim=-1, keepdim=True)
    logits_aux = logits * (1 - aux_weight) + uniform_logits * aux_weight
    imbalance_aux = compute_load_imbalance(logits_aux, top_k)

    # 无辅助损失偏置方法
    bias = torch.zeros(n_experts)
    for _ in range(10):
        scores = torch.sigmoid(logits + bias)
        _, indices = torch.topk(scores, top_k, dim=-1)
        loads = torch.zeros(n_experts)
        for i in range(n_experts):
            loads[i] = (indices == i).sum().item()
        target = n_tokens * top_k / n_experts
        bias[loads > target * 1.1] -= 0.05
        bias[loads < target * 0.9] += 0.05

    imbalance_bias = compute_load_imbalance(torch.sigmoid(logits + bias), top_k)

    print("负载均衡对比:")
    print(f"  无均衡:           负载标准差/均值 = {imbalance_no_bias:.4f}")
    print(f"  辅助损失:         负载标准差/均值 = {imbalance_aux:.4f}")
    print(f"  无辅助损失偏置:   负载标准差/均值 = {imbalance_bias:.4f}")

5. 无辅助损失负载均衡的数学公式总结

复制代码
╔══════════════════════════════════════════════════════════════════════════════════════════╗
║                    DeepSeekMoE 无辅助损失负载均衡 数学总结                                ║
╠══════════════════════════════════════════════════════════════════════════════════════════╣
║                                                                                        ║
║  1. 带偏置的路由:                                                                       ║
║     s_i = Sigmoid(x^T · e_i + b_i)     (b_i 仅用于路由选择)                            ║
║     T = TopK({s_i}, K)                                                                 ║
║     g_i = Sigmoid(x^T · e_i) / Σ_{j∈T} Sigmoid(x^T · e_j)   (门控不含 b_i)           ║
║     y = Σ_{i∈T} g_i · Expert_i(x)                                                     ║
║                                                                                        ║
║  2. 偏置更新 (训练循环外):                                                              ║
║     if Load_i > Target:  b_i ← b_i - γ                                                ║
║     if Load_i < Target:  b_i ← b_i + γ                                                ║
║     Target = B·L·K / N                                                                 ║
║                                                                                        ║
║  3. 关键性质:                                                                           ║
║     - b_i 不参与梯度计算 → 不干扰语言建模目标                                           ║
║     - b_i 仅影响路由选择 → 间接控制负载均衡                                             ║
║     - 偏置更新在训练循环外 → 不增加计算开销                                             ║
║                                                                                        ║
║  4. 对比辅助损失:                                                                       ║
║     辅助损失: L = L_LM + α·L_aux  (目标冲突)                                           ║
║     无辅助损失: L = L_LM  (纯语言建模, 偏置独立调整)                                    ║
║                                                                                        ║
╚══════════════════════════════════════════════════════════════════════════════════════════╝

第三篇:多 Token 预测(MTP)

1. 引言

标准语言模型每步只预测下一个 token,训练信号稀疏。多 Token 预测(Multi-Token Prediction, MTP) 让模型同时预测多个未来 token,提供更密集的训练信号,并可在推理时用于投机解码加速。


2. MTP 的数学框架

2.1 标准下一 Token 预测

标准自回归语言模型的训练目标:

L NTP = − ∑ t = 1 T log ⁡ p θ ( x t ∣ x < t ) \mathcal{L}{\text{NTP}} = -\sum{t=1}^{T} \log p_\theta(x_t | x_{<t}) LNTP=−t=1∑Tlogpθ(xt∣x<t)

每个位置只提供一个训练信号。

2.2 多 Token 预测

MTP 在每个位置预测未来 D D D 个 token:

L MTP = ∑ k = 1 D λ k L CE ( k ) \mathcal{L}{\text{MTP}} = \sum{k=1}^{D} \lambda_k \mathcal{L}_{\text{CE}}^{(k)} LMTP=k=1∑DλkLCE(k)

L CE ( k ) = − ∑ t = 1 T − k log ⁡ p θ ( k ) ( x t + k ∣ x ≤ t ) \mathcal{L}{\text{CE}}^{(k)} = -\sum{t=1}^{T-k} \log p_\theta^{(k)}(x_{t+k} | x_{\leq t}) LCE(k)=−t=1∑T−klogpθ(k)(xt+k∣x≤t)

其中 p θ ( k ) p_\theta^{(k)} pθ(k) 是第 k k k 个预测头的输出概率, λ k \lambda_k λk 是权重系数。

2.3 DeepSeek-V3 的 MTP 架构

DeepSeek-V3 使用顺序依赖的 MTP 模块:

h t ( k ) = TRM k ( Concat h t ( k − 1 ) , e t + k ) \mathbf{h}_t^{(k)} = \text{TRM}_k\left(\text{Concat}\\mathbf{h}_t\^{(k-1)}, \\mathbf{e}_{t+k}\right) ht(k)=TRMk(Concatht(k−1),et+k)

其中:

  • h t ( k − 1 ) \mathbf{h}_t^{(k-1)} ht(k−1) 是上一层 MTP 模块的输出
  • e t + k \mathbf{e}_{t+k} et+k 是第 t + k t+k t+k 个 token 的嵌入
  • TRM k \text{TRM}_k TRMk 是共享的 Transformer 层

输出头共享 :所有 MTP 深度共享同一个输出投影 W out \mathbf{W}_{\text{out}} Wout:

p ( k ) ( x t + k ) = Softmax ( W out h t ( k ) ) p^{(k)}(x_{t+k}) = \text{Softmax}(\mathbf{W}_{\text{out}} \mathbf{h}_t^{(k)}) p(k)(xt+k)=Softmax(Woutht(k))

2.4 MTP 损失函数

L = L CE + λ ∑ k = 2 D L MTP ( k ) \mathcal{L} = \mathcal{L}{\text{CE}} + \lambda \sum{k=2}^{D} \mathcal{L}_{\text{MTP}}^{(k)} L=LCE+λk=2∑DLMTP(k)

DeepSeek-V3 中 D = 2 D = 2 D=2(预测当前和下一个 token), λ = 0.3 \lambda = 0.3 λ=0.3。


3. MTP 的理论优势

3.1 训练信号密度

标准 NTP:每个位置 1 个梯度信号

MTP( D = 2 D=2 D=2):每个位置 2 个梯度信号,信号密度提升 2×

3.2 表示学习的隐式正则化

MTP 迫使模型的隐藏表示同时编码多个未来 token 的信息,这隐式地鼓励了更丰富的语义表示

形式化 :设 h t \mathbf{h}_t ht 为位置 t t t 的隐藏状态,则:

I ( h t ; x t + 1 , x t + 2 , ... , x t + D ) ≥ ∑ k = 1 D I ( h t ; x t + k ) I(\mathbf{h}t; x{t+1}, x_{t+2}, \ldots, x_{t+D}) \geq \sum_{k=1}^{D} I(\mathbf{h}t; x{t+k}) I(ht;xt+1,xt+2,...,xt+D)≥k=1∑DI(ht;xt+k)

MTP 通过多头预测,隐式最大化了 h t \mathbf{h}_t ht 与未来多个 token 的互信息。

3.3 投机解码加速

MTP 模块在推理时可用于投机解码(Speculative Decoding)

  1. 用 MTP 头快速预测未来 D D D 个 token
  2. 用主模型并行验证这 D D D 个 token
  3. 接受正确的预测,拒绝错误的

加速比 :DeepSeek-V3 报告约 1.8× 吞吐量提升


4. 完整可运行实现

python 复制代码
"""
Multi-Token Prediction (MTP) --- 完整可运行实现
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class MTPModule(nn.Module):
    """DeepSeek-V3 风格的 MTP 模块"""

    def __init__(self, d_model: int, vocab_size: int, depth: int = 2):
        super().__init__()
        self.depth = depth
        self.d_model = d_model

        # 共享 Transformer 层
        self.trm_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=8, dim_feedforward=d_model * 4,
            batch_first=True, norm_first=True
        )

        # 投影层 (将拼接向量投影回 d_model)
        self.projections = nn.ModuleList([
            nn.Linear(d_model * 2, d_model) for _ in range(depth - 1)
        ])

        # 共享输出头
        self.output_head = nn.Linear(d_model, vocab_size, bias=False)

    def forward(
        self,
        hidden_states: torch.Tensor,  # (B, L, D)
        token_embeddings: torch.Tensor,  # (B, L+D, D)
        targets: torch.Tensor,  # (B, L+D)
        labels: torch.Tensor = None,  # (B, L+D) 标记有效位置
    ):
        """
        计算 MTP 损失
        """
        B, L, D = hidden_states.shape
        losses = []

        # 第 1 层: 标准下一 token 预测
        logits_1 = self.output_head(hidden_states)  # (B, L, V)
        if labels is not None:
            loss_1 = F.cross_entropy(
                logits_1.view(-1, logits_1.size(-1)),
                targets[:, :L].reshape(-1),
                ignore_index=-100
            )
        else:
            loss_1 = F.cross_entropy(
                logits_1.view(-1, logits_1.size(-1)),
                targets[:, :L].reshape(-1)
            )
        losses.append(loss_1)

        # 第 2+ 层: MTP
        current_hidden = hidden_states
        for k in range(1, self.depth):
            # 获取未来第 k 个 token 的嵌入
            future_emb = token_embeddings[:, k:k+L]  # (B, L, D)

            # 拼接并投影
            combined = torch.cat([current_hidden, future_emb], dim=-1)  # (B, L, 2D)
            projected = self.projections[k-1](combined)  # (B, L, D)

            # 通过 Transformer 层
            current_hidden = self.trm_layer(projected)  # (B, L, D)

            # 预测
            logits_k = self.output_head(current_hidden)  # (B, L, V)
            target_k = targets[:, k:k+L]  # (B, L)

            if labels is not None:
                mask_k = labels[:, k:k+L]
                loss_k = F.cross_entropy(
                    logits_k.view(-1, logits_k.size(-1)),
                    target_k.reshape(-1),
                    ignore_index=-100,
                    reduction='sum'
                ) / (mask_k.sum() + 1e-8)
            else:
                loss_k = F.cross_entropy(
                    logits_k.view(-1, logits_k.size(-1)),
                    target_k.reshape(-1)
                )
            losses.append(loss_k)

        return losses


def demonstrate_mtp_training():
    """演示 MTP 训练"""
    torch.manual_seed(42)

    B, L, D, V = 4, 32, 128, 1000
    depth = 2

    hidden_states = torch.randn(B, L, D)
    token_embeddings = torch.randn(B, L + depth, D)
    targets = torch.randint(0, V, (B, L + depth))

    mtp = MTPModule(d_model=D, vocab_size=V, depth=depth)
    losses = mtp(hidden_states, token_embeddings, targets)

    lambda_weight = 0.3
    total_loss = losses[0] + lambda_weight * sum(losses[1:])

    print("MTP 训练损失:")
    print(f"  标准 NTP 损失: {losses[0].item():.4f}")
    for k, l in enumerate(losses[1:], 2):
        print(f"  MTP-{k} 损失:   {l.item():.4f}")
    print(f"  总损失 (λ={lambda_weight}): {total_loss.item():.4f}")

5. MTP 数学公式总结

复制代码
╔══════════════════════════════════════════════════════════════════════════════════════════╗
║                    MTP (Multi-Token Prediction) 数学总结                                  ║
╠══════════════════════════════════════════════════════════════════════════════════════════╣
║                                                                                        ║
║  1. MTP 损失函数:                                                                       ║
║     L = L_NTP + λ · Σ_{k=2}^{D} L_MTP^{(k)}                                          ║
║     L_MTP^{(k)} = -Σ_t log p^{(k)}(x_{t+k} | x_{≤t})                                ║
║                                                                                        ║
║  2. 顺序依赖架构:                                                                       ║
║     h_t^{(k)} = TRM_k(Concat[h_t^{(k-1)}, e_{t+k}])                                  ║
║     p^{(k)}(x_{t+k}) = Softmax(W_out · h_t^{(k)})                                    ║
║     共享 W_out → 参数开销最小                                                           ║
║                                                                                        ║
║  3. 训练信号密度:                                                                       ║
║     NTP: 每位置 1 个信号                                                                ║
║     MTP (D=2): 每位置 2 个信号 → 密度 2×                                               ║
║                                                                                        ║
║  4. 投机解码:                                                                           ║
║     MTP 头快速预测 D 个 token → 主模型并行验证 → 1.8× 加速                             ║
║                                                                                        ║
║  5. DeepSeek-V3 配置:                                                                   ║
║     D = 2, λ = 0.3                                                                     ║
║                                                                                        ║
╚══════════════════════════════════════════════════════════════════════════════════════════╝

参考文献

MLA

  1. DeepSeek-AI. (2024). DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model. arXiv:2405.04434.
  2. Vaswani, A., Shazeer, N., et al. (2017). Attention Is All You Need. NeurIPS 2017.
  3. Shazeer, N. (2019). Fast Transformer Decoding: One Write-Head is All You Need. arXiv:1911.02150.
  4. Ainslie, J., Lee-Thorp, J., et al. (2023). GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints. EMNLP 2023.
  5. Su, J., Lu, Y., et al. (2021). RoFormer: Enhanced Transformer with Rotary Position Embedding. arXiv:2104.09864.

DeepSeekMoE

  1. DeepSeek-AI. (2024). DeepSeek-V3 Technical Report. arXiv:2412.19437.
  2. Fedus, W., Zoph, B., & Shazeer, N. (2022). Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity. JMLR.
  3. Lepikhin, D., Lee, H., et al. (2021). GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding. ICLR 2021.

MTP

  1. Gloeckle, F., Badr, F., et al. (2024). Better & Faster Large Language Models via Multi-token Prediction. ICML 2024.
  2. Leviathan, Y., Kalman, M., & Matias, Y. (2023). Fast Inference from Transformers via Speculative Decoding. ICML 2023.
相关推荐
Black蜡笔小新1 小时前
企业AI算力工作站DLTM深度学习推理工作站零代码私有化重塑企业AI落地新模式
人工智能·深度学习
吴可可1231 小时前
SolidWorks草图转三维DWG技巧
算法
啦啦啦_99991 小时前
4. Transformer_4_输出部分
人工智能·深度学习·transformer
redaijufeng2 小时前
C++雾中风景7:闭包
c++·算法·风景
小欣加油2 小时前
leetcode287寻找重复数
数据结构·c++·算法·leetcode
DogDaoDao3 小时前
【GitHub】VoxCPM2 实战全解析:原理、部署与效果对比
深度学习·大模型·github·音频·语音模型·tss·文本生成语音
尽兴-3 小时前
2.1 向量基础:Embedding、余弦相似度、欧氏距离、向量检索
算法·embedding·欧氏距离·向量检索·余弦相似度
Black蜡笔小新3 小时前
自动化AI算法训练服务器DLTM训推一体工作站赋能多行业智能化升级
人工智能·算法·自动化
怪兽学LLM3 小时前
LeetCode 438 找到字符串中所有字母异位词(Python 固定滑动窗口+字符计数解法)
python·算法·leetcode