长文本CP 切分,共2次All2All
第一次AlltoAll,输入按Seq维度汇总,按Head维度切。(切输入,非TP维度的切参数)
s/c, b, n/t, h\] -AlltoAll-\> \[s, b, n/(t\*c), h
第二次AlltoAll,恢复按Seq维度切,按Head维度汇总。
s, b, n/(t\*c), h\] -AlltoAll-\> \[s/c, b, n/t, h
其中t 为TP, c 为CP, n = nHead数
举例: CP = 2, TP =4 , H = 8192, nHead = 16
| 阶段 | 形状 | 说明 |
|---|---|---|
| 输入 | [s/2, b, 8192] |
CP 切分后,每 rank 持有半个序列 |
| MLA 解压后 Q/K/V | [s/2, b, 16, 192] |
16 heads/rank(64 heads ÷ TP=4),经过了TP的降维 |
| A2A 后(scatter head,gather seq) | [s, b, 8, 192] |
全序列,head 减半 |
| Flash Attention 输出 | [s, b, 8, 128] |
全序列本地计算 |
| A2A 后(scatter seq,gather head) | [s/2, b, 16, 128] |
还原序列分片 |
| o_proj 后 | [s/2, b, 8192] |
还原 hidden_states, 经过TP升维 |
python
compressed_kv [s, b, 576] ← kv_a_proj 压缩后的 latent,是 _preprocess 的输入
│
├── split → ct_kv [s, b, 512] ← kv_lora_rank 部分
│ k_pe [s, b, 64] ← rope 部分
│
├── kv_a_layernorm(ct_kv)
│
└── kv_b_proj (Up-projection, 解压)
[s, b, 512] → [s, b, 16heads, 128+128]
k_nope [s, b, 16, 128]
v [s, b, 16, 128]
q_b_input (经过 q_b_proj 解压)
q_nope [s, b, 16, 128]
q_pe [s, b, 16, 64]
最终拼接:
query_states [s, b, 16, 192] = q_nope + q_pe
key_states [s, b, 16, 192] = k_nope + k_pe
value_states [s, b, 16, 128]
MLA attention:
python
DeepseekV2Attention
└── self.core_attention_flash = FlashAttention(...) # 基础 flash attn
↓ (当 CP + alltoall 时自动包装)
└── self.core_attention_flash = DistributedAttention(FlashAttention, cp_group)