vllm分析(八)——deepseek v4 Attention (SWA + CSA + HCA)

DeepseekV4Attention

DeepseekV4Attention

DeepseekV4MultiHeadLatentAttentionWrapper

kv_cache

DeepseekV4IndexerCache

python 复制代码
class DeepseekV4Indexer(nn.Module):
    def __init__():
        assert cache_config is not None, "Deepseek V4 indexer requires cache_config"
        # NOTE(yifan): FP8 indxer cache use the same layout as V3.2:
        # head_dim bytes = 128 fp8 + 4 fp32 scale = 132.
        # For FP4 indexer cache, we still allocate the same amount of memory as FP8,
        # but only use the first half of the memory.
        k_cache_head_dim = self.head_dim + self.head_dim // self.quant_block_size * 4
        self.k_cache = DeepseekV4IndexerCache(
            head_dim=k_cache_head_dim,
            dtype=torch.uint8,
            prefix=f"{prefix}.k_cache",
            cache_config=cache_config,
            compress_ratio=self.compress_ratio,
        )
        self.compressor = DeepseekCompressor(
            vllm_config=vllm_config,
            compress_ratio=self.compress_ratio,
            hidden_size=hidden_size,
            head_dim=self.head_dim,
            rotate=True,
            prefix=f"{prefix}.compressor",
            k_cache_prefix=self.k_cache.prefix,
            use_fp4_cache=self.use_fp4_kv,
        )

        self.indexer_op = SparseAttnIndexer(
            self.k_cache,
            self.quant_block_size,
            self.scale_fmt,
            self.topk_tokens,
            self.head_dim,
            self.max_model_len,
            self.max_total_seq_len,
            self.topk_indices_buffer,
            skip_k_cache_insert=True,
            use_fp4_cache=self.use_fp4_kv,
        )

DeepseekV4IndexerCache.kv_cache

在DeepseekV4Indexer中,compressor的压缩的kv entry用于SparseAttnIndexer计算topk索引。 压缩的kv entry存储到DeepseekV4IndexerCache.k_cache。

DeepseekV4SWACache

python 复制代码
class DeepseekV4SWACache(torch.nn.Module, AttentionLayerBase):
    def __init__(
        self,
        head_dim: int,
        window_size: int,
        dtype: torch.dtype,
        prefix: str,
        cache_config: CacheConfig,
    ):
        super().__init__()
        self.kv_cache = torch.tensor([])
        self.head_dim = head_dim
        self.window_size = window_size
        self.prefix = prefix
        self.cache_config = cache_config
        self.dtype = dtype
        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self

DeepseekV4SWACache.kv_cache

DeepseekV4MLAAttention

python 复制代码
class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase):
    def __init__():
        self.kv_cache_dtype = kv_cache_dtype

        # Register with compilation context for metadata lookup
        compilation_config = vllm_config.compilation_config
        if prefix and prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        if prefix:
            compilation_config.static_forward_context[prefix] = self

        self.kv_cache = torch.tensor([])

DeepseekV4MLAAttention.kv_cache

在CSA和HSA场景,DeepseekCompressor压缩后的kv entry存储到 DeepseekV4MLAAttention.kv_cache

CompressorStateCache

python 复制代码
class CompressorStateCache(torch.nn.Module, AttentionLayerBase):
    def __init__(
        self,
        state_dim: int,
        dtype: torch.dtype,
        compress_ratio: int,
        prefix: str,
    ):
        super().__init__()
        self.state_dim = state_dim
        self.dtype = dtype
        self.prefix = prefix
        self.kv_cache = torch.tensor([])
        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self

CompressorStateCache.kv_cache

kv cache 的tensor分配分析

ref: vllm分析(四)------kv cache的初始化

DeepseekCompressor

DeepseekCompressor:负载kv块的压缩,用于Heavily Compressed Attention (HSA),Compressed Sparse Attention (CSA) 和 DeepseekV4Indexer。

python 复制代码
class DeepseekCompressor(nn.Module):
    """DeepSeek V4 KV/score compressor.

    Owns the linear / norm / state-cache / ape state and the shared forward
    prologue (kv/score split, save_partial_states launch). The
    compress → norm → RoPE → store step is dispatched to a triton kernel
    (``compress_norm_rope_store_triton``) by default, except for the NVIDIA
    head_dim=128 indexer path which uses the cutedsl kernel
    (``compress_norm_rope_store_cutedsl``) for better performance.
    """

    def __init__(
        self,
        vllm_config: VllmConfig,
        compress_ratio: int,
        hidden_size: int,
        head_dim: int,
        rotate: bool = False,
        prefix: str = "",
        k_cache_prefix="",
        use_fp4_cache: bool = False,
    ):
        self.overlap = compress_ratio == 4
        self.coff = 1 + self.overlap

        state_dtype = torch.float32
        self.ape = nn.Parameter(
            torch.empty(
                (compress_ratio, self.coff * self.head_dim),
                dtype=state_dtype,
                device=self.device,
            ),
            requires_grad=False,
        )

        self.fused_wkv_wgate = MergedColumnParallelLinear(
            self.hidden_size,
            [self.coff * self.head_dim, self.coff * self.head_dim],
            bias=False,
            return_bias=False,
            quant_config=None,
            disable_tp=True,
            prefix=f"{prefix}.fused_wkv_wgate",
        )
        self.norm = RMSNorm(self.head_dim, self.rms_norm_eps)

        self.state_cache = CompressorStateCache(
            state_dim=2 * self.coff * self.head_dim,  # kv_state + score_state
            dtype=state_dtype,
            compress_ratio=compress_ratio,
            prefix=f"{prefix}.state_cache",
        )

    def forward(
        self,
        # [num_tokens, 2 * self.coff * self.head_dim]
        kv_score: torch.Tensor,
        # [num_tokens]
        positions: torch.Tensor,
        rotary_emb,
    ) -> None:
        # Each of shape [num_tokens, coff * self.head_dim]
        # input bf16, output are fp32
        kv, score = kv_score.split(
            [self.coff * self.head_dim, self.coff * self.head_dim], dim=-1
        )

        # Get the metadata and handle dummy profiling run.
        attn_metadata = get_forward_context().attn_metadata
        if not isinstance(attn_metadata, dict):
            return

        state_metadata = cast(
            CompressorMetadata, attn_metadata[self.state_cache.prefix]
        )
        token_to_req_indices = state_metadata.token_to_req_indices
        slot_mapping = state_metadata.slot_mapping
        num_actual = slot_mapping.shape[0]
        block_table = state_metadata.block_table
        block_size = state_metadata.block_size

        # [num_blocks, block_size, kv_dim+score_dim], where kv_dim == score_dim
        state_cache = self.state_cache.kv_cache
        # kv_state stored in first half, score_state stored in second half
        state_width = state_cache.shape[-1] // 2
        pdl_kwargs = (
            {}
            if current_platform.is_rocm() or current_platform.is_xpu()
            else {"launch_pdl": False}
        )

        # Store the KV and score (with fused APE addition) in the state.
        # NOTE: PDL is disabled --- both this kernel and the compress kernels
        # below depend on preceding kernel outputs (kv/score from the cublas
        # GEMM; state_cache from this kernel) but neither emits/waits on PDL
        # grid dependency primitives, so launch_pdl=True caused a
        # read-after-write race and non-deterministic output.
        save_partial_states(
            kv=kv,
            score=score,
            ape=self.ape,
            positions=positions,
            state_cache=state_cache,
            slot_mapping=slot_mapping,
            block_size=block_size,
            state_width=state_width,
            compress_ratio=self.compress_ratio,
            pdl_kwargs=pdl_kwargs,
        )

        # Fused: compress → RMSNorm → RoPE → FP8 quant → KV cache write.
        # RoPE requirements (kernel applies forward GPT-J style rotation):
        # - is_neox_style=False (interleaved pairs, NOT split-half)
        # - cos_sin_cache layout: [max_pos, rope_head_dim] with first half cos,
        #   second half sin (per-pair, length rope_head_dim // 2 each)
        # - applied to LAST rope_head_dim elements of head_dim
        # - position used: (positions // compress_ratio) * compress_ratio
        cos_sin_cache = rotary_emb.cos_sin_cache
        k_cache_metadata = cast(Any, attn_metadata[self.k_cache_prefix])
        kv_cache = self._static_forward_context[self.k_cache_prefix].kv_cache

        if current_platform.is_cuda():
            # NVIDIA GPUs.
            if self.head_dim == 512:
                from .nvidia.ops import compress_norm_rope_store_cutedsl

                # Main compressor path.
                # Use a cutedsl kernel for better performance.
                compress_norm_rope_store_fn = compress_norm_rope_store_cutedsl
            else:
                # Indexer path (head_dim == 128).
                # Use a triton kernel.
                compress_norm_rope_store_fn = compress_norm_rope_store_triton
        else:
            # AMD GPUs.
            # Always use a triton kernel.
            compress_norm_rope_store_fn = compress_norm_rope_store_triton

        compress_norm_rope_store_fn(
            state_cache=state_cache,
            num_actual=num_actual,
            token_to_req_indices=token_to_req_indices,
            positions=positions,
            slot_mapping=slot_mapping,
            block_table=block_table,
            block_size=block_size,
            state_width=state_width,
            cos_sin_cache=cos_sin_cache,
            kv_cache=kv_cache,
            k_cache_metadata=k_cache_metadata,
            pdl_kwargs=pdl_kwargs,
            head_dim=self.head_dim,
            rope_head_dim=self.rope_head_dim,
            compress_ratio=self.compress_ratio,
            overlap=self.overlap,
            use_fp4_cache=self.use_fp4_cache,
            rms_norm_weight=self.norm.weight,
            rms_norm_eps=self.rms_norm_eps,
            quant_block=self._quant_block,
            token_stride=self._token_stride,
            scale_dim=self._scale_dim,
        )

CompressorStateCache

CompressorStateCache的kv_cache空间用到了滑动窗口机制。

SlidingWindowMLASpec到SlidingWindowManager

滑动窗口滚动时:SlidingWindowManager 会持续计算当前窗口的有效范围。当部分 block 因超出窗口范围被判定为"不再需要"时,便会释放对应的块,remove_skipped_blocks

save_partial_states

compress_norm_rope_store_triton

CSA 层 Compressor的工作原理

kv entry计算过程动态图,c4a attention illustation

head_dim =512。

C a = H ⋅ W a K V , C b = H ⋅ W b K V C_a = H \cdot W_{a}^{KV}, \quad C_b = H \cdot W_{b}^{KV} Ca=H⋅WaKV,Cb=H⋅WbKV

Z a = H ⋅ W Z a , Z b = H ⋅ W b Z Z_a = H \cdot W^{a}{Z}, \quad Z_b = H \cdot W{b}^{Z} Za=H⋅WZa,Zb=H⋅WbZ

C Comp = Softmax row ( Z a + B a Z b + B b ) ⊤ ⊙ C a C b C_{\text{Comp}} = \text{Softmax}_{\text{row}}\left( \begin{bmatrix} Z_a + B_a \\ Z_b + B_b \end{bmatrix} \right)^\top \odot \begin{bmatrix} C_a \\ C_b \end{bmatrix} CComp=Softmaxrow(Za+BaZb+Bb)⊤⊙CaCb

wkv包含权重 W a K V W_{a}^{KV} WaKV和 W b K V W_{b}^{KV} WbKV, wgate包含权重 W a Z W_{a}^{Z} WaZ和 W b Z W_{b}^{Z} WbZ。

fuzed_wkv_wgagte融合了: W a K V W_{a}^{KV} WaKV, W b K V W_{b}^{KV} WbKV, W a Z W^{Z}{a} WaZ, W Z b W^{b}{Z} WZb。

代码和图片中的ape融合了公式中 B a B_a Ba和 B b B_b Bb。

save_partial_states 将新生成的kv, score存入state_cache.kv_cache 。save_partial_states 在 compress_ratio=4, coff=2 时的状态存储:

text 复制代码
┌─────────────────────────────────────┬─────────────────────────────────────┐
│            kv_state                 │          score_state                │
│         (STATE_WIDTH 个元素)         │         (STATE_WIDTH 个元素)         │
├─────────────────┬───────────────────┼─────────────────┬───────────────────┤
│   block0        │   block1          │   block0        │   block1          │
│ (head_dim 个)    │ (head_dim 个)      │ (head_dim 个)    │ (head_dim 个)      │
└─────────────────┴───────────────────┴─────────────────┴───────────────────┘
offset: 0       head_dim            STATE_WIDTH   STATE_WIDTH+head_dim

假设 token 位置:0,1,2,3,4,5,6,7,8,9,10,11,...

压缩边界在 3,7,11,...(即 position+1 是 4 的倍数)。

在压缩边界((position+1) % 4 == 0)时,会触发一次压缩,其窗口包含 (1+overlap)*compress_ratio = 8 个 token。

压缩边界 窗口位置(实际位置) 使用的块(每个 token 贡献的块)
pos=3 -4,-3,-2,-1,0,1,2,3 token -4...-1: 块0(旧组) token 0...3: 块1(新组)
pos=7 0,1,2,3,4,5,6,7 token 0...3: 块0(旧组)token 4...7: 块1(新组)

针对csa,compress_norm_rope_store_triton调用 _fused_kv_compress_norm_rope_insert_sparse_attn

计算过程:

text 复制代码
[state_cache] ──┐
[score]        ──┼─ softmax + weighted sum → compressed KV
[block_table]  ──┘
                        ↓
                  [RMSNorm]
                        ↓
            ┌───────────┴───────────┐
            ↓                       ↓
        nope (448)               rope (64)
            ↓                       ↓
      FP8 量化块            RoPE (GPT-J style)
            ↓                       ↓
  [FP8 data] [scale]          [bf16 data]
            └───────────┬───────────┘
                        ↓
                 [k_cache]  (逐 token 布局)

输出维度:

text 复制代码
Cache block layout:
[0, bs*576):       token data (448 fp8 + 128 bf16 each)
[bs*576, +bs*8):   uint8 UE8M0 scales (7 real + 1 pad each)

前 448 字节:uint8 类型的 FP8 (E4M3) 数据(nope 部分)

后 128 字节:bfloat16 类型的 RoPE 数据(rope 部分)

紧接着 kv_cache_block_size * 8 字节:每个 token 的 scale 因子(7 个有效 uint8 + 1 个填充),用于反量化 FP8。

HSA Compressor的工作流程示意

博客2的流程图,说明了HSA的压缩过程。

针对HSA,compress_ratio=128,每128个kv条目压缩为1个kv entry。

head_dim =512。

DeepseekV4Indexer

DeepseekV4Indexer

text 复制代码
qr ──→ wq_b ──→ Q ──→ fused_indexer_q_rope_quant ──→ (q_quant, weights)
                              ↑                               │
                         indexer_weights                      │
                              positions                        │
compressed_kv_score ──→ compressor ──→ k ─────────────────────┤
                              │                               │
                              └──→ self.k_cache (写入)        │
                                                               ▼
hidden_states ─────────────────────────────────────→ SparseAttnIndexer
                                                               │
                                                               ▼
                                                          输出张量

fused_indexer_q_rope_quant 和 DeepseekCompressor并行运算。maybe_execute_in_parallel使用不同的stream

SparseAttnIndexer对每个 query 执行 top‑k 选择(topk_tokens),将选出的索引写入 self.topk_indices_buffer(供外部使用)。

DeepseekV4Indexer计算示意,图片来源

token level Compressor 计算过程:CSA 层 Compressor的工作原理。

SparseAttnIndexer计算示意,图片来源

SparseAttnIndexer用于为每个 query token 选出最相关的 key-value tokens。

对于第 t个 query token与历史上的每个 token 的 k s I k_s^{I} ksI计算相关性得分:

I t , s = ∑ j = 1 H I w t , j I ⋅ ReLU ( q t , j I ⋅ k s I ) I_{t,s} = \sum_{j=1}^{H_I} w_{t,j}^I \cdot \text{ReLU}\left( \mathbf{q}{t,j}^{I} \cdot \mathbf{k}{s}^{I} \right) It,s=j=1∑HIwt,jI⋅ReLU(qt,jI⋅ksI)

  • H I H_I HI:indexer 头数(固定为 64)
  • q t , j I q_{t,j}^{I} qt,jI: 第t个 Token 在第j个索引头(Indexer Head)的 Query 向量。
  • k s I k_{s}^{I} ksI: 第 s个 Token 的 Key 向量。该 Key 向量只有一个,被所有 64 个索引头共享。图片中的indexer wk后的箭头,方框为虚线形式。
  • w t , j I w_{t,j}^I wt,jI:表示第 j个头的重要性。
  • R e L U ReLU ReLU:激活函数,具备高吞吐量的计算优势。

Attention计算过程

deepseek_v4_attention

text 复制代码
[输入] hidden_states, positions
         │
         ▼
╔════════════════════════════════════════════════════════════╗
║ 阶段1:并行 GEMM 投影 (attn_gemm_parallel_execute)         ║
╠════════════════════════════════════════════════════════════╣
║  default stream: fused_wqa_wkv → qr_kv (主投影)           ║
║          ↓                                                ║
║  aux_streams (最多3个,可选并行):                         ║
║    - compressor_kv_score (如果 compress_ratio > 1)        ║
║    - indexer_weights_proj (如果 indexer exists)           ║
║    - indexer_compressor_kv_score (如果 indexer exists)    ║
║                                                           ║
║  同步: start_event 广播 → aux等待 → done_events等待       ║
╚════════════════════════════════════════════════════════════╝
         │
         ▼
   qr_kv, kv_score, indexer_kv_score, indexer_weights
         │
         ▼
╔════════════════════════════════════════════════════════════╗
║ 阶段2:RMSNorm 归一化 (fused_q_kv_rmsnorm)                 ║
╠════════════════════════════════════════════════════════════╣
║  qr, kv = split(qr_kv, [q_lora_rank, head_dim])           ║
║  qr ← RMSNorm(qr, q_norm.weight)                          ║
║  kv ← RMSNorm(kv, kv_norm.weight)                         ║
╚════════════════════════════════════════════════════════════╝
         │
         ▼
╔════════════════════════════════════════════════════════════╗
║ 阶段3:Q/KV 变换 + 缓存写入 (带多流分支)                    ║
╠════════════════════════════════════════════════════════════╣
║ ┌─────────────────────────────────────────────────────────┐
║ │ 分支 A: 存在 indexer (压缩器必然存在)                    │
║ │   default: wq_b_kv_insert                               │
║ │     → wq_b(qr) → [n_heads, head_dim]                   │
║ │     → _fused_qnorm_rope_kv_insert (融合kernel):        │
║ │         - Q: per-head RMSNorm + RoPE + 填充至padded_heads│
║ │         - KV: RoPE + FP8量化 + 写入SWA缓存              │
║ │   aux0: indexer.forward (含其内部wq_b+量化+稀疏索引)     │
║ │   aux1: compressor.forward (压缩kv_score写索引器K缓存)   │
║ │   同步: default等待aux0,aux1完成 → 返回 q_padded         │
║ ├─────────────────────────────────────────────────────────┤
║ │ 分支 B: 仅存在压缩器 (无 indexer)                        │
║ │   default: wq_b_kv_insert (同上)                        │
║ │   aux: compressor.forward                               │
║ │   同步: maybe_execute_in_parallel                       │
║ ├─────────────────────────────────────────────────────────┤
║ │ 分支 C: 无压缩器 & 无 indexer (纯SWA)                   │
║ │   顺序执行 wq_b_kv_insert                               │
║ └─────────────────────────────────────────────────────────┘
╚════════════════════════════════════════════════════════════╝
         │
         ▼
   q_padded ( [num_tokens, padded_heads, head_dim] )
   kv (原始, 实际读缓存)
         │
         ▼
╔════════════════════════════════════════════════════════════╗
║ 阶段4:稀疏注意力计算 (mla_attn)                           ║
╠════════════════════════════════════════════════════════════╣
║  backend: FlashMLASparseBackend (NVIDIA) / ROCm AITER     ║
║  输入: q_padded, kv, positions                            ║
║  缓存: SWA缓存(FP8) + 索引器提供的稀疏索引(若存在)         ║
║  输出: out [num_tokens, padded_heads, head_dim] (预分配)   ║
║  注意: 仅使用SWA缓存中的KV,忽略输入的kv tensor            ║
╚════════════════════════════════════════════════════════════╝
         │
         ▼
[输出] out (后续经过逆RoPE + FP8 einsum + wo_b 得到最终结果)

以csa为例,计算过程示意,图片来源

结合vllm代码,画出csa计算过程示意图:

reference

1 DeepSeek V4 in vLLM: Efficient Long-context Attention

2 手撕 DeepSeek-V4 (3): HCA

3 DeepSeek v4 Compressor kv cache压缩模块

4 DeepSeek V4-vLLM预览

5 图解DeepSeek V4:详细计算流程解析

6 Deepseek-V4模型结构与源码解析

相关推荐
Soonyang Zhang1 天前
vllm分析(七)——模型结构分析(llama, qwen3moe)
vllm·推理框架
陈 洪 伟1 天前
大模型推理引擎vLLM(25): 从--kv-cache-dtype fp8_e5m2时gsm8k答非所问的bug梳理kv cache相应代码片段
vllm·kvcache
zjun30212 天前
【昇腾950】如何在昇腾950pr的容器环境上部署vllm
vllm·vllm-ascend·torch_npu·昇腾950
小何code2 天前
人工智能【第55篇】大模型推理优化:vLLM与推理加速技术
vllm·大模型部署·推理优化·pagedattention
Soonyang Zhang4 天前
FlexKV 分析(三)——缓存的异步读写操作
推理框架·kv cache
清风lsq5 天前
大模型-vllm 自投机解码可行性分析
vllm·大模型推理
大模型推理5 天前
《Nano-vLLM 源码解读》第 12 篇 · ModelRunner:从 prompt 到 token(二)
vllm
清风lsq6 天前
大模型-解析vllm lora 模块
人工智能·vllm·大模型推理
大模型推理6 天前
《Nano-vLLM 源码解读》第 11 篇 · ModelRunner:从 prompt 到 token
vllm