【vllm】(v1 Sample)vLLM V1 Sample—Part 3 投机采样拒绝器与Triton Kernel

vLLM V1 Sample 模块超深度架构分析 --- Part 3: 投机采样拒绝器与Triton Kernel

分析范围 : rejection_sampler.py + ops/topk_topp_triton.py


目录

  • [第十四章 RejectionSampler 投机解码拒绝采样器](#第十四章 RejectionSampler 投机解码拒绝采样器)
    • [14.1 设计背景与算法原理](#14.1 设计背景与算法原理)
    • [14.2 类结构与初始化](#14.2 类结构与初始化)
    • [14.3 forward() 主流程深度解析](#14.3 forward() 主流程深度解析)
    • [14.4 _get_logprobs_tensors() 日志概率重构](#14.4 _get_logprobs_tensors() 日志概率重构)
    • [14.5 parse_output() 输出解析](#14.5 parse_output() 输出解析)
    • [14.6 apply_logits_processors() 投机解码版](#14.6 apply_logits_processors() 投机解码版)
    • [14.7 apply_penalties() 投机解码版](#14.7 apply_penalties() 投机解码版)
    • [14.8 _combine_outputs_with_spec_tokens() 输出合并](#14.8 _combine_outputs_with_spec_tokens() 输出合并)
  • [第十五章 rejection_sample() 核心拒绝采样函数](#第十五章 rejection_sample() 核心拒绝采样函数)
    • [15.1 函数签名与输入验证](#15.1 函数签名与输入验证)
    • [15.2 输出缓冲区初始化](#15.2 输出缓冲区初始化)
    • [15.3 贪心路径: rejection_greedy_sample_kernel](#15.3 贪心路径: rejection_greedy_sample_kernel)
    • [15.4 随机路径: rejection_random_sample_kernel](#15.4 随机路径: rejection_random_sample_kernel)
    • [15.5 uniform_probs生成与float64精度](#15.5 uniform_probs生成与float64精度)
    • [15.6 sample_recovered_tokens() 恢复token采样](#15.6 sample_recovered_tokens() 恢复token采样)
    • [15.7 apply_sampling_constraints() 约束应用](#15.7 apply_sampling_constraints() 约束应用)
    • [15.8 expand_batch_to_tokens() 批次展平](#15.8 expand_batch_to_tokens() 批次展平)
    • [15.9 expand_kernel Triton内核](#15.9 expand_kernel Triton内核)
  • [第十六章 Triton Top-K/Top-P Kernel超深度解析](#第十六章 Triton Top-K/Top-P Kernel超深度解析)
    • [16.1 算法背景: Qrita论文](#16.1 算法背景: Qrita论文)
    • [16.2 查找表设计](#16.2 查找表设计)
    • [16.3 _update_min_larger_stats() 辅助函数](#16.3 _update_min_larger_stats() 辅助函数)
    • [16.4 _topk_topp_kernel 整体结构](#16.4 _topk_topp_kernel 整体结构)
    • [16.5 Top-K路径详细解析](#16.5 Top-K路径详细解析)
    • [16.6 Top-P路径详细解析](#16.6 Top-P路径详细解析)
    • [16.7 组合Top-K+Top-P路径](#16.7 组合Top-K+Top-P路径)
    • [16.8 最终遮罩应用](#16.8 最终遮罩应用)
    • [16.9 重复logit处理](#16.9 重复logit处理)
    • [16.10 apply_top_k_top_p_triton() Python包装](#16.10 apply_top_k_top_p_triton() Python包装)
    • [16.11 缓冲区与查找表缓存](#16.11 缓冲区与查找表缓存)
    • [16.12 reset_buffer_cache() 缓存清理](#16.12 reset_buffer_cache() 缓存清理)
  • [第十七章 sample_recovered_tokens_kernel Triton内核](#第十七章 sample_recovered_tokens_kernel Triton内核)
    • [17.1 算法原理](#17.1 算法原理)
    • [17.2 逐行解析](#17.2 逐行解析)
    • [17.3 无draft_probs模式(ngram spec decode)](#17.3 无draft_probs模式(ngram spec decode))
  • [附录E 投机解码端到端时序图](#附录E 投机解码端到端时序图)
  • [附录F Triton Kernel计算流程图](#附录F Triton Kernel计算流程图)
  • [附录G 拒绝采样数学证明](#附录G 拒绝采样数学证明)
  • [附录H 术语表补充](#附录H 术语表补充)

第十四章 RejectionSampler 投机解码拒绝采样器

14.1 设计背景与算法原理

投机解码(Speculative Decoding) 是一种加速LLM推理的技术:

  1. 用小型draft模型快速生成K个候选token
  2. 用大型target模型并行验证这些token
  3. 接受与target分布一致的token,拒绝不一致的
  4. 对被拒绝的位置,从修正后的分布中重新采样

拒绝采样(Rejection Sampling) 的数学原理来自论文 Fast Inference from Transformers via Speculative Decoding

对于draft概率 q(x) 和target概率 p(x)

  • 接受条件u < p(x)/q(x),其中 u ~ Uniform(0,1)
  • 接受时:使用draft token
  • 拒绝时 :从 max(p - q, 0) 分布中采样恢复token
  • 全部接受时:额外采样一个bonus token

术语对照表

术语 含义
accepted tokens 通过 p/q ≥ u 验证的draft token
recovered tokens 拒绝后从 max(p-q,0) 采样的修正token
bonus tokens 全部接受后额外采样的token
output tokens accepted + recovered + bonus

Synthetic模式 :一种特殊的拒绝采样变体,使用预设的接受率而非 p/q 比率。用于ngram spec decode等不提供draft概率的场景。
Yes
No
Yes
No
Draft模型生成

K个候选token
Target模型并行推理

验证K+1个位置
对比draft和target概率
p(x)/q(x) ≥ u?
接受draft token

继续检查下一个
拒绝→从max(p-q,0)采样

recovered token
还有更多draft?
全部接受→采样bonus token
output = accepted + recovered
output = accepted + bonus

14.2 类结构与初始化

python 复制代码
class RejectionSampler(nn.Module):
    def __init__(
        self,
        sampler: Sampler,                       # 内部使用的标准采样器
        spec_config: SpeculativeConfig | None,   # 投机解码配置
        device: torch.device | None,
    ):
        super().__init__()
        self.sampler = sampler
        
        # 确定logprobs模式
        logprobs_mode = self.sampler.logprobs_mode
        self.is_processed_logprobs_mode = logprobs_mode.startswith("processed")
        self.is_logits_logprobs_mode = logprobs_mode.endswith("logits")
        
        # Synthetic模式:预设接受率
        self.synthetic_conditional_rates: torch.Tensor | None = None
        if spec_config is not None and spec_config.rejection_sample_method == "synthetic":
            assert spec_config.synthetic_acceptance_rates is not None
            # 将无条件接受率转为条件接受率
            self.synthetic_conditional_rates = torch.tensor(
                unconditional_to_conditional_rates(
                    spec_config.synthetic_acceptance_rates
                ),
                dtype=torch.float32,
                device=device,
            )
        self.synthetic_mode = self.synthetic_conditional_rates is not None

无条件→条件接受率转换

α_1, α_2, ..., α_K 为无条件接受率(每个位置独立接受的概率),则条件接受率为:

复制代码
β_1 = α_1
β_2 = α_2 / (1 - α_1)   # 在第1个已接受的前提下
β_3 = α_3 / ((1 - α_1)(1 - α_2))  # 在前2个已接受的前提下
...

这确保了联合接受概率与无条件设定一致。
RejectionSampler
+sampler: Sampler
+is_processed_logprobs_mode: bool
+is_logits_logprobs_mode: bool
+synthetic_conditional_rates: Tensor|None
+synthetic_mode: bool
+forward(metadata, draft_probs, logits, sampling_metadata) : SamplerOutput
+_get_logprobs_tensors(max_num, metadata, logits, target_logits, bonus_logits, sampled) : LogprobsTensors
+parse_output(output_token_ids, vocab_size, discard_req_indices, logprobs_tensors) : tuple
+apply_logits_processors(logits, sampling_metadata, metadata) : Tensor
+apply_penalties(logits, sampling_metadata, metadata, repeat_indices, output_token_ids) : Tensor
+_combine_outputs_with_spec_tokens(output, spec) : list
Sampler
+forward(logits, metadata) : SamplerOutput

14.3 forward() 主流程深度解析

python 复制代码
def forward(
    self,
    metadata: SpecDecodeMetadata,           # 投机解码元数据
    draft_probs: torch.Tensor | None,       # [num_tokens, vocab_size] draft概率
    logits: torch.Tensor,                   # [num_tokens + batch_size, vocab_size] target logits
    sampling_metadata: SamplingMetadata,
) -> SamplerOutput:

逐行解析

python 复制代码
    assert metadata.max_spec_len <= MAX_SPEC_LEN  # 最大128个spec tokens
    
    # ===== Phase 1: 采样bonus token =====
    bonus_logits_indices = metadata.bonus_logits_indices
    target_logits_indices = metadata.target_logits_indices
    
    # NOTE: PyTorch索引创建新张量,与原logits存储分离
    # 因此对bonus_logits的原地操作不影响原始logits
    assert logits is not None
    bonus_logits = logits[bonus_logits_indices]  # [batch_size, vocab_size]
    
    # 使用标准Sampler采样bonus token
    # 特殊设置:
    # - max_num_logprobs=-1: 特殊模式(spec decode需要所有logits用于后续计算)
    # - predict_bonus_token=True: 标记这是bonus token
    # - logprobs_mode_override: 使用processed_logits或raw_logits模式
    #   因为bonus logits需要用于后续计算accepted token的logprobs
    bonus_sampler_output = self.sampler(
        logits=bonus_logits,
        sampling_metadata=replace(
            sampling_metadata,
            max_num_logprobs=-1,
        ),
        predict_bonus_token=True,
        logprobs_mode_override=(
            "processed_logits" if self.is_processed_logprobs_mode else "raw_logits"
        ),
    )
    bonus_token_ids = bonus_sampler_output.sampled_token_ids  # [batch_size, 1]

为什么bonus token需要override logprobs_mode

  • 在spec decode中,bonus token的logits需要用于计算accepted token的logprobs
  • 如果用 raw_logprobs 模式,计算的是 log(softmax(raw_logits)),但accepted token的logprobs应该基于target logits(经过处理后的)
  • 因此强制使用 processed_logitsraw_logits 模式,保存处理后的logits供后续使用
python 复制代码
    # ===== Phase 2: 处理target logits =====
    raw_target_logits = logits[target_logits_indices]  # [num_tokens, vocab_size]
    raw_target_logits = raw_target_logits.to(torch.float32)
    
    target_logits = raw_target_logits
    if not self.is_processed_logprobs_mode:
        # 非processed模式:需要clone保留原始logits
        # 因为apply_logits_processors会原地修改
        target_logits = target_logits.clone()
    
    # 应用logits处理器(包含penalties、bad_words、thinking budget等)
    target_logits = self.apply_logits_processors(
        target_logits, sampling_metadata, metadata
    )
    
    # 应用采样约束(温度、top-k、top-p)
    target_logits = apply_sampling_constraints(
        target_logits,
        metadata.cu_num_draft_tokens,
        sampling_metadata,
    )
python 复制代码
    # ===== Phase 3: 拒绝采样 =====
    output_token_ids = rejection_sample(
        metadata.draft_token_ids,
        metadata.num_draft_tokens,
        metadata.max_spec_len,
        metadata.cu_num_draft_tokens,
        draft_probs,
        target_logits,
        bonus_token_ids,
        sampling_metadata,
        synthetic_mode=self.synthetic_mode,
        synthetic_conditional_rates=self.synthetic_conditional_rates,
    )
    # output_token_ids: [batch_size, max_spec_len + 1]
    # 被拒绝的位置填入PLACEHOLDER_TOKEN_ID (-1)
    
    # ===== Phase 4: 计算logprobs =====
    logprobs_tensors = None
    if sampling_metadata.max_num_logprobs is not None:
        logprobs_tensors = self._get_logprobs_tensors(
            sampling_metadata.max_num_logprobs,
            metadata,
            logits,
            target_logits if self.is_processed_logprobs_mode else raw_target_logits,
            bonus_sampler_output.logprobs_tensors.logprobs,
            output_token_ids,
        )
    
    return SamplerOutput(
        sampled_token_ids=output_token_ids,
        logprobs_tensors=logprobs_tensors,
    )

Yes
No
RejectionSampler.forward()
Phase 1: 采样bonus token

sampler(bonus_logits)
bonus_token_ids [B, 1]
Phase 2: 处理target logits
clone(raw_target_logits)
apply_logits_processors()
apply_sampling_constraints()

(temp, top-k, top-p)
processed target_logits
Phase 3: rejection_sample()
output_token_ids [B, max_spec+1]

rejected → -1
max_num_logprobs≠None?
_get_logprobs_tensors()
logprobs = None
SamplerOutput

14.4 _get_logprobs_tensors() 日志概率重构

python 复制代码
def _get_logprobs_tensors(
    self,
    max_num_logprobs: int,
    metadata: SpecDecodeMetadata,
    logits: torch.Tensor,          # 原始全量logits [num_tokens+B, V]
    target_logits: torch.Tensor,   # 处理后的target logits [num_tokens, V]
    bonus_logits: torch.Tensor,    # bonus token的logits
    sampled_token_ids: torch.Tensor,  # [B, max_spec+1]
) -> LogprobsTensors:
    # 计算每个请求的起始token偏移
    cu_num_sampled_tokens = torch.zeros_like(metadata.cu_num_sampled_tokens)
    cu_num_sampled_tokens[1:] = metadata.cu_num_sampled_tokens[:-1]
    # cu_num_sampled_tokens[i] = 第i个请求在全局token中的起始索引
    
    # 合并target和bonus的logits
    bonus_logits_indices = metadata.bonus_logits_indices
    target_logits_indices = metadata.target_logits_indices
    final_logits = torch.zeros_like(logits, dtype=torch.float32)
    final_logits[target_logits_indices] = target_logits.to(torch.float32)
    final_logits[bonus_logits_indices] = bonus_logits.to(torch.float32)
    
    # 计算每个采样token在final_logits中的行索引
    # 注意:包括被拒绝的token(后续在parse_output中过滤)
    logit_start_indices = cu_num_sampled_tokens
    offsets = torch.arange(
        sampled_token_ids.shape[-1],
        device=logit_start_indices.device,
        dtype=logit_start_indices.dtype,
    )
    accepted_logit_indices = (
        logit_start_indices.unsqueeze(1) + offsets.unsqueeze(0)
    ).flatten()
    accepted_logit_indices.clamp_(max=final_logits.shape[0] - 1)
    
    # 替换被拒绝token的id为0(避免gather越界)
    accepted_tokens = sampled_token_ids.clone().flatten()
    accepted_tokens[accepted_tokens == PLACEHOLDER_TOKEN_ID] = 0
    
    # 提取accepted token对应的logits
    accepted_logits = final_logits[accepted_logit_indices]
    
    # 计算logprobs
    accepted_logprobs = (
        accepted_logits
        if self.is_logits_logprobs_mode
        else self.sampler.compute_logprobs(accepted_logits)
    )
    
    # Gather top-N logprobs
    return self.sampler.gather_logprobs(
        accepted_logprobs,
        max_num_logprobs,
        accepted_tokens.to(torch.int64),
    )

设计要点

  • 被拒绝的token也会计算logit索引(虽然后续会被过滤),避免CPU-GPU同步
  • accepted_tokens[==PLACEHOLDER] = 0:将-1替换为0,确保gather不越界;但logprob值无意义,会在parse_output中过滤

14.5 parse_output() 输出解析

python 复制代码
@staticmethod
def parse_output(
    output_token_ids: torch.Tensor,       # [B, max_spec+1]
    vocab_size: int,
    discard_req_indices: Sequence[int] = (),  # 需要丢弃的请求索引
    logprobs_tensors: LogprobsTensors | None = None,
) -> tuple[list[list[int]], LogprobsLists | None]:
    """解析拒绝采样输出
    
    将PLACEHOLDER_TOKEN_ID替换的行过滤掉,
    保留实际接受的token序列
    """
    output_token_ids_np = output_token_ids.cpu().numpy()
    
    # 创建有效token掩码
    valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & (
        output_token_ids_np < vocab_size
    )
    # valid_mask[i, j] = True → 第i个请求的第j个token有效
    
    # 处理logprobs
    output_logprobs = None
    if logprobs_tensors is not None:
        cu_num_tokens = [0] + valid_mask.sum(axis=1).cumsum().tolist()
        filtered_tensors = logprobs_tensors.filter(valid_mask.flatten())
        output_logprobs = filtered_tensors.tolists(cu_num_tokens)
    
    # 丢弃指定请求(如被调度器标记为preempted的)
    if len(discard_req_indices) > 0:
        valid_mask[discard_req_indices] = False
    
    # 提取有效token
    outputs = [
        row[valid_mask[i]].tolist()
        for i, row in enumerate(output_token_ids_np)
    ]
    
    return outputs, output_logprobs

14.6 apply_logits_processors() 投机解码版

python 复制代码
def apply_logits_processors(
    self,
    logits: torch.Tensor,               # [num_tokens, vocab_size]
    sampling_metadata: SamplingMetadata,
    metadata: SpecDecodeMetadata,
) -> torch.Tensor:
    has_penalties = not sampling_metadata.no_penalties
    any_penalties_or_bad_words = (
        sampling_metadata.bad_words_token_ids or has_penalties
    )
    holder = sampling_metadata.thinking_budget_state_holder
    needs_thinking = holder is not None and holder.has_tracked_requests()
    
    output_token_ids = sampling_metadata.output_token_ids
    
    # 如果有惩罚/bad_words/thinking,需要合并output和spec tokens
    if any_penalties_or_bad_words or needs_thinking:
        output_token_ids = self._combine_outputs_with_spec_tokens(
            output_token_ids,
            sampling_metadata.spec_token_ids,
        )
        # 合并原因:惩罚需要考虑draft tokens已生成的部分
    
    # ===== 计算repeat_indices =====
    # repeat_indices: 将请求级参数展平到token级
    repeat_indices: torch.Tensor | None = None
    need_repeat_indices = (
        sampling_metadata.allowed_token_ids_mask is not None
        or has_penalties
        or needs_thinking
    )
    if need_repeat_indices:
        num_requests = len(metadata.num_draft_tokens)
        num_draft_tokens = torch.tensor(metadata.num_draft_tokens, device="cpu")
        original_indices = torch.arange(num_requests, device="cpu")
        # repeat_interleave: 将每个请求索引重复其draft token数次
        # 例如: [0, 1, 2] with [2, 3, 1] → [0, 0, 1, 1, 1, 2]
        repeat_indices_cpu = original_indices.repeat_interleave(num_draft_tokens)
        repeat_indices = repeat_indices_cpu.to(device=logits.device, non_blocking=True)
        
        # 应用惩罚
        logits = self.apply_penalties(
            logits, sampling_metadata, metadata, repeat_indices, output_token_ids
        )
        
        # 应用token白名单
        if sampling_metadata.allowed_token_ids_mask is not None:
            token_mask = sampling_metadata.allowed_token_ids_mask[repeat_indices]
            logits.masked_fill_(token_mask, float("-inf"))
    
    # ===== 应用bad_words =====
    if bad_words_token_ids := sampling_metadata.bad_words_token_ids:
        apply_bad_words_with_drafts(
            logits, bad_words_token_ids, output_token_ids, metadata.num_draft_tokens
        )
    
    # ===== 应用MinTokens(spec decode版本)=====
    for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
        if isinstance(processor, MinTokensLogitsProcessor):
            logits = processor.apply_with_spec_decode(
                logits, metadata.num_draft_tokens
            )
    
    # ===== 应用thinking budget =====
    if holder is not None and holder.has_tracked_requests():
        logits = holder.apply_to_logits(
            logits,
            predict_bonus_token=False,
            spec_token_ids=sampling_metadata.spec_token_ids,
        )
    
    return logits

14.7 apply_penalties() 投机解码版

python 复制代码
@staticmethod
def apply_penalties(
    logits: torch.Tensor,
    sampling_metadata: SamplingMetadata,
    metadata: SpecDecodeMetadata,
    repeat_indices: torch.Tensor,
    output_token_ids: list[list[int]],
) -> torch.Tensor:
    if sampling_metadata.no_penalties:
        return logits
    
    assert sampling_metadata.prompt_token_ids is not None
    
    # 使用repeat_indices将请求级参数展平到token级
    prompt_token_ids = sampling_metadata.prompt_token_ids[repeat_indices]
    presence_penalties = sampling_metadata.presence_penalties[repeat_indices]
    frequency_penalties = sampling_metadata.frequency_penalties[repeat_indices]
    repetition_penalties = sampling_metadata.repetition_penalties[repeat_indices]
    
    return apply_all_penalties(
        logits,
        prompt_token_ids,
        presence_penalties,
        frequency_penalties,
        repetition_penalties,
        output_token_ids,
    )

14.8 _combine_outputs_with_spec_tokens() 输出合并

python 复制代码
@staticmethod
def _combine_outputs_with_spec_tokens(
    output_token_ids: list[list[int]],
    spec_token_ids: list[list[int]] | None = None,
) -> list[list[int]]:
    """将output tokens与spec tokens合并
    
    用于惩罚计算:每个draft位置需要知道"到目前为止已生成的所有token"
    
    Example:
      output = [[10, 20]]          # 原始输出
      spec   = [[30, 40]]          # 2个draft tokens
      
      result = [[10, 20],          # draft位置0: output + spec[0]
                [10, 20, 30]]      # draft位置1: output + spec[0] + spec[1]
    """
    if spec_token_ids is None:
        return output_token_ids
    
    result = []
    for out, spec in zip(output_token_ids, spec_token_ids):
        if len(spec) == 0:
            continue
        result.append(out)  # 第一个draft位置:只有原始output
        for i in range(len(spec) - 1):
            # 后续位置:追加之前的spec tokens
            result.append([*result[-1], spec[i]])
    return result

output
input
output = [[10, 20]]
spec = [[30, 40]]
result[0] = [10, 20]
result[1] = [10, 20, 30]


第十五章 rejection_sample() 核心拒绝采样函数

15.1 函数签名与输入验证

python 复制代码
def rejection_sample(
    draft_token_ids: torch.Tensor,          # [num_tokens] draft token id
    num_draft_tokens: list[int],            # [batch_size] 每请求draft数
    max_spec_len: int,                      # 最大spec长度
    cu_num_draft_tokens: torch.Tensor,      # [batch_size] 累积draft token偏移
    draft_probs: torch.Tensor | None,       # [num_tokens, vocab_size] draft概率
    target_logits: torch.Tensor,            # [num_tokens, vocab_size] 处理后target logits
    bonus_token_ids: torch.Tensor,          # [batch_size, 1] bonus token id
    sampling_metadata: SamplingMetadata,
    synthetic_mode: bool = False,
    synthetic_conditional_rates: torch.Tensor | None = None,
) -> torch.Tensor:
    """执行拒绝采样,返回 [batch_size, max_spec_len + 1] 的输出token矩阵"""
    
    # ===== 输入验证 =====
    assert draft_token_ids.ndim == 1
    assert draft_probs is None or draft_probs.ndim == 2
    assert cu_num_draft_tokens.ndim == 1
    assert target_logits.ndim == 2
    
    batch_size = len(num_draft_tokens)
    num_tokens = draft_token_ids.shape[0]
    vocab_size = target_logits.shape[-1]
    device = target_logits.device
    
    # 连续性检查(Triton kernel要求连续内存)
    assert draft_token_ids.is_contiguous()
    assert draft_probs is None or draft_probs.is_contiguous()
    assert bonus_token_ids.is_contiguous()
    assert target_logits.shape == (num_tokens, vocab_size)

15.2 输出缓冲区初始化

python 复制代码
    # 创建输出缓冲区
    output_token_ids = torch.full(
        (batch_size, max_spec_len + 1),
        PLACEHOLDER_TOKEN_ID,  # -1,表示"未填充"
        dtype=torch.int32,     # 与SamplerOutput.sampled_token_ids一致
        device=device,
    )
    # 形状 [batch_size, max_spec_len + 1]
    # "+1" 是给bonus token的位置

15.3 贪心路径: rejection_greedy_sample_kernel

python 复制代码
    # ===== 判断贪心/随机 =====
    if sampling_metadata.all_greedy:
        is_greedy = None  # 全贪心,不需要逐请求判断
    else:
        is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE
        # [batch_size] bool,True=贪心,False=随机
    
    # ===== 生成均匀随机数 =====
    uniform_probs: torch.Tensor | None = None
    if synthetic_mode or not sampling_metadata.all_greedy:
        # synthetic模式或非全贪心:需要均匀随机数
        uniform_probs = generate_uniform_probs(
            num_tokens,
            num_draft_tokens,
            sampling_metadata.generators,
            device,
        )
    
    # ===== 贪心请求的拒绝采样 =====
    if not sampling_metadata.all_random:
        target_argmax = target_logits.argmax(dim=-1)  # [num_tokens]
        rejection_greedy_sample_kernel[(batch_size,)](
            output_token_ids,
            cu_num_draft_tokens,
            draft_token_ids,
            target_argmax,
            bonus_token_ids,
            is_greedy,
            max_spec_len,
            uniform_probs,
            synthetic_conditional_rates,
            SYNTHETIC_MODE=synthetic_mode,
        )
        if sampling_metadata.all_greedy:
            return output_token_ids  # 全贪心:直接返回

rejection_greedy_sample_kernel 逐行解析

python 复制代码
@triton.jit(do_not_specialize=["max_spec_len"])
def rejection_greedy_sample_kernel(
    output_token_ids_ptr,      # [batch_size, max_spec_len + 1]
    cu_num_draft_tokens_ptr,   # [batch_size]
    draft_token_ids_ptr,       # [num_tokens]
    target_argmax_ptr,         # [num_tokens]
    bonus_token_ids_ptr,       # [batch_size]
    is_greedy_ptr,             # [batch_size] or None
    max_spec_len,
    uniform_probs_ptr,         # [num_tokens] (synthetic mode)
    synthetic_conditional_rates_ptr,  # [num_spec_tokens] (synthetic mode)
    SYNTHETIC_MODE: tl.constexpr,
):
    req_idx = tl.program_id(0)  # 每个program处理一个请求
    
    # 读取is_greedy标记
    is_greedy = True if is_greedy_ptr is None else tl.load(is_greedy_ptr + req_idx)
    if not is_greedy:
        return  # 非贪心请求:跳过(由random kernel处理)
    
    # 读取该请求的draft token范围
    start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
    end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
    num_draft_tokens = end_idx - start_idx
    
    # 逐个验证draft token
    rejected = False
    for pos in range(num_draft_tokens):
        if not rejected:
            draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
            target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos).to(tl.int32)
            
            if SYNTHETIC_MODE:
                # Synthetic模式:用预设接受率判断
                uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
                rate = tl.load(synthetic_conditional_rates_ptr + pos)
                accepted = uniform_prob < rate
                token_id = draft_token_id if accepted else target_argmax_id
                rejected = not accepted
            else:
                # 标准模式:draft == target_argmax 即接受
                token_id = target_argmax_id
                rejected = draft_token_id != target_argmax_id
            
            tl.store(
                output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
                token_id,
            )
    
    # 如果全部接受,追加bonus token
    if not rejected:
        bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
        tl.store(
            output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens,
            bonus_token_id,
        )

贪心模式的拒绝逻辑

  • 贪心采样时,target模型的输出是确定的(argmax)
  • 因此只需检查draft token是否与target的argmax一致
  • 不一致即拒绝,直接使用target的argmax作为recovered token
  • 无需计算概率比率或采样

No
Yes
Yes
No
Yes
No
Accept
Reject
Yes
No
rejection_greedy_sample_kernel
is_greedy?
return (由random kernel处理)
遍历draft tokens
SYNTHETIC_MODE?
uniform_prob < rate?

接受: draft_token

拒绝: target_argmax
draft == target_argmax?
接受draft token
拒绝→用target_argmax

rejected=True
用draft_token
用target_argmax

rejected=True
还有draft?
后续位置保持PLACEHOLDER
追加bonus_token

15.4 随机路径: rejection_random_sample_kernel

python 复制代码
    # ===== 计算target概率分布 =====
    target_probs = target_logits.softmax(dim=-1, dtype=torch.float32)
    assert target_probs.is_contiguous()
    
    # ===== 采样recovered tokens =====
    recovered_token_ids = sample_recovered_tokens(
        max_spec_len,
        num_draft_tokens,
        cu_num_draft_tokens,
        draft_token_ids,
        draft_probs,
        target_probs,
        sampling_metadata,
        device,
    )
    
    # ===== 随机请求的拒绝采样 =====
    assert uniform_probs is not None
    rejection_random_sample_kernel[(batch_size,)](
        output_token_ids,
        cu_num_draft_tokens,
        draft_token_ids,
        draft_probs,
        target_probs,
        bonus_token_ids,
        recovered_token_ids,
        uniform_probs,
        is_greedy,
        max_spec_len,
        vocab_size,
        synthetic_conditional_rates,
        NO_DRAFT_PROBS=draft_probs is None,
        SYNTHETIC_MODE=synthetic_mode,
    )
    return output_token_ids

rejection_random_sample_kernel 逐行解析

python 复制代码
@triton.jit(do_not_specialize=["max_spec_len"])
def rejection_random_sample_kernel(
    output_token_ids_ptr,
    cu_num_draft_tokens_ptr,
    draft_token_ids_ptr,
    draft_probs_ptr,              # [num_tokens, vocab_size] or None
    target_probs_ptr,             # [num_tokens, vocab_size]
    bonus_token_ids_ptr,
    recovered_token_ids_ptr,      # [num_tokens]
    uniform_probs_ptr,            # [num_tokens]
    is_greedy_ptr,
    max_spec_len,
    vocab_size,
    synthetic_conditional_rates_ptr,
    NO_DRAFT_PROBS: tl.constexpr,
    SYNTHETIC_MODE: tl.constexpr,
):
    req_idx = tl.program_id(0)
    is_greedy = tl.load(is_greedy_ptr + req_idx)
    if is_greedy:
        return  # 贪心请求已在greedy kernel中处理
    
    start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
    end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
    num_draft_tokens = end_idx - start_idx
    
    rejected = False
    for pos in range(num_draft_tokens):
        if not rejected:
            draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
            uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
            
            if SYNTHETIC_MODE:
                rate = tl.load(synthetic_conditional_rates_ptr + pos)
                accepted = uniform_prob < rate
            else:
                if NO_DRAFT_PROBS:
                    draft_prob = 1  # ngram模式:无draft概率,总是接受
                else:
                    # 读取draft概率 q(draft_token)
                    draft_prob = tl.load(
                        draft_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id
                    )
                
                # 读取target概率 p(draft_token)
                target_prob = tl.load(
                    target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id
                )
                
                # 接受条件: u < p(x)/q(x)
                # 特殊处理: q(x)=0 → 拒绝(避免NaN)
                accepted = draft_prob > 0 and target_prob / draft_prob >= uniform_prob
            
            if accepted:
                token_id = draft_token_id
            else:
                rejected = True
                # 使用预采样的recovered token
                token_id = tl.load(recovered_token_ids_ptr + start_idx + pos)
            
            tl.store(
                output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
                token_id,
            )
    
    if not rejected:
        bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
        tl.store(
            output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens,
            bonus_token_id,
        )

标准拒绝采样的数学验证

接受概率 = P(u < p/q) = p/q(当u~Uniform(0,1)时)

  • p(x)/q(x) ≥ 1 时,接受概率为1(target更倾向该token)
  • p(x)/q(x) < 1 时,按概率接受(target不如draft倾向该token)

被拒绝后从 max(p-q, 0) 采样:

  • 这是 pq 之间的差异分布
  • 保证了最终分布与直接从target采样一致

15.5 uniform_probs生成与float64精度

python 复制代码
def generate_uniform_probs(
    num_tokens: int,
    num_draft_tokens: list[int],
    generators: dict[int, torch.Generator],
    device: torch.device,
) -> torch.Tensor:
    """生成均匀随机数用于拒绝采样"""
    
    # NOTE: 使用float64而非float32
    # 原因: float32时uniform_prob可能恰好为0.0
    # 参考: https://github.com/pytorch/pytorch/issues/16706
    # 这会导致 p/q >= 0 恒真,本应拒绝的token被错误接受
    uniform_probs = torch.rand(
        (num_tokens,),
        dtype=torch.float64,  # 关键: 使用double精度
        device=device,
    )
    
    # 对有seed的请求覆盖随机数
    start_idx = 0
    for req_idx, n in enumerate(num_draft_tokens):
        if n == 0:
            continue  # 无draft token → 不生成随机数(复现性重要)
        end_idx = start_idx + n
        generator = generators.get(req_idx)
        if generator is not None:
            uniform_probs[start_idx:end_idx].uniform_(generator=generator)
        start_idx = end_idx
    
    return uniform_probs

float64的必要性

  • float32的范围: [-3.4e38, 3.4e38],有效数字7位
  • torch.rand() 在float32下可能精确产生0.0(概率约 1/2^24 ≈ 6e-8)
  • uniform_prob = 0.0 时,p/q >= 0 恒真,本应拒绝的token被错误接受
  • float64下精确产生0.0的概率约 1/2^53 ≈ 1.1e-16,可忽略

15.6 sample_recovered_tokens() 恢复token采样

python 复制代码
def sample_recovered_tokens(
    max_spec_len: int,
    num_draft_tokens: list[int],
    cu_num_draft_tokens: torch.Tensor,
    draft_token_ids: torch.Tensor,
    draft_probs: torch.Tensor | None,
    target_probs: torch.Tensor,
    sampling_metadata: SamplingMetadata,
    device: torch.device,
) -> torch.Tensor:
    """为每个draft位置预采样recovered token
    
    使用Gumbel-Max技巧从 max(p-q, 0) 分布采样
    """
    batch_size = len(num_draft_tokens)
    vocab_size = target_probs.shape[-1]
    
    # 生成指数分布噪声
    q = torch.empty(
        (batch_size, vocab_size),
        dtype=torch.float32,
        device=device,
    )
    q.exponential_()  # 全局Exp(1)
    
    # 对有seed的请求覆盖
    for i, generator in sampling_metadata.generators.items():
        if num_draft_tokens[i] > 0:
            q[i].exponential_(generator=generator)
    
    # 取倒数: 1/q 用于Gumbel-Max
    inv_q = q.reciprocal()
    
    # 预分配输出
    recovered_token_ids = torch.empty_like(draft_token_ids)
    
    # Triton kernel: 对每个(请求, 位置)计算recovered token
    BLOCK_SIZE = 8192
    sample_recovered_tokens_kernel[(batch_size, max_spec_len)](
        recovered_token_ids,
        cu_num_draft_tokens,
        draft_token_ids,
        draft_probs,
        target_probs,
        inv_q,
        vocab_size,
        BLOCK_SIZE,
        NO_DRAFT_PROBS=draft_probs is None,
    )
    return recovered_token_ids

数学原理

  • 拒绝后的修正分布为 max(p(x) - q(x), 0)
  • Gumbel-Max技巧: argmax_x(prob(x) * inv_q[x]) 等价于从 prob 分布采样
  • 这里 prob = max(target_prob - draft_prob, 0)
  • 由于只关心argmax,不需要归一化 max(p-q, 0)

为什么每个请求只生成一个 inv_q

  • 同一请求的多个draft位置共享一个recovered分布
  • 但实际上每个位置的 pq 不同(不同位置的概率分布不同)
  • kernel内部按位置使用不同的 target_probs[pos]draft_probs[pos]

15.7 apply_sampling_constraints() 约束应用

python 复制代码
def apply_sampling_constraints(
    logits: torch.Tensor,              # [num_tokens, vocab_size]
    cu_num_draft_tokens: torch.Tensor, # [batch_size]
    sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
    """对spec decode的logits应用温度/top-k/top-p约束"""
    
    if sampling_metadata.all_greedy:
        return logits  # 全贪心:无需约束
    
    num_tokens = logits.shape[0]
    
    # 温度缩放:展平请求级→token级
    temperature = expand_batch_to_tokens(
        sampling_metadata.temperature,
        cu_num_draft_tokens,
        num_tokens,
        replace_from=GREEDY_TEMPERATURE,  # 0 → 1(避免除以0)
        replace_to=1,
    )
    # 原地除法
    logits.div_(temperature.unsqueeze(-1))
    
    # top-k展平
    top_k = None
    if sampling_metadata.top_k is not None:
        top_k = expand_batch_to_tokens(
            sampling_metadata.top_k,
            cu_num_draft_tokens,
            num_tokens,
        )
    
    # top-p展平
    top_p = None
    if sampling_metadata.top_p is not None:
        top_p = expand_batch_to_tokens(
            sampling_metadata.top_p,
            cu_num_draft_tokens,
            num_tokens,
        )
    
    # 应用top-k/top-p
    return apply_top_k_top_p(logits, top_k, top_p)

15.8 expand_batch_to_tokens() 批次展平

python 复制代码
def expand_batch_to_tokens(
    x: torch.Tensor,              # [batch_size]
    cu_num_tokens: torch.Tensor,  # [batch_size] 累积token数
    num_tokens: int,
    replace_from: int = 0,
    replace_to: int = 0,
) -> torch.Tensor:
    """将请求级张量展平为token级
    
    Example: x = [a, b, c], cu_num_tokens = [2, 5, 6]
    → expanded = [a, a, b, b, b, c]
    """
    batch_size = x.shape[0]
    expanded_x = x.new_empty(num_tokens)
    expand_kernel[(batch_size,)](
        expanded_x, x, cu_num_tokens,
        replace_from, replace_to,
        MAX_NUM_TOKENS=***  # 避免重编译
    )
    return expanded_x

15.9 expand_kernel Triton内核

python 复制代码
@triton.jit(do_not_specialize=["replace_from", "replace_to"])
def expand_kernel(
    output_ptr,          # [num_tokens]
    input_ptr,           # [batch_size]
    cu_num_tokens_ptr,   # [batch_size]
    replace_from,
    replace_to,
    MAX_NUM_TOKENS: tl.constexpr,
):
    req_idx = tl.program_id(0)
    
    # 计算该请求的token范围
    start_idx = 0 if req_idx == 0 else tl.load(cu_num_tokens_ptr + req_idx - 1)
    end_idx = tl.load(cu_num_tokens_ptr + req_idx)
    num_tokens = end_idx - start_idx
    
    # 读取源值
    src_val = tl.load(input_ptr + req_idx)
    # 可选替换: replace_from → replace_to
    src_val = tl.where(src_val == replace_from, replace_to, src_val)
    
    # 批量写入
    offset = tl.arange(0, MAX_NUM_TOKENS)
    tl.store(output_ptr + start_idx + offset, src_val, mask=offset < num_tokens)

replace_from/replace_to 的用途

  • 温度展平时: replace_from=0, replace_to=1 --- 将温度0(贪心)替换为1(不做缩放)
  • 避免除以0错误

第十六章 Triton Top-K/Top-P Kernel超深度解析

16.1 算法背景: Qrita论文

topk_topp_triton.py 基于 Qrita 论文 (https://arxiv.org/abs/2602.01518),核心思想是基于枢轴的截断与选择(Pivot-based Truncation and Selection)

传统方法:

  1. 对每行排序 → O(V log V)
  2. 计算累积概率 → O(V)
  3. 应用top-k/top-p阈值 → O(V)

Qrita方法:

  1. 估计枢轴(pivot)→ 通过三元搜索 O(log V × log V)
  2. 直接应用阈值 → O(V)

关键优化:不需要排序整个词表,只需找到一个阈值logit值,低于此值的token被屏蔽为-inf。

16.2 查找表设计

两个查找表在编译时生成,存储在模块级变量中:

python 复制代码
_NORMAL_CDF_TO_SIGMA_TABLE = [...]  # 200个元素
# 用途: 独立Top-P采样
# 给定百分位数p,查表得到正态分布CDF(p)对应的σ值
# 用于估计top-p阈值的初始猜测

_PERCENTILE_TO_STD_TABLE = [...]   # 200个元素
# 用途: Top-K采样
# 给定百分位数k/V,查表得到标准正态分布的z值
# 用于估计top-k阈值的初始猜测

表结构

  • 每个表200个元素,覆盖百分位数 [0.5%, 100%]
  • _NORMAL_CDF_TO_SIGMA_TABLE: 用于top-p,从百分位映射到σ
  • _PERCENTILE_TO_STD_TABLE: 用于top-k,从百分位映射到z-score

使用方式

python 复制代码
# Top-K: 估计k/V百分位对应的z-score
percentile = int(k / VOCAB_SIZE * 200)  # 映射到[0, 199]
sigma = PERCENTILE_TO_STD_TABLE[percentile]
outlier_pivot = avg_logit + std_logit * sigma  # 初始枢轴估计

16.3 _update_min_larger_stats() 辅助函数

python 复制代码
@triton.jit
def _update_min_larger_stats(data, above_mask, min_larger, num_min_larger, sentinel):
    """跨tile更新"严格大于pivot的最小值"及其计数
    
    用于处理重复logit值:当多个token有相同的logit值时,
    需要知道有多少个token恰好在pivot值上
    
    合并规则:
    - tile_min < running_min → 替换min和count
    - tile_min == running_min → 累加count
    - tile_min > running_min → 保持running值
    """
    # 在above_mask为True的位置找最小值(其他位置设为sentinel=inf)
    tile_min = tl.min(tl.where(above_mask, data, sentinel))
    
    # 计算tile中等于tile_min的元素数
    tile_eq = above_mask & (tl.abs(data - tile_min) < 1e-9)
    tile_cnt = tl.sum(tile_eq)
    
    # 与running状态合并
    is_new = tile_min < min_larger      # 发现更小的值
    is_same = tl.abs(tile_min - min_larger) < 1e-9  # 值相同
    
    # 更新count
    num_min_larger = tl.where(
        is_new, tile_cnt,              # 新值 → 用tile的count
        num_min_larger + tile_cnt * is_same  # 相同 → 累加
    )
    # 更新min
    min_larger = tl.minimum(min_larger, tile_min)
    
    return min_larger, num_min_larger

为什么需要这个函数

Top-K的核心问题是找到第K大的值。但如果有重复值,情况变得复杂:

复制代码
logits = [5.0, 3.0, 3.0, 3.0, 1.0]  k=3

第3大的值是3.0,但有3个token的值都是3.0。是否全部保留?

  • 保留3个3.0 → 实际保留4个token(5.0 + 3×3.0)
  • 只保留0个3.0 → 实际保留1个token(5.0)
  • 保留部分3.0 → 需要随机选择保留哪几个

_update_min_larger_stats 追踪"严格大于pivot的最小值"和"等于该值的token数",使得可以在最终遮罩阶段正确处理重复值。

16.4 _topk_topp_kernel 整体结构

Yes
No
Yes
No
Yes
No
Yes
No
Yes
No
_topk_topp_kernel
TOPK_ENABLED

and k < V?
Top-K路径
TOPP_ENABLED?
Pass 0: 统计采样块

计算avg/std
估计outlier_pivot

avg + std × σ
Pass 1: 全词表扫描

找max/min + 收集outliers
outliers > k?
三元搜索(在buffer中)

2×pivot同时评估
三元搜索(全词表)

2×pivot同时评估
k_pivot确定
TOPP_ENABLED

and finite > k?
Top-K + Top-P联合路径
Top-K only

final_pivot = k_pivot
Pass 3: 计算softmax和

收集outlier概率
Pass 4: 归一化buffer概率
Pass 5: 三元搜索p_pivot
final_pivot = log(p_pivot × sum + max)
Standalone Top-P路径
无过滤
Pass 0: 统计采样块

计算avg/std
Pass 1: 计算softmax和

收集outlier概率
outlier_sum > p?
三元搜索(在buffer中)
三元搜索(全词表)
final_pivot = log(p_pivot × sum + max_sample)
Pass 6: 应用最终遮罩

16.5 Top-K路径详细解析

Pass 0: 统计采样
python 复制代码
# 从第一个BLOCK_SIZE(8192)的采样块计算均值和标准差
offs = tl.arange(0, BLOCK_SIZE)
mask_n = offs < VOCAB_SIZE
logits_blk0 = tl.load(LOGITS_ROW + offs, mask=mask_n, other=-float("inf"))

# 排除-inf值(如grammar bitmasks产生的)
finite_mask = (logits_blk0 > -float("inf")) & mask_n
num_finite = tl.sum(finite_mask)
finite_logits = tl.where(finite_mask, logits_blk0, 0.0)

avg_logit = tl.where(num_finite > 0, tl.sum(finite_logits) / num_finite, 0.0)
sq_avg_logit = tl.where(
    num_finite > 0,
    tl.sum(finite_logits * finite_logits) / num_finite,
    0.0,
)
std_logit = tl.sqrt(tl.maximum(sq_avg_logit - avg_logit * avg_logit, 0.0))
# 使用 E[X²] - E[X]² 而非直接计算var,避免二次遍历

为什么只用第一个block

  • 全词表遍历代价高(V=128K需要16个8K blocks)
  • 第一个block的统计量足以估计整体分布
  • 精度要求不高(只需估计outlier阈值的大致位置)
Pivot估计
python 复制代码
# 从百分位查找表估计z-score
percentile = tl.cast(k / VOCAB_SIZE * 200, tl.uint32)
percentile = tl.minimum(percentile, 199)
sigma = tl.load(PERCENTILE_TO_STD_TABLE + percentile)
# 加负偏移: sigma = sigma - 0.15 * |sigma|
# 这使得初始估计偏高,宁可多收集outliers再缩小
sigma = sigma + tl.abs(sigma) * -0.15
outlier_pivot = avg_logit + std_logit * sigma

-0.15偏移的设计

  • 偏高的估计意味着收集的outliers比实际需要的多
  • 多收集无害(后续三元搜索会精确化)
  • 少收集有害(可能遗漏需要保留的token)
Pass 1: 全词表扫描
python 复制代码
num_finite_total = tl.zeros((), dtype=tl.uint32)
for i in range(0, NUM_TILES):  # 遍历所有8K blocks
    offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask_n = offs_n < VOCAB_SIZE
    logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf"))
    
    # 更新全局max/min
    max_logit = tl.maximum(max_logit, tl.max(logits_blk))
    finite_blk = tl.where(logits_blk > -float("inf"), logits_blk, float("inf"))
    min_logit = tl.minimum(min_logit, tl.min(finite_blk))
    num_finite_total += tl.sum(finite_blk_mask & mask_n)
    
    # 收集大于outlier_pivot的值到buffer
    outlier_mask = (logits_blk > outlier_pivot) & mask_n
    cumulative_pos = tl.cast(
        tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32
    )
    num_outliers += tl.sum(outlier_mask)
    write_pos = tl.where(outlier_mask, cumulative_pos, -1)
    tl.store(BUFFER_ROW + write_pos, logits_blk, mask=outlier_mask)

buffer的作用:收集"可能是top-K"的outlier值到连续内存区域,使后续三元搜索只需遍历少量元素而非全词表。

三元搜索(Ternary Search)
python 复制代码
found_pivot = 0
while found_pivot == 0:
    # 同时评估两个pivot点(1/3和2/3位置)
    k_pivot_0 = (max_range - min_range) * 1.0 / 3.0 + min_range
    k_pivot_1 = (max_range - min_range) * 2.0 / 3.0 + min_range
    
    # 遍历buffer,统计大于每个pivot的元素数
    for i in range(0, search_iters):
        logits_blk2 = tl.load(BUFFER_ROW + offs_n, ...)
        k_pivots_num_0 += tl.sum(logits_blk2 > k_pivot_0)
        k_pivots_num_1 += tl.sum(logits_blk2 > k_pivot_1)
        
        # 同时追踪"严格大于pivot的最小值"
        min_larger_0, num_min_larger_0 = _update_min_larger_stats(...)
        min_larger_1, num_min_larger_1 = _update_min_larger_stats(...)
    
    # 终止条件: count >= k 且 count - num_at_pivot < k
    # 即pivot恰好将前k个元素与剩余元素分开
    if k_pivots_num_0 >= k and k_pivots_num_0 - num_min_larger_0 < k:
        found_pivot = 1  # k_pivot_0 是正确的阈值
    
    # 更新搜索范围
    if k_pivots_num_1 > k: min_range = k_pivot_1
    elif k_pivots_num_0 > k: min_range = k_pivot_0
    if k_pivots_num_0 < k: max_range = k_pivot_0
    elif k_pivots_num_1 < k: max_range = k_pivot_1
    
    num_iters += 1
    if num_iters >= 18 or abs(min_range - max_range) < 1e-9:
        k_pivot = (max_range + min_range) / 2.0
        found_pivot = 1

为什么用三元搜索而非二分

  • 每次迭代评估2个pivot(1/3和2/3)
  • 理论收敛速度: 三元搜索每次缩小2/3范围,二分缩小1/2
  • 但2pivot的评估在一次遍历中完成,实际效率更高
  • 最多18次迭代(约2^18 ≈ 262K的精度,足够float32)

终止条件解析

复制代码
k_pivots_num >= k           → 大于pivot的元素 ≥ k个
k_pivots_num - num_min_larger < k  → 大于pivot但不等于pivot最小值的元素 < k个

这意味着:pivot值将"确定保留的元素"和"可能需要部分保留的元素"分开。

16.6 Top-P路径详细解析

Standalone Top-P(无Top-K时)的流程:

Pass 1: 计算softmax和
python 复制代码
# 使用max_sample避免softmax溢出
max_sample = avg_logit + std_logit * 10.0
sum_exp_logits = 0.0

for i in range(0, NUM_TILES):
    probs_blk = tl.exp(logits_blk - max_sample)
    sum_exp_logits += tl.sum(probs_blk)
Pass 2: 收集outlier概率
python 复制代码
# 使用查找表估计概率阈值
idx = int(p * 200)
sigma = NORMAL_CDF_TO_SIGMA_TABLE[idx]
sigma = sigma + abs(sigma) * -0.25  # 偏保守估计
outlier_prob = exp(outlier_pivot - max_sample) / sum_exp_logits

# 收集概率大于outlier_prob的token
for i in range(0, NUM_TILES):
    probs_blk = exp(logits_blk - max_sample) / sum_exp_logits
    outlier_mask = probs_blk > outlier_prob
    # 存储到buffer
Pass 3: 三元搜索p_pivot
python 复制代码
# 在概率空间搜索p_pivot
# p_pivot满足: sum(prob > p_pivot) ≥ p 且 sum(prob > p_pivot) - min_larger × count < p
while found_pivot == 0:
    p_pivot_0 = (max_range - min_range) * 1/3 + min_range
    p_pivot_1 = (max_range - min_range) * 2/3 + min_range
    
    # 统计大于p_pivot的概率和
    p_pivots_sum_0 = sum(probs × (probs > p_pivot_0))
    p_pivots_sum_1 = sum(probs × (probs > p_pivot_1))
    
    # 终止条件
    if p_pivots_sum >= p and p_pivots_sum - min_larger × count < p:
        found_pivot = 1

16.7 组合Top-K+Top-P路径

当同时启用top-k和top-p时:

  1. 先应用Top-K:找到k_pivot,保留top-K个元素
  2. 再应用Top-P:在top-K的输出上应用top-p过滤
python 复制代码
# Top-K结果已收集在buffer中
# 现在需要在buffer中的元素上应用Top-P

# Pass 3: 计算top-K元素的softmax和
sum_exp_logits = 0.0
for i in range(0, search_iters):
    probs_blk = BUFFER_ROW[i]
    
    # 处理重复logit
    if num_keep < num_duplicate_logit:
        duplicate_mask = abs(probs_blk - duplicate_logit) < 1e-9
        duplicate_count = cumsum(duplicate_mask) + num_kept
        duplicate_keep_mask = duplicate_count <= num_keep
        duplicate_remove_mask = duplicate_mask & ~duplicate_keep_mask
        outlier_mask = outlier_mask & ~duplicate_remove_mask
    
    probs_blk = where(outlier_mask, probs_blk, -inf)
    probs_blk = probs_blk - max_logit
    probs_blk = exp(probs_blk)
    sum_exp_logits += sum(probs_blk)

# Pass 4: 归一化并存储
for i in range(0, search_iters):
    probs_blk = BUFFER_ROW[i]
    probs_blk = exp(probs_blk - max_logit)
    probs_blk = probs_blk / sum_exp_logits
    store(BUFFER_ROW + offs_n, probs_blk)

# Pass 5: 搜索p_pivot
# (与standalone top-p相同的三元搜索)

16.8 最终遮罩应用

python 复制代码
# Pass 6: 应用最终遮罩
# 如果final_pivot >= max_logit(或NaN),无token需要屏蔽
if not (final_pivot < max_logit):
    final_pivot = -float("inf")  # 不过滤任何token
elif final_pivot != -float("inf"):
    for i in range(0, NUM_TILES):
        logits_blk = load(LOGITS_ROW + offs_n)
        keep_mask = (logits_blk > final_pivot) & mask_n
        
        # 处理重复logit
        if num_keep < num_duplicate_logit:
            duplicate_mask = abs(logits_blk - duplicate_logit) < 1e-9
            duplicate_count = cumsum(duplicate_mask) + num_kept
            duplicate_keep_mask = duplicate_count <= num_keep
            duplicate_remove_mask = duplicate_mask & ~duplicate_keep_mask
            num_kept += sum(duplicate_keep_mask)
            keep_mask = keep_mask & ~duplicate_remove_mask
        
        logits_blk = where(keep_mask, logits_blk, MASK_VALUE)  # -inf
        store(LOGITS_ROW + offs_n, logits_blk)

16.9 重复logit处理

重复logit是Top-K/Top-P实现中的关键边界情况:

问题场景

复制代码
logits = [5.0, 3.0, 3.0, 3.0, 1.0]  k=3
k_pivot = 3.0  num_at_pivot = 3

严格 > k_pivot 的元素:1个(5.0)
== k_pivot 的元素:3个(3.0 × 3)

需要保留:k - 1 = 2个3.0

处理方式

  1. num_kept 追踪已保留的重复值计数
  2. duplicate_keep_mask = (cumsum ≤ num_keep) → 保留前num_keep个
  3. duplicate_remove_mask → 移除其余

为什么用cumsum而非随机选择

  • 确定性更好,便于调试和复现
  • 实际效果等价(相同logit的token概率相同,选哪个无所谓)
  • Triton kernel中随机数生成代价高

16.10 apply_top_k_top_p_triton() Python包装

python 复制代码
def apply_top_k_top_p_triton(
    logits: torch.Tensor,       # [batch_size, vocab_size] 修改in-place
    k: torch.Tensor | None,     # [batch_size]
    p: torch.Tensor | None,     # [batch_size]
    mask_value: float = float("-inf"),
) -> torch.Tensor:
    assert logits.ndim == 2
    assert logits.dtype == torch.float32
    
    batch_size, vocab_size = logits.shape
    topk_enabled = k is not None
    topp_enabled = p is not None
    
    if batch_size == 0 or not (topk_enabled or topp_enabled):
        return logits
    
    # 准备k/p指针(不允许的用dummy)
    if k is not None:
        k_ptr = k.to(torch.int32)  # Triton kernel需要int32
    else:
        k_ptr = logits  # dummy,不会被读取
    
    if p is not None:
        p_ptr = p.to(torch.float32)
    else:
        p_ptr = logits  # dummy
    
    # 计算并行program数
    num_sm = num_compute_units(logits.device.index)
    NUM_PROGRAMS = min(num_sm, batch_size)
    
    # 分配/缓存buffer
    buf_key = (logits.device, logits.dtype, vocab_size)
    buffer = _TRITON_BUFFER_CACHE.get(buf_key)
    if buffer is None or buffer.shape[0] < NUM_PROGRAMS:
        size = min(next_power_of_2(NUM_PROGRAMS), num_sm)
        buffer = logits.new_empty((size, vocab_size))
        _TRITON_BUFFER_CACHE[buf_key] = buffer
    if buffer.shape[0] > NUM_PROGRAMS:
        buffer = buffer[:NUM_PROGRAMS]
    
    # 缓存查找表
    tables = _TRITON_TABLE_CACHE.get(logits.device)
    if tables is None:
        normal_cdf_to_sigma_table = logits.new_tensor(_NORMAL_CDF_TO_SIGMA_TABLE)
        percentile_to_std_table = logits.new_tensor(_PERCENTILE_TO_STD_TABLE)
        _TRITON_TABLE_CACHE[logits.device] = (
            normal_cdf_to_sigma_table, percentile_to_std_table
        )
    else:
        normal_cdf_to_sigma_table, percentile_to_std_table = tables
    
    # 启动kernel
    _topk_topp_kernel[(NUM_PROGRAMS,)](
        logits, buffer,
        percentile_to_std_table, normal_cdf_to_sigma_table,
        k_ptr, p_ptr,
        BATCH_SIZE=batch_size,
        MASK_VALUE=mask_value,
        VOCAB_SIZE=vocab_size,
        BLOCK_SIZE=8192,
        BLOCK_SIZE_TRUNC=4096,
        TOPK_ENABLED=topk_enabled,
        TOPP_ENABLED=topp_enabled,
    )
    
    return logits

BLOCK_SIZE / BLOCK_SIZE_TRUNC 的设计

  • BLOCK_SIZE=8192:全词表遍历时的tile大小
  • BLOCK_SIZE_TRUNC=4096:只在buffer/outlier上遍历时的tile大小(更小,因为outlier数量通常远小于V)

16.11 缓冲区与查找表缓存

python 复制代码
_TRITON_TABLE_CACHE: dict[tuple[torch.device], tuple[torch.Tensor, torch.Tensor]] = {}
# key: (device,) → value: (normal_cdf_to_sigma_table, percentile_to_std_table)
# 每个设备只需创建一次查找表

_TRITON_BUFFER_CACHE: dict[tuple[torch.device, torch.dtype, int], torch.Tensor] = {}
# key: (device, dtype, vocab_size) → value: buffer [num_programs, vocab_size]
# 按设备+词表大小缓存,避免每次调用重新分配

16.12 reset_buffer_cache() 缓存清理

python 复制代码
def reset_buffer_cache():
    """清理缓存,释放GPU内存"""
    _TRITON_BUFFER_CACHE.clear()
    _TRITON_TABLE_CACHE.clear()
    torch.accelerator.empty_cache()  # 释放未使用的GPU内存

第十七章 sample_recovered_tokens_kernel Triton内核

17.1 算法原理

拒绝采样被拒绝后,需要从修正分布 max(p(x) - q(x), 0) 采样。使用Gumbel-Max技巧:

复制代码
recovered = argmax_x(prob(x) × inv_q[x])

其中:

  • prob(x) = max(target_prob(x) - draft_prob(x), 0)
  • inv_q[x] = 1 / Exp(1)_x

由于 prob(x) 不需要归一化(argmax与归一化无关),避免了sum运算。

17.2 逐行解析

python 复制代码
@triton.jit
def sample_recovered_tokens_kernel(
    output_token_ids_ptr,          # [num_tokens] 输出
    cu_num_draft_tokens_ptr,       # [batch_size]
    draft_token_ids_ptr,           # [num_tokens]
    draft_probs_ptr,               # [num_tokens, vocab_size] or None
    target_probs_ptr,              # [num_tokens, vocab_size]
    inv_q_ptr,                     # [batch_size, vocab_size]
    vocab_size,
    BLOCK_SIZE: tl.constexpr,
    NO_DRAFT_PROBS: tl.constexpr,
):
    req_idx = tl.program_id(0)   # 请求索引
    pos = tl.program_id(1)       # draft位置索引
    
    # 读取该请求的token范围
    start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
    end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
    num_draft = end_idx - start_idx
    
    # 超出范围的位置直接返回
    if pos >= num_draft:
        return
    
    token_idx = start_idx + pos  # 全局token索引
    
    # ngram模式:排除draft_token本身
    if NO_DRAFT_PROBS:
        draft_token_id = tl.load(draft_token_ids_ptr + token_idx)
    
    # 分tile遍历词表
    max_val = float("-inf")
    recovered_id = 0
    for v in range(0, vocab_size, BLOCK_SIZE):
        vocab_offset = v + tl.arange(0, BLOCK_SIZE)
        vocab_mask = vocab_offset < vocab_size
        
        if NO_DRAFT_PROBS:
            # ngram: prob = target_prob(排除draft_token)
            prob = tl.load(
                target_probs_ptr + token_idx * vocab_size + vocab_offset,
                mask=(vocab_mask & (vocab_offset != draft_token_id)),
                other=0.0,
            )
        else:
            # 标准模式: prob = max(target - draft, 0)
            draft_prob = tl.load(
                draft_probs_ptr + token_idx * vocab_size + vocab_offset,
                mask=vocab_mask, other=0.0
            )
            target_prob = tl.load(
                target_probs_ptr + token_idx * vocab_size + vocab_offset,
                mask=vocab_mask, other=0.0
            )
            prob = tl.maximum(target_prob - draft_prob, 0.0)
        
        # 读取Gumbel噪声倒数
        inv_q = tl.load(
            inv_q_ptr + req_idx * vocab_size + vocab_offset,
            mask=vocab_mask, other=0.0
        )
        
        # 计算 score = prob × inv_q
        score = prob * inv_q
        
        # 找本tile的max
        local_max, local_id = tl.max(score, axis=0, return_indices=True)
        
        # 更新全局max
        if local_max > max_val:
            max_val = local_max
            recovered_id = v + local_id
    
    # 写入结果
    tl.store(output_token_ids_ptr + token_idx, recovered_id)

17.3 无draft_probs模式(ngram spec decode)

当使用ngram spec decode时,没有draft概率(draft_probs=None)。此时:

  • 修正分布简化为:prob(x) = target_prob(x)(排除draft_token本身)
  • 因为 max(p - q, 0)q=1 时退化为 max(p - 1, 0),这不对
  • 实际处理:在ngram模式下,接受条件变为 uniform < 1(总是接受),所以不会进入recovered token逻辑
  • 但kernel仍需正确处理:使用 prob = target_prob(排除draft token)

附录E 投机解码端到端时序图

sample_recovered_kernel random_kernel greedy_kernel rejection_sample() Sampler RejectionSampler GPUModelRunner sample_recovered_kernel random_kernel greedy_kernel rejection_sample() Sampler RejectionSampler GPUModelRunner 投机解码步骤 Phase 1: Bonus token Phase 2: Target logits处理 Phase 3: 拒绝采样 alt [全贪心] [有随机请求] Phase 4: Logprobs forward(metadata, draft_probs, target_logits, sampling_metadata) sampler(bonus_logits, ...) bonus_token_ids apply_logits_processors(target_logits) apply_sampling_constraints(target_logits) rejection_sample(draft_ids, target_logits, bonus_ids) rejection_greedy_sample_kernel output_token_ids rejection_greedy_sample_kernel (贪心部分) target_probs = softmax(target_logits) sample_recovered_tokens(target_probs, draft_probs) recovered_token_ids rejection_random_sample_kernel (随机部分) output_token_ids output_token_ids [B, max_spec+1] _get_logprobs_tensors(...) SamplerOutput


附录F Triton Kernel计算流程图

Standalone Top-P 路径
Top-K + Top-P 联合路径
Pass 0: 采样块统计

avg, std, outlier_pivot
Pass 1: 全词表扫描

max, min, outlier收集
Ternary Search (在buffer)

找k_pivot
记录duplicate_logit

num_keep
Pass 3: Top-K元素softmax

sum_exp, outlier概率收集
Pass 4: 归一化buffer概率

probs / sum_exp
Ternary Search (在概率空间)

找p_pivot
更新duplicate_logit

num_keep
Pass 6: 应用final_pivot遮罩

  • 重复logit处理
    Pass 0: 采样块统计
    Pass 1: 计算softmax和

收集outlier概率
Ternary Search (在概率空间)

找p_pivot
SP6
返回过滤后的logits


附录G 拒绝采样数学证明

定理:拒绝采样算法的输出分布等价于直接从target模型采样。

证明

设draft分布为 q(x),target分布为 p(x)

步骤1:接受概率

复制代码
P(accept x) = P(draft generates x) × P(accept x | x generated)
            = q(x) × min(p(x)/q(x), 1)
            = min(p(x), q(x))

步骤2:拒绝后恢复分布

复制代码
P(reject) = 1 - Σ_x min(p(x), q(x))
P(recover y | reject) = max(p(y) - q(y), 0) / Σ_x max(p(x) - q(x), 0)

步骤3:最终输出分布

复制代码
P(output = y) = P(accept y) + P(reject) × P(recover y | reject)

对于被接受的token y:

复制代码
= min(p(y), q(y)) + [1 - Σ_x min(p(x), q(x))] × max(p(y) - q(y), 0) / Z

其中 Z = Σ_x max(p(x) - q(x), 0) = 1 - Σ_x min(p(x), q(x))

因此:

复制代码
P(output = y) = min(p(y), q(y)) + max(p(y) - q(y), 0) = p(y)

证毕。 □

bonus token :当所有K个draft token都被接受时,从target分布额外采样一个token。由于此时没有拒绝位置需要恢复,bonus token直接从 p(x) 采样,保证分布一致性。


附录H 术语表补充

术语 含义
PLACEHOLDER_TOKEN_ID (-1) 拒绝采样输出中被拒绝位置的占位符
MAX_SPEC_LEN (128) 单步最大投机token数
GREEDY_TEMPERATURE (0) 温度为0表示贪心采样
cu_num_draft_tokens 累积draft token偏移,类似CSR格式的行偏移
outlier_pivot 基于统计估计的初始枢轴值
k_pivot Top-K阈值:大于此值的logit保留
p_pivot Top-P阈值:概率大于此值的token保留
duplicate_logit 等于pivot值的logit,需特殊处理部分保留
BLOCK_SIZE (8192) Triton kernel的tile大小
BLOCK_SIZE_TRUNC (4096) 在buffer上的tile大小
Gumbel-Max argmax(prob × inv_exp) = Categorical(prob)
Synthetic Mode 使用预设接受率而非p/q比率的拒绝采样
Conditional Rate 在前面token已接受条件下的接受率
recovered_token 拒绝后从max(p-q,0)采样的修正token
Qrita 论文名,Pivot-based Top-K/Top-P算法
Ternary Search 三分搜索,每次评估2个pivot点
Grammar Mask 语法约束产生的-inf logit掩码

附录R rejection_sample() 完整数据流推演

R.1 场景: 3个请求,不同draft数量

复制代码
请求配置:
  请求0: 2个draft tokens, greedy采样
  请求1: 3个draft tokens, random采样(temp=0.8)
  请求2: 1个draft token, random采样(temp=1.0)

输入数据:
  draft_token_ids = [D00, D01, D10, D11, D12, D20]  # [6个]
  num_draft_tokens = [2, 3, 1]
  cu_num_draft_tokens = [2, 5, 6]  # 累积偏移

  bonus_token_ids = [B0, B1, B2]  # [3, 1]

  target_logits shape: [6, 32000]  # 6个token位置 × 词表大小

R.2 贪心请求(请求0)的拒绝采样

复制代码
rejection_greedy_sample_kernel[program_id=0]:

  is_greedy[0] = True  (temperature=0)
  start_idx = 0
  end_idx = 2  (cu_num_draft_tokens[0])
  num_draft_tokens = 2

  pos=0:
    draft_token_id = D00
    target_argmax_id = argmax(target_logits[0])
    # 贪心拒绝逻辑: D00 == argmax?
    如果 D00 == argmax → token_id = argmax, rejected=False
    如果 D00 != argmax → token_id = argmax, rejected=True

  pos=1 (如果rejected=False):
    draft_token_id = D01
    target_argmax_id = argmax(target_logits[1])
    同样检查

  如果全部接受:
    bonus_token_id = B0
    output[0] = [D00_or_argmax, D01_or_argmax, B0]
  
  如果pos=0就拒绝:
    output[0] = [argmax, -1, -1]
    # -1 = PLACEHOLDER_TOKEN_ID

R.3 随机请求(请求1)的拒绝采样

复制代码
rejection_random_sample_kernel[program_id=1]:

  is_greedy[1] = False  (temperature=0.8)
  start_idx = 2
  end_idx = 5  (cu_num_draft_tokens[1])
  num_draft_tokens = 3

  pos=0:
    draft_token_id = D10
    uniform_prob = U10  # 均匀随机数
    draft_prob = draft_probs[2, D10]  # q(D10)
    target_prob = target_probs[2, D10]  # p(D10)
    
    # 接受条件: q(D10) > 0 and p(D10)/q(D10) >= U10
    accepted = draft_prob > 0 and target_prob / draft_prob >= uniform_prob
    
    如果接受: token_id = D10
    如果拒绝: token_id = recovered_token_ids[2], rejected=True

  pos=1 (如果rejected=False):
    draft_token_id = D11
    同样计算 p/q vs uniform

  pos=2 (如果rejected=False):
    draft_token_id = D12
    同样计算

  如果全部接受:
    bonus_token_id = B1
    output[1] = [D10, D11, D12, B1]

R.4 完整输出矩阵

复制代码
output_token_ids: [batch_size=3, max_spec_len+1=4]

请求0 (greedy, 2 drafts):
  [accepted_0, accepted_1_or_recovered, B0_or_-1, -1]
  # 注意: max_spec_len=3,但请求0只有2个draft
  # 第4列(max_spec_len位置)由max_spec_len决定
  # 实际: 如果2个draft全接受,bonus在第3列(索引2)

请求1 (random, 3 drafts):
  [accepted_0, accepted_1, accepted_2_or_recovered, B1_or_-1]

请求2 (random, 1 draft):
  [accepted_0_or_recovered, B2_or_-1, -1, -1]

解析后(parse_output)
output_token_ids [3, 4]
请求0: [tok, tok, B0, -1]
请求1: [tok, tok, tok, B1]
请求2: [tok, B2, -1, -1]
请求0: [tok, tok, B0]
请求1: [tok, tok, tok, B1]
请求2: [tok, B2]


附录S Triton Kernel 三元搜索算法详解

S.1 标准三分搜索

复制代码
目标: 找到值x,使得 f(x) 满足特定条件

标准三分搜索:
  while range > ε:
    m1 = lo + (hi - lo) / 3
    m2 = lo + 2*(hi - lo) / 3
    
    if f(m1) < f(m2):
      hi = m2  # 去掉右1/3
    else:
      lo = m1  # 去掉左1/3

S.2 Qrita变体: 同时评估两个pivot

复制代码
Qrita的三元搜索:
  while not found:
    pivot_0 = lo + (hi - lo) / 3   # 1/3位置
    pivot_1 = lo + 2*(hi - lo) / 3 # 2/3位置
    
    count_0 = count(logits > pivot_0)  # 大于pivot_0的元素数
    count_1 = count(logits > pivot_1)  # 大于pivot_1的元素数
    
    # 检查是否命中
    if count_0 >= k and count_0 - dup_count_0 < k:
      found! pivot = pivot_0
    if count_1 >= k and count_1 - dup_count_1 < k:
      found! pivot = pivot_1
    
    # 更新范围
    if count_1 > k: lo = pivot_1  # 右pivot仍太多 → 抬高下限
    elif count_0 > k: lo = pivot_0  # 左pivot太多 → 抬高下限
    if count_0 < k: hi = pivot_0  # 左pivot太少 → 降低上限
    elif count_1 < k: hi = pivot_1  # 右pivot太少 → 降低上限

与标准三分搜索的区别:

  1. 不是找单峰函数的极值,而是找满足条件的阈值
  2. 终止条件是"恰好有K个元素大于等于pivot"
  3. 同时评估2个pivot,可能其中一个直接命中
  4. 最多18次迭代(float32精度足够)

S.3 收敛速度分析

复制代码
每次迭代:
  - 最优: pivot_0或pivot_1直接命中 → 1次
  - 最差: 范围缩小1/3 → log_3(V/logit_range) 次
  - 对于V=128000, range~20:
    log_3(128000/20) ≈ log_3(6400) ≈ 7.8 次
  - 18次上限提供充分余量

附录T expand_kernel Triton内核详解

T.1 累积偏移(Cumulative Offset)模式

vLLM投机解码大量使用"累积偏移"模式,类似CSR稀疏矩阵格式的行偏移:

复制代码
num_draft_tokens = [2, 3, 1]
cu_num_draft_tokens = [2, 5, 6]

含义:
  请求0: token 0-1 (start=0, end=2)
  请求1: token 2-4 (start=2, end=5)  
  请求2: token 5   (start=5, end=6)

T.2 expand_kernel 工作原理

复制代码
输入: x = [a, b, c], cu_num_tokens = [2, 5, 6]
输出: expanded = [a, a, b, b, b, c]

Program 0 (req_idx=0):
  start=0, end=2, num=2
  src_val = a
  store(output[0:2], a)  → output = [a, a, _, _, _, _]

Program 1 (req_idx=1):
  start=2, end=5, num=3
  src_val = b
  store(output[2:5], b)  → output = [a, a, b, b, b, _]

Program 2 (req_idx=2):
  start=5, end=6, num=1
  src_val = c
  store(output[5:6], c)  → output = [a, a, b, b, b, c]

T.3 replace_from/replace_to 参数

复制代码
温度展平时的特殊处理:
  temperature = [0, 0.8, 1.0]  # 请求0温度=0(贪心)
  replace_from = 0
  replace_to = 1

  expanded = [1, 1, 0.8, 0.8, 0.8, 1.0]
  # 贪心请求的温度被替换为1.0
  # 因为后续要做 logits /= temperature
  # temperature=0 会导致除零错误
  # 但贪心请求已经走了argmax路径,不需要温度缩放
  # temperature=1.0 等价于不缩放(logits /= 1.0 = logits)

附录U 查找表数据详解

U.1 _NORMAL_CDF_TO_SIGMA_TABLE

这个表将正态分布的百分位数映射到σ值。用于Standalone Top-P采样。

复制代码
索引0 → σ = 3.656  (对应百分位约0.5%,即p≈0.005)
索引1 → σ = 3.650
...
索引100 → σ ≈ 0.0  (对应百分位50%)
...
索引199 → σ = -3.813 (对应百分位约99.5%)

使用方式:
  p = 0.9  (top-p = 0.9)
  idx = int(0.9 * 200) = 180
  sigma = NORMAL_CDF_TO_SIGMA_TABLE[180] ≈ -1.881
  outlier_pivot = avg_logit + std_logit * sigma
  # sigma为负 → outlier_pivot < avg → 收集更多outlier

U.2 _PERCENTILE_TO_STD_TABLE

这个表将百分位数映射到标准正态分布的z值。用于Top-K采样。

复制代码
索引0 → z = 2.576  (对应百分位约0.5%)
索引100 → z ≈ 0.0  (对应百分位50%)
索引199 → z = -3.813 (对应百分位约99.5%)

使用方式:
  k = 50, V = 32000
  percentile = int(50/32000 * 200) = 0
  sigma = PERCENTILE_TO_STD_TABLE[0] = 2.576
  outlier_pivot = avg_logit + std_logit * 2.576
  # 很高的z值 → outlier_pivot远大于avg → 只收集极端outlier

U.3 偏移修正

python 复制代码
# Top-K偏移: sigma = sigma + |sigma| * -0.15
# 使估计偏保守(收集更多outlier)
# -0.15 * |sigma| 当sigma>0时减少sigma值 → 降低阈值 → 收集更多
# -0.15 * |sigma| 当sigma<0时增加sigma值(更负) → 也降低阈值

# Top-P偏移: sigma = sigma + |sigma| * -0.25
# 更激进的偏移(Top-P对遗漏更敏感)
# -0.25 vs -0.15: Top-P需要更保守的估计

为什么需要偏移

  • 估计基于第一个block的统计量,可能不完全代表全词表
  • 偏保守意味着多收集一些outlier → 后续三元搜索会精确化
  • 偏激进意味着可能遗漏 → 导致top-K/Top-P结果不准确
  • "宁可多收集,不可遗漏"的原则

附录V 拒绝采样与Gumbel-Max的统一视角

V.1 Gumbel-Max等价性

定理 : 对于概率分布 p(x) 和独立 G_x ~ Gumbel(0,1)

复制代码
argmax_x (log p(x) + G_x) ~ Categorical(p)

推论 : 对于 q_x ~ Exp(1)1/q_x ~ Gumbel(0,1)

复制代码
argmax_x (p(x) / q_x) ~ Categorical(p)

这就是 random_sample()sample_recovered_tokens_kernel 的数学基础。

V.2 拒绝采样的Gumbel-Max解释

标准拒绝采样可以理解为Gumbel-Max的变体:

复制代码
1. 对draft token x: 计算 r = p(x)/q(x)
2. 生成 u ~ Uniform(0,1)
3. 如果 u < r: 接受x
4. 否则: 从max(p-q,0)采样recovered token

等价于:
1. 对draft token x: 设定 Gumbel增量 g_x = -log(u) 当 u < r
2. 如果 g_x < -log(1-r): 接受
3. 否则: 从修正分布采样

但实际实现不用Gumbel表示,因为:
- p/q比较更直观
- Uniform随机数更易生成
- Triton kernel中Uniform比Gumbel更高效

V.3 sample_recovered_tokens 的Gumbel-Max实现

复制代码
prob(x) = max(target_prob(x) - draft_prob(x), 0)  # 修正分布
inv_q[x] = 1 / Exp(1)  # Gumbel噪声倒数

recovered = argmax_x (prob(x) * inv_q[x])

等价于:
  score(x) = log(prob(x)) + Gumbel(x)  (因为log(1/q) = Gumbel)
  但prob可能为0 → log(0) = -inf → 不影响argmax
  
  实际使用 prob * inv_q 而非 log(prob) + Gumbel:
  因为乘法避免了log(0)的NaN问题
  argmax不变: argmax(a*b) = argmax(log(a)+log(b)) 当a,b>0

附录W 投机解码中Logits处理器的特殊行为

W.1 MinTokens在投机解码中的展开

复制代码
常规采样: 每个请求1行logits
  logits [B, V]
  MinTokens: 对每行屏蔽stop tokens

投机解码: 每个请求N行logits (N=draft tokens数)
  logits [num_tokens, V]
  MinTokens: 对前remaining行屏蔽stop tokens

示例:
  请求0: min_tokens=5, current_output=3, num_draft=2
  remaining = 5 - 3 = 2
  n_mask = min(2, 2) = 2
  
  该请求的2个draft位置都需要屏蔽stop tokens:
  rows [0, 1] × stop_tokens → index_put_(rows, stop_ids, -inf)
  
  请求1: min_tokens=10, current_output=9, num_draft=3
  remaining = 10 - 9 = 1
  n_mask = min(1, 3) = 1
  
  只有第1个draft位置需要屏蔽:
  row [2] × stop_tokens → index_put_(row, stop_ids, -inf)

W.2 LogitBias在投机解码中的展开

复制代码
LogitBias不需要特殊处理:
  biases[req_idx] = {token_id: bias_value}

在投机解码中:
  repeat_indices将请求级参数展平到token级
  bias_tensor[repeat_indices] → 每个draft位置使用相同偏置

原因: LogitBias是全局的token偏好,不依赖于当前位置
  不管在哪个draft位置,对特定token的偏置值相同

W.3 ThinkingBudget在投机解码中的双次调用

复制代码
投机解码中logits处理器被调用两次:

调用1: bonus token logits (predict_bonus_token=True)
  只处理1行per request (bonus位置)
  如果force_index指向bonus → 在此调用中强制
  
调用2: target logits (predict_bonus_token=False)
  处理num_draft_tokens行per request
  如果force_index指向draft位置 → 在此调用中强制

设计原因:
  bonus token和draft token的logits是分开计算的
  需要在各自的logits张量上独立应用处理器
  避免在不同形状的张量上做广播

ThinkingBudget RejectionSampler ThinkingBudget RejectionSampler Phase 1: Bonus token force_index指向bonus? → 强制 force_index指向draft? → 跳过 Phase 2: Target logits force_index指向draft位置 → 强制 逐行设置mask和force_token_ids apply_to_logits(bonus_logits, predict_bonus_token=True) apply_to_logits(target_logits, predict_bonus_token=False)


附录X2 Triton Kernel Memory Access Pattern 分析

X2.1 全局内存访问模式

复制代码
_topk_topp_kernel 的内存访问:

输入: LOGITS [BATCH_SIZE, VOCAB_SIZE]
       每行VOCAB_SIZE=32000~128000个float32元素
       每行大小: 128KB~512KB

Pass 0 (采样块):
  读取: LOGITS_ROW[0:BLOCK_SIZE]  # 8192个float = 32KB
  写入: 无
  访问模式: 连续读取,高效的内存合并

Pass 1 (全词表扫描):
  读取: LOGITS_ROW[NUM_TILES × BLOCK_SIZE]  # 全词表
  写入: BUFFER_ROW[0:num_outliers]  # 稀疏写入
  访问模式: 连续读取,稀疏写入(outlier位置不确定)

三元搜索:
  读取: BUFFER_ROW[0:search_range]  # 只读outlier区域
  写入: 无(只更新搜索范围变量)
  访问模式: 连续读取BUFFER,效率高

Pass 6 (最终遮罩):
  读取: LOGITS_ROW[NUM_TILES × BLOCK_SIZE]  # 全词表
  写入: LOGITS_ROW[NUM_TILES × BLOCK_SIZE]  # 原地修改
  访问模式: 连续读写,最高效

X2.2 共享内存使用

复制代码
Triton kernel的共享内存使用:
  - BLOCK_SIZE=8192: 每个program处理1行
  - 中间变量存储在寄存器中(非共享内存)
  - BUFFER是全局内存(非共享内存)
  - 不使用tl.load到shared memory的模式

原因:
  - 每个program独立处理1行,不需要program间通信
  - BUFFER需要跨pass持久化,必须用全局内存
  - 共享内存大小有限(48KB~96KB),无法存8192个float

X2.3 Warp Divergence分析

复制代码
Kernel中的分支:
  1. TOPK_ENABLED / TOPP_ENABLED: 编译时常量 → 无divergence
  2. k < VOCAB_SIZE: 每行可能不同 → 可能divergence
  3. p < 1.0: 每行可能不同 → 可能divergence
  4. found_pivot: 循环条件 → 可能divergence
  5. duplicate处理: 条件分支 → 可能divergence

缓解策略:
  - 每个program处理1行 → 同一warp内的program处理不同行
  - 不同行的分支独立 → 不影响效率
  - 三元搜索的迭代次数差异 → 少数program可能更早退出
  - 总体: divergence影响较小

附录Y2 rejection_greedy_sample_kernel 完整执行追踪

Y2.1 示例: 5个请求的贪心拒绝采样

复制代码
配置:
  batch_size = 5
  num_draft_tokens = [3, 2, 4, 1, 2]
  cu_num_draft_tokens = [3, 5, 9, 10, 12]
  max_spec_len = 4

  draft_token_ids = [D00, D01, D02, D10, D11, D20, D21, D22, D23, D30, D40, D41]
  target_argmax  = [T00, T01, T02, T10, T11, T20, T21, T22, T23, T30, T40, T41]
  bonus_token_ids = [B0, B1, B2, B3, B4]

  is_greedy = [True, True, True, True, True]  # 全贪心

Kernel启动: 5个program(每个请求1个)

Program 0 (请求0, 3个draft):
  is_greedy[0] = True → 继续执行
  start_idx = 0, end_idx = 3, num_draft = 3
  
  pos=0: D00 vs T00
    如果 D00 == T00: token=D00(或T00), rejected=False
    如果 D00 != T00: token=T00, rejected=True, 后续位置保持-1
  
  假设 D00==T00, D01==T01, D02!=T02:
    output[0] = [T00, T01, T02, -1]
    # pos=0,1接受draft,pos=2拒绝→用target的argmax

Program 1 (请求1, 2个draft):
  start_idx = 3, end_idx = 5, num_draft = 2
  
  假设 D10==T10, D11==T11:
    rejected=False → 追加bonus
    output[1] = [T10, T11, B1, -1]

Program 2 (请求2, 4个draft):
  start_idx = 5, end_idx = 9, num_draft = 4
  
  假设全部接受:
    output[2] = [T20, T21, T22, T23, B2]
    # 注意: max_spec_len=4, 所以最多4个draft+1个bonus=5列
  
  但num_draft=4 == max_spec_len:
    output[2] = [T20, T21, T22, T23, B2]
    # 5列: 4个draft位置 + 1个bonus位置

Program 3 (请求3, 1个draft):
  start_idx = 9, end_idx = 10, num_draft = 1
  
  假设 D30==T30:
    output[3] = [T30, B3, -1, -1, -1]

Program 4 (请求4, 2个draft):
  start_idx = 10, end_idx = 12, num_draft = 2
  
  假设 D40!=T40:
    output[4] = [T40, -1, -1, -1, -1]
    # 第1个draft就拒绝,后续全为-1

最终输出:
  output_token_ids = [
    [T00, T01, T02, -1, -1],     # 请求0: 2接受+1恢复
    [T10, T11, B1,  -1, -1],     # 请求1: 全接受+bonus
    [T20, T21, T22, T23, B2],    # 请求2: 全接受+bonus
    [T30, B3,  -1,  -1,  -1],    # 请求3: 1接受+bonus
    [T40, -1,  -1,  -1,  -1],    # 请求4: 1恢复(0接受)
  ]

parse_output过滤-1后:
  [
    [T00, T01, T02],       # 3个token
    [T10, T11, B1],        # 3个token
    [T20, T21, T22, T23, B2],  # 5个token
    [T30, B3],             # 2个token
    [T40],                 # 1个token
  ]

Y2.2 Synthetic模式执行路径

复制代码
synthetic_conditional_rates = [0.8, 0.7, 0.6, 0.5]
# 位置0接受率0.8, 位置1接受率0.7, 位置2接受率0.6, 位置3接受率0.5

Program 0 (请求0, 3个draft, greedy+synthetic):
  pos=0:
    uniform_prob = U00
    rate = 0.8
    accepted = (U00 < 0.8)
    如果接受: token_id = D00 (draft token)
    如果拒绝: token_id = T00 (target argmax), rejected=True

  pos=1 (如果未拒绝):
    uniform_prob = U01
    rate = 0.7
    accepted = (U01 < 0.7)
    ...

  注意: synthetic模式不比较p/q
  而是使用预设的接受率
  这是因为ngram spec decode没有draft概率
  无法计算p/q比率

附录Z2 generate_uniform_probs float64精度必要性实验

Z2.1 问题复现

python 复制代码
# float32下的问题:
import torch
torch.manual_seed(42)

# 生成100万个float32均匀随机数
u32 = torch.rand(1_000_000, dtype=torch.float32)
count_zero_32 = (u32 == 0.0).sum().item()
# 结果: count_zero_32 ≈ 0~10 (取决于种子)
# 概率: P(u == 0.0) ≈ 2^{-24} ≈ 6e-8
# 期望: 1e6 × 6e-8 ≈ 0.06 → 通常为0或1

# 但在拒绝采样中:
# accepted = target_prob / draft_prob >= uniform_prob
# 如果 uniform_prob == 0.0:
#   任何 target_prob/draft_prob >= 0 都为True
#   → 本应拒绝的token被错误接受
#   → 采样分布偏离target分布

# float64下:
u64 = torch.rand(1_000_000, dtype=torch.float64)
count_zero_64 = (u64 == 0.0).sum().item()
# 结果: count_zero_64 = 0
# 概率: P(u == 0.0) ≈ 2^{-53} ≈ 1.1e-16
# 期望: 1e6 × 1.1e-16 ≈ 1.1e-10 → 实际上不可能为0

Z2.2 精度对采样分布的影响

复制代码
假设: draft_prob(x) = 0.1, target_prob(x) = 0.05
正确行为: p/q = 0.05/0.1 = 0.5 → 接受概率50%

float32 + uniform_prob恰好=0.0:
  0.5 >= 0.0 → True → 接受
  本应50%接受 → 实际100%接受 → 分布偏差

float64 + uniform_prob≈0.0但不精确=0:
  0.5 >= ~1e-16 → True → 接受
  但1e-16的概率几乎不可能 → 对分布影响可忽略

结论: float64消除了"精确零"问题,保证了采样的数学正确性

附录AA2 sample_recovered_tokens_kernel Block分块策略

AA2.1 为什么使用BLOCK_SIZE=8192

复制代码
vocab_size = 128000 (典型LLM词表大小)

如果BLOCK_SIZE太小 (如256):
  需要的迭代次数: 128000 / 256 = 500次
  每次迭代: 1次全局内存读取 + 1次比较 + 1次乘法
  总体: 500次内存读取 → 延迟高

如果BLOCK_SIZE太大 (如65536):
  超过Triton的最大BLOCK_SIZE限制
  寄存器压力过大 → 可能spill到local memory → 性能下降

BLOCK_SIZE=8192:
  迭代次数: 128000 / 8192 ≈ 16次
  每次迭代处理8192个元素
  Triton可高效处理这个大小
  寄存器使用合理

trade-off:
  更大的BLOCK → 更少的迭代 → 更少的kernel启动开销
  但寄存器压力增加 → 可能降低occupancy
  8192是实测的经验最优值

AA2.2 tile内max操作的实现

python 复制代码
# 在每个tile内找到score的最大值
score = prob * inv_q  # [BLOCK_SIZE] 向量

# tl.max(score, axis=0, return_indices=True)
# 返回: (最大值, 最大值的索引)
# axis=0: 在BLOCK_SIZE维度上归约
# return_indices=True: 同时返回索引

local_max, local_id = tl.max(score, axis=0, return_indices=True)

# 跨tile更新全局max
if local_max > max_val:
    max_val = local_max
    recovered_id = v + local_id  # v是tile的起始偏移

为什么不需要归一化prob:

复制代码
prob = max(target_prob - draft_prob, 0)

prob可能有很多0值 → 这些位置不影响argmax
prob不需要归一化 → argmax不关心绝对值,只关心相对大小

数学证明:
  argmax_x(prob[x] * inv_q[x])
  = argmax_x(log(prob[x]) + log(inv_q[x]))  (当prob>0时)
  = argmax_x(log(prob[x]) + Gumbel[x])       (因为inv_q ~ 1/Exp(1) ~ Gumbel)

  不需要prob归一化:
  如果prob[x] *= C (常数),则:
    argmax_x(C*prob[x] * inv_q[x])
    = argmax_x(log(C) + log(prob[x]) + Gumbel[x])
    = argmax_x(log(prob[x]) + Gumbel[x])  (log(C)是常数,不影响argmax)
    = 原始argmax

  所以: argmax(max(p-q,0) * inv_q) = argmax(归一化(max(p-q,0)) * inv_q)
  无需归一化步骤,节省了一次sum操作

附录BB2 SpecDecodeMetadata 关键字段说明

RejectionSampler接收的 SpecDecodeMetadata 包含以下关键字段:

字段 形状 含义
draft_token_ids [num_tokens] draft模型生成的token ids
num_draft_tokens list[int] 每个请求的draft token数量
max_spec_len int 最大spec长度(≤128)
cu_num_draft_tokens [batch_size] 累积draft token偏移
cu_num_sampled_tokens [batch_size] 累积采样token偏移(含bonus)
bonus_logits_indices [batch_size] bonus token在logits中的行索引
target_logits_indices [num_tokens] target token在logits中的行索引

logits的展平布局:

复制代码
logits shape: [num_tokens + batch_size, vocab_size]

对于 num_draft_tokens = [2, 3, 1]:
  num_tokens = 2 + 3 + 1 = 6
  total_rows = 6 + 3 = 9

  行0-1: 请求0的draft位置0-1
  行2-4: 请求1的draft位置0-2
  行5:   请求2的draft位置0
  行6:   请求0的bonus位置
  行7:   请求1的bonus位置
  行8:   请求2的bonus位置

  target_logits_indices = [0, 1, 2, 3, 4, 5]
  bonus_logits_indices = [6, 7, 8]

索引
target: [0,1,2,3,4,5]
bonus: [6,7,8]
logits [9, V]
行0: 请求0 draft-0
行1: 请求0 draft-1
行2: 请求1 draft-0
行3: 请求1 draft-1
行4: 请求1 draft-2
行5: 请求2 draft-0
行6: 请求0 bonus
行7: 请求1 bonus
行8: 请求2 bonus


附录CC2 拒绝采样与标准采样的等价性完整证明

CC2.1 单步拒绝采样

设draft分布为q(x),target分布为p(x)。

接受概率:

复制代码
P(accept x|draft generates x) = min(p(x)/q(x), 1)

p(x) ≥ q(x) 时: 接受概率=1(target比draft更倾向x)

p(x) < q(x) 时: 接受概率=p(x)/q(x)(按概率接受)

联合概率(draft生成x且被接受):

复制代码
P(accept x) = q(x) × min(p(x)/q(x), 1) = min(p(x), q(x))

拒绝概率:

复制代码
P(reject) = 1 - Σ_x min(p(x), q(x)) = Σ_x max(p(x)-q(x), 0)

(利用恒等式: min(a,b) + max(a-b,0) = a 当a≥b; min(a,b) + 0 = b 当a<b)

CC2.2 恢复采样

拒绝后,从修正分布采样:

复制代码
P(recover y|reject) = max(p(y)-q(y), 0) / Σ_x max(p(x)-q(x), 0)

CC2.3 最终输出分布

复制代码
P(output = y) = P(accept y) + P(reject) × P(recover y|reject)

= min(p(y), q(y)) + [Σ_x max(p(x)-q(x),0)] × [max(p(y)-q(y),0) / Σ_x max(p(x)-q(x),0)]

= min(p(y), q(y)) + max(p(y)-q(y), 0)

Case 1: p(y) ≥ q(y)
  = q(y) + (p(y)-q(y)) = p(y) ✓

Case 2: p(y) < q(y)
  = p(y) + 0 = p(y) ✓

因此: P(output = y) = p(y) 对所有y成立。 □

CC2.4 多步拒绝采样

对于K个draft token,按顺序验证:

复制代码
P(accept all K) = Π_{i=1}^{K} min(p_i/q_i, 1)
  = Π_{i=1}^{K} min(p(x_i)/q(x_i), 1)

P(reject at position j) = [Π_{i=1}^{j-1} min(p_i/q_i, 1)] × [1 - min(p_j/q_j, 1)]

由于每步独立(draft token独立生成),乘积分解成立。
被拒绝后只采一个recovered token,后续draft token全部丢弃。
最终输出 = (accepted tokens up to j-1) + (recovered token at j)

数学保证: 对于位置j的最终token,其分布等价于从p采样。
对于位置j+1及之后,重新开始新的投机步骤。
相关推荐
不剪发的Tony老师20 小时前
Databasus:一个免费开源的数据库备份管理平台
数据库
一拳一个娘娘腔20 小时前
【SRC漏洞挖掘系列】第09期:XXE与反序列化 —— 当XML和Java开始“吃”代码
xml·java·安全·web安全·github
努力成为AK大王20 小时前
从前端到数据库:一个 Web 项目的完整通信链路解析
前端·数据库·ajax·jdbc
粉嘟小飞妹儿20 小时前
Java Switch与Break用法详解
java·开发语言
艾莉丝努力练剑20 小时前
【QT】常用控件(三)Qt布局管理器(网格/表单/间隔器)
java·linux·运维·服务器·开发语言·网络·qt
骑士雄师20 小时前
python 的列表和java中的集合有什么区别
java·windows·python
csjane107920 小时前
Redis 分布式锁实战
java·redis
それども20 小时前
redis 集群操作进阶 - hashtag
数据库·redis·缓存
尋找記憶的魚20 小时前
基于langchain4j的ai编程助手项目(完整篇)
java·人工智能·spring boot·langchain·ai编程