vLLM V1 Sample 模块超深度架构分析 --- Part 3: 投机采样拒绝器与Triton Kernel
分析范围 :
rejection_sampler.py+ops/topk_topp_triton.py
目录
- [第十四章 RejectionSampler 投机解码拒绝采样器](#第十四章 RejectionSampler 投机解码拒绝采样器)
- [14.1 设计背景与算法原理](#14.1 设计背景与算法原理)
- [14.2 类结构与初始化](#14.2 类结构与初始化)
- [14.3 forward() 主流程深度解析](#14.3 forward() 主流程深度解析)
- [14.4 _get_logprobs_tensors() 日志概率重构](#14.4 _get_logprobs_tensors() 日志概率重构)
- [14.5 parse_output() 输出解析](#14.5 parse_output() 输出解析)
- [14.6 apply_logits_processors() 投机解码版](#14.6 apply_logits_processors() 投机解码版)
- [14.7 apply_penalties() 投机解码版](#14.7 apply_penalties() 投机解码版)
- [14.8 _combine_outputs_with_spec_tokens() 输出合并](#14.8 _combine_outputs_with_spec_tokens() 输出合并)
- [第十五章 rejection_sample() 核心拒绝采样函数](#第十五章 rejection_sample() 核心拒绝采样函数)
- [15.1 函数签名与输入验证](#15.1 函数签名与输入验证)
- [15.2 输出缓冲区初始化](#15.2 输出缓冲区初始化)
- [15.3 贪心路径: rejection_greedy_sample_kernel](#15.3 贪心路径: rejection_greedy_sample_kernel)
- [15.4 随机路径: rejection_random_sample_kernel](#15.4 随机路径: rejection_random_sample_kernel)
- [15.5 uniform_probs生成与float64精度](#15.5 uniform_probs生成与float64精度)
- [15.6 sample_recovered_tokens() 恢复token采样](#15.6 sample_recovered_tokens() 恢复token采样)
- [15.7 apply_sampling_constraints() 约束应用](#15.7 apply_sampling_constraints() 约束应用)
- [15.8 expand_batch_to_tokens() 批次展平](#15.8 expand_batch_to_tokens() 批次展平)
- [15.9 expand_kernel Triton内核](#15.9 expand_kernel Triton内核)
- [第十六章 Triton Top-K/Top-P Kernel超深度解析](#第十六章 Triton Top-K/Top-P Kernel超深度解析)
- [16.1 算法背景: Qrita论文](#16.1 算法背景: Qrita论文)
- [16.2 查找表设计](#16.2 查找表设计)
- [16.3 _update_min_larger_stats() 辅助函数](#16.3 _update_min_larger_stats() 辅助函数)
- [16.4 _topk_topp_kernel 整体结构](#16.4 _topk_topp_kernel 整体结构)
- [16.5 Top-K路径详细解析](#16.5 Top-K路径详细解析)
- [16.6 Top-P路径详细解析](#16.6 Top-P路径详细解析)
- [16.7 组合Top-K+Top-P路径](#16.7 组合Top-K+Top-P路径)
- [16.8 最终遮罩应用](#16.8 最终遮罩应用)
- [16.9 重复logit处理](#16.9 重复logit处理)
- [16.10 apply_top_k_top_p_triton() Python包装](#16.10 apply_top_k_top_p_triton() Python包装)
- [16.11 缓冲区与查找表缓存](#16.11 缓冲区与查找表缓存)
- [16.12 reset_buffer_cache() 缓存清理](#16.12 reset_buffer_cache() 缓存清理)
- [第十七章 sample_recovered_tokens_kernel Triton内核](#第十七章 sample_recovered_tokens_kernel Triton内核)
- [17.1 算法原理](#17.1 算法原理)
- [17.2 逐行解析](#17.2 逐行解析)
- [17.3 无draft_probs模式(ngram spec decode)](#17.3 无draft_probs模式(ngram spec decode))
- [附录E 投机解码端到端时序图](#附录E 投机解码端到端时序图)
- [附录F Triton Kernel计算流程图](#附录F Triton Kernel计算流程图)
- [附录G 拒绝采样数学证明](#附录G 拒绝采样数学证明)
- [附录H 术语表补充](#附录H 术语表补充)
第十四章 RejectionSampler 投机解码拒绝采样器
14.1 设计背景与算法原理
投机解码(Speculative Decoding) 是一种加速LLM推理的技术:
- 用小型draft模型快速生成K个候选token
- 用大型target模型并行验证这些token
- 接受与target分布一致的token,拒绝不一致的
- 对被拒绝的位置,从修正后的分布中重新采样
拒绝采样(Rejection Sampling) 的数学原理来自论文 Fast Inference from Transformers via Speculative Decoding:
对于draft概率 q(x) 和target概率 p(x):
- 接受条件 :
u < p(x)/q(x),其中u ~ Uniform(0,1) - 接受时:使用draft token
- 拒绝时 :从
max(p - q, 0)分布中采样恢复token - 全部接受时:额外采样一个bonus token
术语对照表:
| 术语 | 含义 |
|---|---|
| accepted tokens | 通过 p/q ≥ u 验证的draft token |
| recovered tokens | 拒绝后从 max(p-q,0) 采样的修正token |
| bonus tokens | 全部接受后额外采样的token |
| output tokens | accepted + recovered + bonus |
Synthetic模式 :一种特殊的拒绝采样变体,使用预设的接受率而非 p/q 比率。用于ngram spec decode等不提供draft概率的场景。
Yes
No
Yes
No
Draft模型生成
K个候选token
Target模型并行推理
验证K+1个位置
对比draft和target概率
p(x)/q(x) ≥ u?
接受draft token
继续检查下一个
拒绝→从max(p-q,0)采样
recovered token
还有更多draft?
全部接受→采样bonus token
output = accepted + recovered
output = accepted + bonus
14.2 类结构与初始化
python
class RejectionSampler(nn.Module):
def __init__(
self,
sampler: Sampler, # 内部使用的标准采样器
spec_config: SpeculativeConfig | None, # 投机解码配置
device: torch.device | None,
):
super().__init__()
self.sampler = sampler
# 确定logprobs模式
logprobs_mode = self.sampler.logprobs_mode
self.is_processed_logprobs_mode = logprobs_mode.startswith("processed")
self.is_logits_logprobs_mode = logprobs_mode.endswith("logits")
# Synthetic模式:预设接受率
self.synthetic_conditional_rates: torch.Tensor | None = None
if spec_config is not None and spec_config.rejection_sample_method == "synthetic":
assert spec_config.synthetic_acceptance_rates is not None
# 将无条件接受率转为条件接受率
self.synthetic_conditional_rates = torch.tensor(
unconditional_to_conditional_rates(
spec_config.synthetic_acceptance_rates
),
dtype=torch.float32,
device=device,
)
self.synthetic_mode = self.synthetic_conditional_rates is not None
无条件→条件接受率转换:
设 α_1, α_2, ..., α_K 为无条件接受率(每个位置独立接受的概率),则条件接受率为:
β_1 = α_1
β_2 = α_2 / (1 - α_1) # 在第1个已接受的前提下
β_3 = α_3 / ((1 - α_1)(1 - α_2)) # 在前2个已接受的前提下
...
这确保了联合接受概率与无条件设定一致。
RejectionSampler
+sampler: Sampler
+is_processed_logprobs_mode: bool
+is_logits_logprobs_mode: bool
+synthetic_conditional_rates: Tensor|None
+synthetic_mode: bool
+forward(metadata, draft_probs, logits, sampling_metadata) : SamplerOutput
+_get_logprobs_tensors(max_num, metadata, logits, target_logits, bonus_logits, sampled) : LogprobsTensors
+parse_output(output_token_ids, vocab_size, discard_req_indices, logprobs_tensors) : tuple
+apply_logits_processors(logits, sampling_metadata, metadata) : Tensor
+apply_penalties(logits, sampling_metadata, metadata, repeat_indices, output_token_ids) : Tensor
+_combine_outputs_with_spec_tokens(output, spec) : list
Sampler
+forward(logits, metadata) : SamplerOutput
14.3 forward() 主流程深度解析
python
def forward(
self,
metadata: SpecDecodeMetadata, # 投机解码元数据
draft_probs: torch.Tensor | None, # [num_tokens, vocab_size] draft概率
logits: torch.Tensor, # [num_tokens + batch_size, vocab_size] target logits
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
逐行解析:
python
assert metadata.max_spec_len <= MAX_SPEC_LEN # 最大128个spec tokens
# ===== Phase 1: 采样bonus token =====
bonus_logits_indices = metadata.bonus_logits_indices
target_logits_indices = metadata.target_logits_indices
# NOTE: PyTorch索引创建新张量,与原logits存储分离
# 因此对bonus_logits的原地操作不影响原始logits
assert logits is not None
bonus_logits = logits[bonus_logits_indices] # [batch_size, vocab_size]
# 使用标准Sampler采样bonus token
# 特殊设置:
# - max_num_logprobs=-1: 特殊模式(spec decode需要所有logits用于后续计算)
# - predict_bonus_token=True: 标记这是bonus token
# - logprobs_mode_override: 使用processed_logits或raw_logits模式
# 因为bonus logits需要用于后续计算accepted token的logprobs
bonus_sampler_output = self.sampler(
logits=bonus_logits,
sampling_metadata=replace(
sampling_metadata,
max_num_logprobs=-1,
),
predict_bonus_token=True,
logprobs_mode_override=(
"processed_logits" if self.is_processed_logprobs_mode else "raw_logits"
),
)
bonus_token_ids = bonus_sampler_output.sampled_token_ids # [batch_size, 1]
为什么bonus token需要override logprobs_mode:
- 在spec decode中,bonus token的logits需要用于计算accepted token的logprobs
- 如果用
raw_logprobs模式,计算的是log(softmax(raw_logits)),但accepted token的logprobs应该基于target logits(经过处理后的) - 因此强制使用
processed_logits或raw_logits模式,保存处理后的logits供后续使用
python
# ===== Phase 2: 处理target logits =====
raw_target_logits = logits[target_logits_indices] # [num_tokens, vocab_size]
raw_target_logits = raw_target_logits.to(torch.float32)
target_logits = raw_target_logits
if not self.is_processed_logprobs_mode:
# 非processed模式:需要clone保留原始logits
# 因为apply_logits_processors会原地修改
target_logits = target_logits.clone()
# 应用logits处理器(包含penalties、bad_words、thinking budget等)
target_logits = self.apply_logits_processors(
target_logits, sampling_metadata, metadata
)
# 应用采样约束(温度、top-k、top-p)
target_logits = apply_sampling_constraints(
target_logits,
metadata.cu_num_draft_tokens,
sampling_metadata,
)
python
# ===== Phase 3: 拒绝采样 =====
output_token_ids = rejection_sample(
metadata.draft_token_ids,
metadata.num_draft_tokens,
metadata.max_spec_len,
metadata.cu_num_draft_tokens,
draft_probs,
target_logits,
bonus_token_ids,
sampling_metadata,
synthetic_mode=self.synthetic_mode,
synthetic_conditional_rates=self.synthetic_conditional_rates,
)
# output_token_ids: [batch_size, max_spec_len + 1]
# 被拒绝的位置填入PLACEHOLDER_TOKEN_ID (-1)
# ===== Phase 4: 计算logprobs =====
logprobs_tensors = None
if sampling_metadata.max_num_logprobs is not None:
logprobs_tensors = self._get_logprobs_tensors(
sampling_metadata.max_num_logprobs,
metadata,
logits,
target_logits if self.is_processed_logprobs_mode else raw_target_logits,
bonus_sampler_output.logprobs_tensors.logprobs,
output_token_ids,
)
return SamplerOutput(
sampled_token_ids=output_token_ids,
logprobs_tensors=logprobs_tensors,
)
Yes
No
RejectionSampler.forward()
Phase 1: 采样bonus token
sampler(bonus_logits)
bonus_token_ids [B, 1]
Phase 2: 处理target logits
clone(raw_target_logits)
apply_logits_processors()
apply_sampling_constraints()
(temp, top-k, top-p)
processed target_logits
Phase 3: rejection_sample()
output_token_ids [B, max_spec+1]
rejected → -1
max_num_logprobs≠None?
_get_logprobs_tensors()
logprobs = None
SamplerOutput
14.4 _get_logprobs_tensors() 日志概率重构
python
def _get_logprobs_tensors(
self,
max_num_logprobs: int,
metadata: SpecDecodeMetadata,
logits: torch.Tensor, # 原始全量logits [num_tokens+B, V]
target_logits: torch.Tensor, # 处理后的target logits [num_tokens, V]
bonus_logits: torch.Tensor, # bonus token的logits
sampled_token_ids: torch.Tensor, # [B, max_spec+1]
) -> LogprobsTensors:
# 计算每个请求的起始token偏移
cu_num_sampled_tokens = torch.zeros_like(metadata.cu_num_sampled_tokens)
cu_num_sampled_tokens[1:] = metadata.cu_num_sampled_tokens[:-1]
# cu_num_sampled_tokens[i] = 第i个请求在全局token中的起始索引
# 合并target和bonus的logits
bonus_logits_indices = metadata.bonus_logits_indices
target_logits_indices = metadata.target_logits_indices
final_logits = torch.zeros_like(logits, dtype=torch.float32)
final_logits[target_logits_indices] = target_logits.to(torch.float32)
final_logits[bonus_logits_indices] = bonus_logits.to(torch.float32)
# 计算每个采样token在final_logits中的行索引
# 注意:包括被拒绝的token(后续在parse_output中过滤)
logit_start_indices = cu_num_sampled_tokens
offsets = torch.arange(
sampled_token_ids.shape[-1],
device=logit_start_indices.device,
dtype=logit_start_indices.dtype,
)
accepted_logit_indices = (
logit_start_indices.unsqueeze(1) + offsets.unsqueeze(0)
).flatten()
accepted_logit_indices.clamp_(max=final_logits.shape[0] - 1)
# 替换被拒绝token的id为0(避免gather越界)
accepted_tokens = sampled_token_ids.clone().flatten()
accepted_tokens[accepted_tokens == PLACEHOLDER_TOKEN_ID] = 0
# 提取accepted token对应的logits
accepted_logits = final_logits[accepted_logit_indices]
# 计算logprobs
accepted_logprobs = (
accepted_logits
if self.is_logits_logprobs_mode
else self.sampler.compute_logprobs(accepted_logits)
)
# Gather top-N logprobs
return self.sampler.gather_logprobs(
accepted_logprobs,
max_num_logprobs,
accepted_tokens.to(torch.int64),
)
设计要点:
- 被拒绝的token也会计算logit索引(虽然后续会被过滤),避免CPU-GPU同步
accepted_tokens[==PLACEHOLDER] = 0:将-1替换为0,确保gather不越界;但logprob值无意义,会在parse_output中过滤
14.5 parse_output() 输出解析
python
@staticmethod
def parse_output(
output_token_ids: torch.Tensor, # [B, max_spec+1]
vocab_size: int,
discard_req_indices: Sequence[int] = (), # 需要丢弃的请求索引
logprobs_tensors: LogprobsTensors | None = None,
) -> tuple[list[list[int]], LogprobsLists | None]:
"""解析拒绝采样输出
将PLACEHOLDER_TOKEN_ID替换的行过滤掉,
保留实际接受的token序列
"""
output_token_ids_np = output_token_ids.cpu().numpy()
# 创建有效token掩码
valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & (
output_token_ids_np < vocab_size
)
# valid_mask[i, j] = True → 第i个请求的第j个token有效
# 处理logprobs
output_logprobs = None
if logprobs_tensors is not None:
cu_num_tokens = [0] + valid_mask.sum(axis=1).cumsum().tolist()
filtered_tensors = logprobs_tensors.filter(valid_mask.flatten())
output_logprobs = filtered_tensors.tolists(cu_num_tokens)
# 丢弃指定请求(如被调度器标记为preempted的)
if len(discard_req_indices) > 0:
valid_mask[discard_req_indices] = False
# 提取有效token
outputs = [
row[valid_mask[i]].tolist()
for i, row in enumerate(output_token_ids_np)
]
return outputs, output_logprobs
14.6 apply_logits_processors() 投机解码版
python
def apply_logits_processors(
self,
logits: torch.Tensor, # [num_tokens, vocab_size]
sampling_metadata: SamplingMetadata,
metadata: SpecDecodeMetadata,
) -> torch.Tensor:
has_penalties = not sampling_metadata.no_penalties
any_penalties_or_bad_words = (
sampling_metadata.bad_words_token_ids or has_penalties
)
holder = sampling_metadata.thinking_budget_state_holder
needs_thinking = holder is not None and holder.has_tracked_requests()
output_token_ids = sampling_metadata.output_token_ids
# 如果有惩罚/bad_words/thinking,需要合并output和spec tokens
if any_penalties_or_bad_words or needs_thinking:
output_token_ids = self._combine_outputs_with_spec_tokens(
output_token_ids,
sampling_metadata.spec_token_ids,
)
# 合并原因:惩罚需要考虑draft tokens已生成的部分
# ===== 计算repeat_indices =====
# repeat_indices: 将请求级参数展平到token级
repeat_indices: torch.Tensor | None = None
need_repeat_indices = (
sampling_metadata.allowed_token_ids_mask is not None
or has_penalties
or needs_thinking
)
if need_repeat_indices:
num_requests = len(metadata.num_draft_tokens)
num_draft_tokens = torch.tensor(metadata.num_draft_tokens, device="cpu")
original_indices = torch.arange(num_requests, device="cpu")
# repeat_interleave: 将每个请求索引重复其draft token数次
# 例如: [0, 1, 2] with [2, 3, 1] → [0, 0, 1, 1, 1, 2]
repeat_indices_cpu = original_indices.repeat_interleave(num_draft_tokens)
repeat_indices = repeat_indices_cpu.to(device=logits.device, non_blocking=True)
# 应用惩罚
logits = self.apply_penalties(
logits, sampling_metadata, metadata, repeat_indices, output_token_ids
)
# 应用token白名单
if sampling_metadata.allowed_token_ids_mask is not None:
token_mask = sampling_metadata.allowed_token_ids_mask[repeat_indices]
logits.masked_fill_(token_mask, float("-inf"))
# ===== 应用bad_words =====
if bad_words_token_ids := sampling_metadata.bad_words_token_ids:
apply_bad_words_with_drafts(
logits, bad_words_token_ids, output_token_ids, metadata.num_draft_tokens
)
# ===== 应用MinTokens(spec decode版本)=====
for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
if isinstance(processor, MinTokensLogitsProcessor):
logits = processor.apply_with_spec_decode(
logits, metadata.num_draft_tokens
)
# ===== 应用thinking budget =====
if holder is not None and holder.has_tracked_requests():
logits = holder.apply_to_logits(
logits,
predict_bonus_token=False,
spec_token_ids=sampling_metadata.spec_token_ids,
)
return logits
14.7 apply_penalties() 投机解码版
python
@staticmethod
def apply_penalties(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
metadata: SpecDecodeMetadata,
repeat_indices: torch.Tensor,
output_token_ids: list[list[int]],
) -> torch.Tensor:
if sampling_metadata.no_penalties:
return logits
assert sampling_metadata.prompt_token_ids is not None
# 使用repeat_indices将请求级参数展平到token级
prompt_token_ids = sampling_metadata.prompt_token_ids[repeat_indices]
presence_penalties = sampling_metadata.presence_penalties[repeat_indices]
frequency_penalties = sampling_metadata.frequency_penalties[repeat_indices]
repetition_penalties = sampling_metadata.repetition_penalties[repeat_indices]
return apply_all_penalties(
logits,
prompt_token_ids,
presence_penalties,
frequency_penalties,
repetition_penalties,
output_token_ids,
)
14.8 _combine_outputs_with_spec_tokens() 输出合并
python
@staticmethod
def _combine_outputs_with_spec_tokens(
output_token_ids: list[list[int]],
spec_token_ids: list[list[int]] | None = None,
) -> list[list[int]]:
"""将output tokens与spec tokens合并
用于惩罚计算:每个draft位置需要知道"到目前为止已生成的所有token"
Example:
output = [[10, 20]] # 原始输出
spec = [[30, 40]] # 2个draft tokens
result = [[10, 20], # draft位置0: output + spec[0]
[10, 20, 30]] # draft位置1: output + spec[0] + spec[1]
"""
if spec_token_ids is None:
return output_token_ids
result = []
for out, spec in zip(output_token_ids, spec_token_ids):
if len(spec) == 0:
continue
result.append(out) # 第一个draft位置:只有原始output
for i in range(len(spec) - 1):
# 后续位置:追加之前的spec tokens
result.append([*result[-1], spec[i]])
return result
output
input
output = [[10, 20]]
spec = [[30, 40]]
result[0] = [10, 20]
result[1] = [10, 20, 30]
第十五章 rejection_sample() 核心拒绝采样函数
15.1 函数签名与输入验证
python
def rejection_sample(
draft_token_ids: torch.Tensor, # [num_tokens] draft token id
num_draft_tokens: list[int], # [batch_size] 每请求draft数
max_spec_len: int, # 最大spec长度
cu_num_draft_tokens: torch.Tensor, # [batch_size] 累积draft token偏移
draft_probs: torch.Tensor | None, # [num_tokens, vocab_size] draft概率
target_logits: torch.Tensor, # [num_tokens, vocab_size] 处理后target logits
bonus_token_ids: torch.Tensor, # [batch_size, 1] bonus token id
sampling_metadata: SamplingMetadata,
synthetic_mode: bool = False,
synthetic_conditional_rates: torch.Tensor | None = None,
) -> torch.Tensor:
"""执行拒绝采样,返回 [batch_size, max_spec_len + 1] 的输出token矩阵"""
# ===== 输入验证 =====
assert draft_token_ids.ndim == 1
assert draft_probs is None or draft_probs.ndim == 2
assert cu_num_draft_tokens.ndim == 1
assert target_logits.ndim == 2
batch_size = len(num_draft_tokens)
num_tokens = draft_token_ids.shape[0]
vocab_size = target_logits.shape[-1]
device = target_logits.device
# 连续性检查(Triton kernel要求连续内存)
assert draft_token_ids.is_contiguous()
assert draft_probs is None or draft_probs.is_contiguous()
assert bonus_token_ids.is_contiguous()
assert target_logits.shape == (num_tokens, vocab_size)
15.2 输出缓冲区初始化
python
# 创建输出缓冲区
output_token_ids = torch.full(
(batch_size, max_spec_len + 1),
PLACEHOLDER_TOKEN_ID, # -1,表示"未填充"
dtype=torch.int32, # 与SamplerOutput.sampled_token_ids一致
device=device,
)
# 形状 [batch_size, max_spec_len + 1]
# "+1" 是给bonus token的位置
15.3 贪心路径: rejection_greedy_sample_kernel
python
# ===== 判断贪心/随机 =====
if sampling_metadata.all_greedy:
is_greedy = None # 全贪心,不需要逐请求判断
else:
is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE
# [batch_size] bool,True=贪心,False=随机
# ===== 生成均匀随机数 =====
uniform_probs: torch.Tensor | None = None
if synthetic_mode or not sampling_metadata.all_greedy:
# synthetic模式或非全贪心:需要均匀随机数
uniform_probs = generate_uniform_probs(
num_tokens,
num_draft_tokens,
sampling_metadata.generators,
device,
)
# ===== 贪心请求的拒绝采样 =====
if not sampling_metadata.all_random:
target_argmax = target_logits.argmax(dim=-1) # [num_tokens]
rejection_greedy_sample_kernel[(batch_size,)](
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
target_argmax,
bonus_token_ids,
is_greedy,
max_spec_len,
uniform_probs,
synthetic_conditional_rates,
SYNTHETIC_MODE=synthetic_mode,
)
if sampling_metadata.all_greedy:
return output_token_ids # 全贪心:直接返回
rejection_greedy_sample_kernel 逐行解析:
python
@triton.jit(do_not_specialize=["max_spec_len"])
def rejection_greedy_sample_kernel(
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
target_argmax_ptr, # [num_tokens]
bonus_token_ids_ptr, # [batch_size]
is_greedy_ptr, # [batch_size] or None
max_spec_len,
uniform_probs_ptr, # [num_tokens] (synthetic mode)
synthetic_conditional_rates_ptr, # [num_spec_tokens] (synthetic mode)
SYNTHETIC_MODE: tl.constexpr,
):
req_idx = tl.program_id(0) # 每个program处理一个请求
# 读取is_greedy标记
is_greedy = True if is_greedy_ptr is None else tl.load(is_greedy_ptr + req_idx)
if not is_greedy:
return # 非贪心请求:跳过(由random kernel处理)
# 读取该请求的draft token范围
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = end_idx - start_idx
# 逐个验证draft token
rejected = False
for pos in range(num_draft_tokens):
if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos).to(tl.int32)
if SYNTHETIC_MODE:
# Synthetic模式:用预设接受率判断
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
rate = tl.load(synthetic_conditional_rates_ptr + pos)
accepted = uniform_prob < rate
token_id = draft_token_id if accepted else target_argmax_id
rejected = not accepted
else:
# 标准模式:draft == target_argmax 即接受
token_id = target_argmax_id
rejected = draft_token_id != target_argmax_id
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
token_id,
)
# 如果全部接受,追加bonus token
if not rejected:
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens,
bonus_token_id,
)
贪心模式的拒绝逻辑:
- 贪心采样时,target模型的输出是确定的(argmax)
- 因此只需检查draft token是否与target的argmax一致
- 不一致即拒绝,直接使用target的argmax作为recovered token
- 无需计算概率比率或采样
No
Yes
Yes
No
Yes
No
Accept
Reject
Yes
No
rejection_greedy_sample_kernel
is_greedy?
return (由random kernel处理)
遍历draft tokens
SYNTHETIC_MODE?
uniform_prob < rate?
接受: draft_token
拒绝: target_argmax
draft == target_argmax?
接受draft token
拒绝→用target_argmax
rejected=True
用draft_token
用target_argmax
rejected=True
还有draft?
后续位置保持PLACEHOLDER
追加bonus_token
15.4 随机路径: rejection_random_sample_kernel
python
# ===== 计算target概率分布 =====
target_probs = target_logits.softmax(dim=-1, dtype=torch.float32)
assert target_probs.is_contiguous()
# ===== 采样recovered tokens =====
recovered_token_ids = sample_recovered_tokens(
max_spec_len,
num_draft_tokens,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
sampling_metadata,
device,
)
# ===== 随机请求的拒绝采样 =====
assert uniform_probs is not None
rejection_random_sample_kernel[(batch_size,)](
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
bonus_token_ids,
recovered_token_ids,
uniform_probs,
is_greedy,
max_spec_len,
vocab_size,
synthetic_conditional_rates,
NO_DRAFT_PROBS=draft_probs is None,
SYNTHETIC_MODE=synthetic_mode,
)
return output_token_ids
rejection_random_sample_kernel 逐行解析:
python
@triton.jit(do_not_specialize=["max_spec_len"])
def rejection_random_sample_kernel(
output_token_ids_ptr,
cu_num_draft_tokens_ptr,
draft_token_ids_ptr,
draft_probs_ptr, # [num_tokens, vocab_size] or None
target_probs_ptr, # [num_tokens, vocab_size]
bonus_token_ids_ptr,
recovered_token_ids_ptr, # [num_tokens]
uniform_probs_ptr, # [num_tokens]
is_greedy_ptr,
max_spec_len,
vocab_size,
synthetic_conditional_rates_ptr,
NO_DRAFT_PROBS: tl.constexpr,
SYNTHETIC_MODE: tl.constexpr,
):
req_idx = tl.program_id(0)
is_greedy = tl.load(is_greedy_ptr + req_idx)
if is_greedy:
return # 贪心请求已在greedy kernel中处理
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = end_idx - start_idx
rejected = False
for pos in range(num_draft_tokens):
if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
if SYNTHETIC_MODE:
rate = tl.load(synthetic_conditional_rates_ptr + pos)
accepted = uniform_prob < rate
else:
if NO_DRAFT_PROBS:
draft_prob = 1 # ngram模式:无draft概率,总是接受
else:
# 读取draft概率 q(draft_token)
draft_prob = tl.load(
draft_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id
)
# 读取target概率 p(draft_token)
target_prob = tl.load(
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id
)
# 接受条件: u < p(x)/q(x)
# 特殊处理: q(x)=0 → 拒绝(避免NaN)
accepted = draft_prob > 0 and target_prob / draft_prob >= uniform_prob
if accepted:
token_id = draft_token_id
else:
rejected = True
# 使用预采样的recovered token
token_id = tl.load(recovered_token_ids_ptr + start_idx + pos)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
token_id,
)
if not rejected:
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens,
bonus_token_id,
)
标准拒绝采样的数学验证:
接受概率 = P(u < p/q) = p/q(当u~Uniform(0,1)时)
- 当
p(x)/q(x) ≥ 1时,接受概率为1(target更倾向该token) - 当
p(x)/q(x) < 1时,按概率接受(target不如draft倾向该token)
被拒绝后从 max(p-q, 0) 采样:
- 这是
p和q之间的差异分布 - 保证了最终分布与直接从target采样一致
15.5 uniform_probs生成与float64精度
python
def generate_uniform_probs(
num_tokens: int,
num_draft_tokens: list[int],
generators: dict[int, torch.Generator],
device: torch.device,
) -> torch.Tensor:
"""生成均匀随机数用于拒绝采样"""
# NOTE: 使用float64而非float32
# 原因: float32时uniform_prob可能恰好为0.0
# 参考: https://github.com/pytorch/pytorch/issues/16706
# 这会导致 p/q >= 0 恒真,本应拒绝的token被错误接受
uniform_probs = torch.rand(
(num_tokens,),
dtype=torch.float64, # 关键: 使用double精度
device=device,
)
# 对有seed的请求覆盖随机数
start_idx = 0
for req_idx, n in enumerate(num_draft_tokens):
if n == 0:
continue # 无draft token → 不生成随机数(复现性重要)
end_idx = start_idx + n
generator = generators.get(req_idx)
if generator is not None:
uniform_probs[start_idx:end_idx].uniform_(generator=generator)
start_idx = end_idx
return uniform_probs
float64的必要性:
- float32的范围:
[-3.4e38, 3.4e38],有效数字7位 torch.rand()在float32下可能精确产生0.0(概率约1/2^24≈ 6e-8)- 当
uniform_prob = 0.0时,p/q >= 0恒真,本应拒绝的token被错误接受 - float64下精确产生0.0的概率约
1/2^53≈ 1.1e-16,可忽略
15.6 sample_recovered_tokens() 恢复token采样
python
def sample_recovered_tokens(
max_spec_len: int,
num_draft_tokens: list[int],
cu_num_draft_tokens: torch.Tensor,
draft_token_ids: torch.Tensor,
draft_probs: torch.Tensor | None,
target_probs: torch.Tensor,
sampling_metadata: SamplingMetadata,
device: torch.device,
) -> torch.Tensor:
"""为每个draft位置预采样recovered token
使用Gumbel-Max技巧从 max(p-q, 0) 分布采样
"""
batch_size = len(num_draft_tokens)
vocab_size = target_probs.shape[-1]
# 生成指数分布噪声
q = torch.empty(
(batch_size, vocab_size),
dtype=torch.float32,
device=device,
)
q.exponential_() # 全局Exp(1)
# 对有seed的请求覆盖
for i, generator in sampling_metadata.generators.items():
if num_draft_tokens[i] > 0:
q[i].exponential_(generator=generator)
# 取倒数: 1/q 用于Gumbel-Max
inv_q = q.reciprocal()
# 预分配输出
recovered_token_ids = torch.empty_like(draft_token_ids)
# Triton kernel: 对每个(请求, 位置)计算recovered token
BLOCK_SIZE = 8192
sample_recovered_tokens_kernel[(batch_size, max_spec_len)](
recovered_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
inv_q,
vocab_size,
BLOCK_SIZE,
NO_DRAFT_PROBS=draft_probs is None,
)
return recovered_token_ids
数学原理:
- 拒绝后的修正分布为
max(p(x) - q(x), 0) - Gumbel-Max技巧:
argmax_x(prob(x) * inv_q[x])等价于从prob分布采样 - 这里
prob = max(target_prob - draft_prob, 0) - 由于只关心argmax,不需要归一化
max(p-q, 0)
为什么每个请求只生成一个 inv_q:
- 同一请求的多个draft位置共享一个recovered分布
- 但实际上每个位置的
p和q不同(不同位置的概率分布不同) - kernel内部按位置使用不同的
target_probs[pos]和draft_probs[pos]
15.7 apply_sampling_constraints() 约束应用
python
def apply_sampling_constraints(
logits: torch.Tensor, # [num_tokens, vocab_size]
cu_num_draft_tokens: torch.Tensor, # [batch_size]
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
"""对spec decode的logits应用温度/top-k/top-p约束"""
if sampling_metadata.all_greedy:
return logits # 全贪心:无需约束
num_tokens = logits.shape[0]
# 温度缩放:展平请求级→token级
temperature = expand_batch_to_tokens(
sampling_metadata.temperature,
cu_num_draft_tokens,
num_tokens,
replace_from=GREEDY_TEMPERATURE, # 0 → 1(避免除以0)
replace_to=1,
)
# 原地除法
logits.div_(temperature.unsqueeze(-1))
# top-k展平
top_k = None
if sampling_metadata.top_k is not None:
top_k = expand_batch_to_tokens(
sampling_metadata.top_k,
cu_num_draft_tokens,
num_tokens,
)
# top-p展平
top_p = None
if sampling_metadata.top_p is not None:
top_p = expand_batch_to_tokens(
sampling_metadata.top_p,
cu_num_draft_tokens,
num_tokens,
)
# 应用top-k/top-p
return apply_top_k_top_p(logits, top_k, top_p)
15.8 expand_batch_to_tokens() 批次展平
python
def expand_batch_to_tokens(
x: torch.Tensor, # [batch_size]
cu_num_tokens: torch.Tensor, # [batch_size] 累积token数
num_tokens: int,
replace_from: int = 0,
replace_to: int = 0,
) -> torch.Tensor:
"""将请求级张量展平为token级
Example: x = [a, b, c], cu_num_tokens = [2, 5, 6]
→ expanded = [a, a, b, b, b, c]
"""
batch_size = x.shape[0]
expanded_x = x.new_empty(num_tokens)
expand_kernel[(batch_size,)](
expanded_x, x, cu_num_tokens,
replace_from, replace_to,
MAX_NUM_TOKENS=*** # 避免重编译
)
return expanded_x
15.9 expand_kernel Triton内核
python
@triton.jit(do_not_specialize=["replace_from", "replace_to"])
def expand_kernel(
output_ptr, # [num_tokens]
input_ptr, # [batch_size]
cu_num_tokens_ptr, # [batch_size]
replace_from,
replace_to,
MAX_NUM_TOKENS: tl.constexpr,
):
req_idx = tl.program_id(0)
# 计算该请求的token范围
start_idx = 0 if req_idx == 0 else tl.load(cu_num_tokens_ptr + req_idx - 1)
end_idx = tl.load(cu_num_tokens_ptr + req_idx)
num_tokens = end_idx - start_idx
# 读取源值
src_val = tl.load(input_ptr + req_idx)
# 可选替换: replace_from → replace_to
src_val = tl.where(src_val == replace_from, replace_to, src_val)
# 批量写入
offset = tl.arange(0, MAX_NUM_TOKENS)
tl.store(output_ptr + start_idx + offset, src_val, mask=offset < num_tokens)
replace_from/replace_to 的用途:
- 温度展平时:
replace_from=0, replace_to=1--- 将温度0(贪心)替换为1(不做缩放) - 避免除以0错误
第十六章 Triton Top-K/Top-P Kernel超深度解析
16.1 算法背景: Qrita论文
topk_topp_triton.py 基于 Qrita 论文 (https://arxiv.org/abs/2602.01518),核心思想是基于枢轴的截断与选择(Pivot-based Truncation and Selection)。
传统方法:
- 对每行排序 → O(V log V)
- 计算累积概率 → O(V)
- 应用top-k/top-p阈值 → O(V)
Qrita方法:
- 估计枢轴(pivot)→ 通过三元搜索 O(log V × log V)
- 直接应用阈值 → O(V)
关键优化:不需要排序整个词表,只需找到一个阈值logit值,低于此值的token被屏蔽为-inf。
16.2 查找表设计
两个查找表在编译时生成,存储在模块级变量中:
python
_NORMAL_CDF_TO_SIGMA_TABLE = [...] # 200个元素
# 用途: 独立Top-P采样
# 给定百分位数p,查表得到正态分布CDF(p)对应的σ值
# 用于估计top-p阈值的初始猜测
_PERCENTILE_TO_STD_TABLE = [...] # 200个元素
# 用途: Top-K采样
# 给定百分位数k/V,查表得到标准正态分布的z值
# 用于估计top-k阈值的初始猜测
表结构:
- 每个表200个元素,覆盖百分位数 [0.5%, 100%]
_NORMAL_CDF_TO_SIGMA_TABLE: 用于top-p,从百分位映射到σ_PERCENTILE_TO_STD_TABLE: 用于top-k,从百分位映射到z-score
使用方式:
python
# Top-K: 估计k/V百分位对应的z-score
percentile = int(k / VOCAB_SIZE * 200) # 映射到[0, 199]
sigma = PERCENTILE_TO_STD_TABLE[percentile]
outlier_pivot = avg_logit + std_logit * sigma # 初始枢轴估计
16.3 _update_min_larger_stats() 辅助函数
python
@triton.jit
def _update_min_larger_stats(data, above_mask, min_larger, num_min_larger, sentinel):
"""跨tile更新"严格大于pivot的最小值"及其计数
用于处理重复logit值:当多个token有相同的logit值时,
需要知道有多少个token恰好在pivot值上
合并规则:
- tile_min < running_min → 替换min和count
- tile_min == running_min → 累加count
- tile_min > running_min → 保持running值
"""
# 在above_mask为True的位置找最小值(其他位置设为sentinel=inf)
tile_min = tl.min(tl.where(above_mask, data, sentinel))
# 计算tile中等于tile_min的元素数
tile_eq = above_mask & (tl.abs(data - tile_min) < 1e-9)
tile_cnt = tl.sum(tile_eq)
# 与running状态合并
is_new = tile_min < min_larger # 发现更小的值
is_same = tl.abs(tile_min - min_larger) < 1e-9 # 值相同
# 更新count
num_min_larger = tl.where(
is_new, tile_cnt, # 新值 → 用tile的count
num_min_larger + tile_cnt * is_same # 相同 → 累加
)
# 更新min
min_larger = tl.minimum(min_larger, tile_min)
return min_larger, num_min_larger
为什么需要这个函数:
Top-K的核心问题是找到第K大的值。但如果有重复值,情况变得复杂:
logits = [5.0, 3.0, 3.0, 3.0, 1.0] k=3
第3大的值是3.0,但有3个token的值都是3.0。是否全部保留?
- 保留3个3.0 → 实际保留4个token(5.0 + 3×3.0)
- 只保留0个3.0 → 实际保留1个token(5.0)
- 保留部分3.0 → 需要随机选择保留哪几个
_update_min_larger_stats 追踪"严格大于pivot的最小值"和"等于该值的token数",使得可以在最终遮罩阶段正确处理重复值。
16.4 _topk_topp_kernel 整体结构
Yes
No
Yes
No
Yes
No
Yes
No
Yes
No
_topk_topp_kernel
TOPK_ENABLED
and k < V?
Top-K路径
TOPP_ENABLED?
Pass 0: 统计采样块
计算avg/std
估计outlier_pivot
avg + std × σ
Pass 1: 全词表扫描
找max/min + 收集outliers
outliers > k?
三元搜索(在buffer中)
2×pivot同时评估
三元搜索(全词表)
2×pivot同时评估
k_pivot确定
TOPP_ENABLED
and finite > k?
Top-K + Top-P联合路径
Top-K only
final_pivot = k_pivot
Pass 3: 计算softmax和
收集outlier概率
Pass 4: 归一化buffer概率
Pass 5: 三元搜索p_pivot
final_pivot = log(p_pivot × sum + max)
Standalone Top-P路径
无过滤
Pass 0: 统计采样块
计算avg/std
Pass 1: 计算softmax和
收集outlier概率
outlier_sum > p?
三元搜索(在buffer中)
三元搜索(全词表)
final_pivot = log(p_pivot × sum + max_sample)
Pass 6: 应用最终遮罩
16.5 Top-K路径详细解析
Pass 0: 统计采样
python
# 从第一个BLOCK_SIZE(8192)的采样块计算均值和标准差
offs = tl.arange(0, BLOCK_SIZE)
mask_n = offs < VOCAB_SIZE
logits_blk0 = tl.load(LOGITS_ROW + offs, mask=mask_n, other=-float("inf"))
# 排除-inf值(如grammar bitmasks产生的)
finite_mask = (logits_blk0 > -float("inf")) & mask_n
num_finite = tl.sum(finite_mask)
finite_logits = tl.where(finite_mask, logits_blk0, 0.0)
avg_logit = tl.where(num_finite > 0, tl.sum(finite_logits) / num_finite, 0.0)
sq_avg_logit = tl.where(
num_finite > 0,
tl.sum(finite_logits * finite_logits) / num_finite,
0.0,
)
std_logit = tl.sqrt(tl.maximum(sq_avg_logit - avg_logit * avg_logit, 0.0))
# 使用 E[X²] - E[X]² 而非直接计算var,避免二次遍历
为什么只用第一个block:
- 全词表遍历代价高(V=128K需要16个8K blocks)
- 第一个block的统计量足以估计整体分布
- 精度要求不高(只需估计outlier阈值的大致位置)
Pivot估计
python
# 从百分位查找表估计z-score
percentile = tl.cast(k / VOCAB_SIZE * 200, tl.uint32)
percentile = tl.minimum(percentile, 199)
sigma = tl.load(PERCENTILE_TO_STD_TABLE + percentile)
# 加负偏移: sigma = sigma - 0.15 * |sigma|
# 这使得初始估计偏高,宁可多收集outliers再缩小
sigma = sigma + tl.abs(sigma) * -0.15
outlier_pivot = avg_logit + std_logit * sigma
-0.15偏移的设计:
- 偏高的估计意味着收集的outliers比实际需要的多
- 多收集无害(后续三元搜索会精确化)
- 少收集有害(可能遗漏需要保留的token)
Pass 1: 全词表扫描
python
num_finite_total = tl.zeros((), dtype=tl.uint32)
for i in range(0, NUM_TILES): # 遍历所有8K blocks
offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask_n = offs_n < VOCAB_SIZE
logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf"))
# 更新全局max/min
max_logit = tl.maximum(max_logit, tl.max(logits_blk))
finite_blk = tl.where(logits_blk > -float("inf"), logits_blk, float("inf"))
min_logit = tl.minimum(min_logit, tl.min(finite_blk))
num_finite_total += tl.sum(finite_blk_mask & mask_n)
# 收集大于outlier_pivot的值到buffer
outlier_mask = (logits_blk > outlier_pivot) & mask_n
cumulative_pos = tl.cast(
tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32
)
num_outliers += tl.sum(outlier_mask)
write_pos = tl.where(outlier_mask, cumulative_pos, -1)
tl.store(BUFFER_ROW + write_pos, logits_blk, mask=outlier_mask)
buffer的作用:收集"可能是top-K"的outlier值到连续内存区域,使后续三元搜索只需遍历少量元素而非全词表。
三元搜索(Ternary Search)
python
found_pivot = 0
while found_pivot == 0:
# 同时评估两个pivot点(1/3和2/3位置)
k_pivot_0 = (max_range - min_range) * 1.0 / 3.0 + min_range
k_pivot_1 = (max_range - min_range) * 2.0 / 3.0 + min_range
# 遍历buffer,统计大于每个pivot的元素数
for i in range(0, search_iters):
logits_blk2 = tl.load(BUFFER_ROW + offs_n, ...)
k_pivots_num_0 += tl.sum(logits_blk2 > k_pivot_0)
k_pivots_num_1 += tl.sum(logits_blk2 > k_pivot_1)
# 同时追踪"严格大于pivot的最小值"
min_larger_0, num_min_larger_0 = _update_min_larger_stats(...)
min_larger_1, num_min_larger_1 = _update_min_larger_stats(...)
# 终止条件: count >= k 且 count - num_at_pivot < k
# 即pivot恰好将前k个元素与剩余元素分开
if k_pivots_num_0 >= k and k_pivots_num_0 - num_min_larger_0 < k:
found_pivot = 1 # k_pivot_0 是正确的阈值
# 更新搜索范围
if k_pivots_num_1 > k: min_range = k_pivot_1
elif k_pivots_num_0 > k: min_range = k_pivot_0
if k_pivots_num_0 < k: max_range = k_pivot_0
elif k_pivots_num_1 < k: max_range = k_pivot_1
num_iters += 1
if num_iters >= 18 or abs(min_range - max_range) < 1e-9:
k_pivot = (max_range + min_range) / 2.0
found_pivot = 1
为什么用三元搜索而非二分:
- 每次迭代评估2个pivot(1/3和2/3)
- 理论收敛速度: 三元搜索每次缩小2/3范围,二分缩小1/2
- 但2pivot的评估在一次遍历中完成,实际效率更高
- 最多18次迭代(约2^18 ≈ 262K的精度,足够float32)
终止条件解析:
k_pivots_num >= k → 大于pivot的元素 ≥ k个
k_pivots_num - num_min_larger < k → 大于pivot但不等于pivot最小值的元素 < k个
这意味着:pivot值将"确定保留的元素"和"可能需要部分保留的元素"分开。
16.6 Top-P路径详细解析
Standalone Top-P(无Top-K时)的流程:
Pass 1: 计算softmax和
python
# 使用max_sample避免softmax溢出
max_sample = avg_logit + std_logit * 10.0
sum_exp_logits = 0.0
for i in range(0, NUM_TILES):
probs_blk = tl.exp(logits_blk - max_sample)
sum_exp_logits += tl.sum(probs_blk)
Pass 2: 收集outlier概率
python
# 使用查找表估计概率阈值
idx = int(p * 200)
sigma = NORMAL_CDF_TO_SIGMA_TABLE[idx]
sigma = sigma + abs(sigma) * -0.25 # 偏保守估计
outlier_prob = exp(outlier_pivot - max_sample) / sum_exp_logits
# 收集概率大于outlier_prob的token
for i in range(0, NUM_TILES):
probs_blk = exp(logits_blk - max_sample) / sum_exp_logits
outlier_mask = probs_blk > outlier_prob
# 存储到buffer
Pass 3: 三元搜索p_pivot
python
# 在概率空间搜索p_pivot
# p_pivot满足: sum(prob > p_pivot) ≥ p 且 sum(prob > p_pivot) - min_larger × count < p
while found_pivot == 0:
p_pivot_0 = (max_range - min_range) * 1/3 + min_range
p_pivot_1 = (max_range - min_range) * 2/3 + min_range
# 统计大于p_pivot的概率和
p_pivots_sum_0 = sum(probs × (probs > p_pivot_0))
p_pivots_sum_1 = sum(probs × (probs > p_pivot_1))
# 终止条件
if p_pivots_sum >= p and p_pivots_sum - min_larger × count < p:
found_pivot = 1
16.7 组合Top-K+Top-P路径
当同时启用top-k和top-p时:
- 先应用Top-K:找到k_pivot,保留top-K个元素
- 再应用Top-P:在top-K的输出上应用top-p过滤
python
# Top-K结果已收集在buffer中
# 现在需要在buffer中的元素上应用Top-P
# Pass 3: 计算top-K元素的softmax和
sum_exp_logits = 0.0
for i in range(0, search_iters):
probs_blk = BUFFER_ROW[i]
# 处理重复logit
if num_keep < num_duplicate_logit:
duplicate_mask = abs(probs_blk - duplicate_logit) < 1e-9
duplicate_count = cumsum(duplicate_mask) + num_kept
duplicate_keep_mask = duplicate_count <= num_keep
duplicate_remove_mask = duplicate_mask & ~duplicate_keep_mask
outlier_mask = outlier_mask & ~duplicate_remove_mask
probs_blk = where(outlier_mask, probs_blk, -inf)
probs_blk = probs_blk - max_logit
probs_blk = exp(probs_blk)
sum_exp_logits += sum(probs_blk)
# Pass 4: 归一化并存储
for i in range(0, search_iters):
probs_blk = BUFFER_ROW[i]
probs_blk = exp(probs_blk - max_logit)
probs_blk = probs_blk / sum_exp_logits
store(BUFFER_ROW + offs_n, probs_blk)
# Pass 5: 搜索p_pivot
# (与standalone top-p相同的三元搜索)
16.8 最终遮罩应用
python
# Pass 6: 应用最终遮罩
# 如果final_pivot >= max_logit(或NaN),无token需要屏蔽
if not (final_pivot < max_logit):
final_pivot = -float("inf") # 不过滤任何token
elif final_pivot != -float("inf"):
for i in range(0, NUM_TILES):
logits_blk = load(LOGITS_ROW + offs_n)
keep_mask = (logits_blk > final_pivot) & mask_n
# 处理重复logit
if num_keep < num_duplicate_logit:
duplicate_mask = abs(logits_blk - duplicate_logit) < 1e-9
duplicate_count = cumsum(duplicate_mask) + num_kept
duplicate_keep_mask = duplicate_count <= num_keep
duplicate_remove_mask = duplicate_mask & ~duplicate_keep_mask
num_kept += sum(duplicate_keep_mask)
keep_mask = keep_mask & ~duplicate_remove_mask
logits_blk = where(keep_mask, logits_blk, MASK_VALUE) # -inf
store(LOGITS_ROW + offs_n, logits_blk)
16.9 重复logit处理
重复logit是Top-K/Top-P实现中的关键边界情况:
问题场景:
logits = [5.0, 3.0, 3.0, 3.0, 1.0] k=3
k_pivot = 3.0 num_at_pivot = 3
严格 > k_pivot 的元素:1个(5.0)
== k_pivot 的元素:3个(3.0 × 3)
需要保留:k - 1 = 2个3.0
处理方式:
num_kept追踪已保留的重复值计数duplicate_keep_mask = (cumsum ≤ num_keep)→ 保留前num_keep个duplicate_remove_mask→ 移除其余
为什么用cumsum而非随机选择:
- 确定性更好,便于调试和复现
- 实际效果等价(相同logit的token概率相同,选哪个无所谓)
- Triton kernel中随机数生成代价高
16.10 apply_top_k_top_p_triton() Python包装
python
def apply_top_k_top_p_triton(
logits: torch.Tensor, # [batch_size, vocab_size] 修改in-place
k: torch.Tensor | None, # [batch_size]
p: torch.Tensor | None, # [batch_size]
mask_value: float = float("-inf"),
) -> torch.Tensor:
assert logits.ndim == 2
assert logits.dtype == torch.float32
batch_size, vocab_size = logits.shape
topk_enabled = k is not None
topp_enabled = p is not None
if batch_size == 0 or not (topk_enabled or topp_enabled):
return logits
# 准备k/p指针(不允许的用dummy)
if k is not None:
k_ptr = k.to(torch.int32) # Triton kernel需要int32
else:
k_ptr = logits # dummy,不会被读取
if p is not None:
p_ptr = p.to(torch.float32)
else:
p_ptr = logits # dummy
# 计算并行program数
num_sm = num_compute_units(logits.device.index)
NUM_PROGRAMS = min(num_sm, batch_size)
# 分配/缓存buffer
buf_key = (logits.device, logits.dtype, vocab_size)
buffer = _TRITON_BUFFER_CACHE.get(buf_key)
if buffer is None or buffer.shape[0] < NUM_PROGRAMS:
size = min(next_power_of_2(NUM_PROGRAMS), num_sm)
buffer = logits.new_empty((size, vocab_size))
_TRITON_BUFFER_CACHE[buf_key] = buffer
if buffer.shape[0] > NUM_PROGRAMS:
buffer = buffer[:NUM_PROGRAMS]
# 缓存查找表
tables = _TRITON_TABLE_CACHE.get(logits.device)
if tables is None:
normal_cdf_to_sigma_table = logits.new_tensor(_NORMAL_CDF_TO_SIGMA_TABLE)
percentile_to_std_table = logits.new_tensor(_PERCENTILE_TO_STD_TABLE)
_TRITON_TABLE_CACHE[logits.device] = (
normal_cdf_to_sigma_table, percentile_to_std_table
)
else:
normal_cdf_to_sigma_table, percentile_to_std_table = tables
# 启动kernel
_topk_topp_kernel[(NUM_PROGRAMS,)](
logits, buffer,
percentile_to_std_table, normal_cdf_to_sigma_table,
k_ptr, p_ptr,
BATCH_SIZE=batch_size,
MASK_VALUE=mask_value,
VOCAB_SIZE=vocab_size,
BLOCK_SIZE=8192,
BLOCK_SIZE_TRUNC=4096,
TOPK_ENABLED=topk_enabled,
TOPP_ENABLED=topp_enabled,
)
return logits
BLOCK_SIZE / BLOCK_SIZE_TRUNC 的设计:
BLOCK_SIZE=8192:全词表遍历时的tile大小BLOCK_SIZE_TRUNC=4096:只在buffer/outlier上遍历时的tile大小(更小,因为outlier数量通常远小于V)
16.11 缓冲区与查找表缓存
python
_TRITON_TABLE_CACHE: dict[tuple[torch.device], tuple[torch.Tensor, torch.Tensor]] = {}
# key: (device,) → value: (normal_cdf_to_sigma_table, percentile_to_std_table)
# 每个设备只需创建一次查找表
_TRITON_BUFFER_CACHE: dict[tuple[torch.device, torch.dtype, int], torch.Tensor] = {}
# key: (device, dtype, vocab_size) → value: buffer [num_programs, vocab_size]
# 按设备+词表大小缓存,避免每次调用重新分配
16.12 reset_buffer_cache() 缓存清理
python
def reset_buffer_cache():
"""清理缓存,释放GPU内存"""
_TRITON_BUFFER_CACHE.clear()
_TRITON_TABLE_CACHE.clear()
torch.accelerator.empty_cache() # 释放未使用的GPU内存
第十七章 sample_recovered_tokens_kernel Triton内核
17.1 算法原理
拒绝采样被拒绝后,需要从修正分布 max(p(x) - q(x), 0) 采样。使用Gumbel-Max技巧:
recovered = argmax_x(prob(x) × inv_q[x])
其中:
prob(x) = max(target_prob(x) - draft_prob(x), 0)inv_q[x] = 1 / Exp(1)_x
由于 prob(x) 不需要归一化(argmax与归一化无关),避免了sum运算。
17.2 逐行解析
python
@triton.jit
def sample_recovered_tokens_kernel(
output_token_ids_ptr, # [num_tokens] 输出
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
draft_probs_ptr, # [num_tokens, vocab_size] or None
target_probs_ptr, # [num_tokens, vocab_size]
inv_q_ptr, # [batch_size, vocab_size]
vocab_size,
BLOCK_SIZE: tl.constexpr,
NO_DRAFT_PROBS: tl.constexpr,
):
req_idx = tl.program_id(0) # 请求索引
pos = tl.program_id(1) # draft位置索引
# 读取该请求的token范围
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft = end_idx - start_idx
# 超出范围的位置直接返回
if pos >= num_draft:
return
token_idx = start_idx + pos # 全局token索引
# ngram模式:排除draft_token本身
if NO_DRAFT_PROBS:
draft_token_id = tl.load(draft_token_ids_ptr + token_idx)
# 分tile遍历词表
max_val = float("-inf")
recovered_id = 0
for v in range(0, vocab_size, BLOCK_SIZE):
vocab_offset = v + tl.arange(0, BLOCK_SIZE)
vocab_mask = vocab_offset < vocab_size
if NO_DRAFT_PROBS:
# ngram: prob = target_prob(排除draft_token)
prob = tl.load(
target_probs_ptr + token_idx * vocab_size + vocab_offset,
mask=(vocab_mask & (vocab_offset != draft_token_id)),
other=0.0,
)
else:
# 标准模式: prob = max(target - draft, 0)
draft_prob = tl.load(
draft_probs_ptr + token_idx * vocab_size + vocab_offset,
mask=vocab_mask, other=0.0
)
target_prob = tl.load(
target_probs_ptr + token_idx * vocab_size + vocab_offset,
mask=vocab_mask, other=0.0
)
prob = tl.maximum(target_prob - draft_prob, 0.0)
# 读取Gumbel噪声倒数
inv_q = tl.load(
inv_q_ptr + req_idx * vocab_size + vocab_offset,
mask=vocab_mask, other=0.0
)
# 计算 score = prob × inv_q
score = prob * inv_q
# 找本tile的max
local_max, local_id = tl.max(score, axis=0, return_indices=True)
# 更新全局max
if local_max > max_val:
max_val = local_max
recovered_id = v + local_id
# 写入结果
tl.store(output_token_ids_ptr + token_idx, recovered_id)
17.3 无draft_probs模式(ngram spec decode)
当使用ngram spec decode时,没有draft概率(draft_probs=None)。此时:
- 修正分布简化为:
prob(x) = target_prob(x)(排除draft_token本身) - 因为
max(p - q, 0)在q=1时退化为max(p - 1, 0),这不对 - 实际处理:在ngram模式下,接受条件变为
uniform < 1(总是接受),所以不会进入recovered token逻辑 - 但kernel仍需正确处理:使用
prob = target_prob(排除draft token)
附录E 投机解码端到端时序图
sample_recovered_kernel random_kernel greedy_kernel rejection_sample() Sampler RejectionSampler GPUModelRunner sample_recovered_kernel random_kernel greedy_kernel rejection_sample() Sampler RejectionSampler GPUModelRunner 投机解码步骤 Phase 1: Bonus token Phase 2: Target logits处理 Phase 3: 拒绝采样 alt [全贪心] [有随机请求] Phase 4: Logprobs forward(metadata, draft_probs, target_logits, sampling_metadata) sampler(bonus_logits, ...) bonus_token_ids apply_logits_processors(target_logits) apply_sampling_constraints(target_logits) rejection_sample(draft_ids, target_logits, bonus_ids) rejection_greedy_sample_kernel output_token_ids rejection_greedy_sample_kernel (贪心部分) target_probs = softmax(target_logits) sample_recovered_tokens(target_probs, draft_probs) recovered_token_ids rejection_random_sample_kernel (随机部分) output_token_ids output_token_ids [B, max_spec+1] _get_logprobs_tensors(...) SamplerOutput
附录F Triton Kernel计算流程图
Standalone Top-P 路径
Top-K + Top-P 联合路径
Pass 0: 采样块统计
avg, std, outlier_pivot
Pass 1: 全词表扫描
max, min, outlier收集
Ternary Search (在buffer)
找k_pivot
记录duplicate_logit
num_keep
Pass 3: Top-K元素softmax
sum_exp, outlier概率收集
Pass 4: 归一化buffer概率
probs / sum_exp
Ternary Search (在概率空间)
找p_pivot
更新duplicate_logit
num_keep
Pass 6: 应用final_pivot遮罩
- 重复logit处理
Pass 0: 采样块统计
Pass 1: 计算softmax和
收集outlier概率
Ternary Search (在概率空间)
找p_pivot
SP6
返回过滤后的logits
附录G 拒绝采样数学证明
定理:拒绝采样算法的输出分布等价于直接从target模型采样。
证明:
设draft分布为 q(x),target分布为 p(x)。
步骤1:接受概率
P(accept x) = P(draft generates x) × P(accept x | x generated)
= q(x) × min(p(x)/q(x), 1)
= min(p(x), q(x))
步骤2:拒绝后恢复分布
P(reject) = 1 - Σ_x min(p(x), q(x))
P(recover y | reject) = max(p(y) - q(y), 0) / Σ_x max(p(x) - q(x), 0)
步骤3:最终输出分布
P(output = y) = P(accept y) + P(reject) × P(recover y | reject)
对于被接受的token y:
= min(p(y), q(y)) + [1 - Σ_x min(p(x), q(x))] × max(p(y) - q(y), 0) / Z
其中 Z = Σ_x max(p(x) - q(x), 0) = 1 - Σ_x min(p(x), q(x))
因此:
P(output = y) = min(p(y), q(y)) + max(p(y) - q(y), 0) = p(y)
证毕。 □
bonus token :当所有K个draft token都被接受时,从target分布额外采样一个token。由于此时没有拒绝位置需要恢复,bonus token直接从 p(x) 采样,保证分布一致性。
附录H 术语表补充
| 术语 | 含义 |
|---|---|
| PLACEHOLDER_TOKEN_ID (-1) | 拒绝采样输出中被拒绝位置的占位符 |
| MAX_SPEC_LEN (128) | 单步最大投机token数 |
| GREEDY_TEMPERATURE (0) | 温度为0表示贪心采样 |
| cu_num_draft_tokens | 累积draft token偏移,类似CSR格式的行偏移 |
| outlier_pivot | 基于统计估计的初始枢轴值 |
| k_pivot | Top-K阈值:大于此值的logit保留 |
| p_pivot | Top-P阈值:概率大于此值的token保留 |
| duplicate_logit | 等于pivot值的logit,需特殊处理部分保留 |
| BLOCK_SIZE (8192) | Triton kernel的tile大小 |
| BLOCK_SIZE_TRUNC (4096) | 在buffer上的tile大小 |
| Gumbel-Max | argmax(prob × inv_exp) = Categorical(prob) |
| Synthetic Mode | 使用预设接受率而非p/q比率的拒绝采样 |
| Conditional Rate | 在前面token已接受条件下的接受率 |
| recovered_token | 拒绝后从max(p-q,0)采样的修正token |
| Qrita | 论文名,Pivot-based Top-K/Top-P算法 |
| Ternary Search | 三分搜索,每次评估2个pivot点 |
| Grammar Mask | 语法约束产生的-inf logit掩码 |
附录R rejection_sample() 完整数据流推演
R.1 场景: 3个请求,不同draft数量
请求配置:
请求0: 2个draft tokens, greedy采样
请求1: 3个draft tokens, random采样(temp=0.8)
请求2: 1个draft token, random采样(temp=1.0)
输入数据:
draft_token_ids = [D00, D01, D10, D11, D12, D20] # [6个]
num_draft_tokens = [2, 3, 1]
cu_num_draft_tokens = [2, 5, 6] # 累积偏移
bonus_token_ids = [B0, B1, B2] # [3, 1]
target_logits shape: [6, 32000] # 6个token位置 × 词表大小
R.2 贪心请求(请求0)的拒绝采样
rejection_greedy_sample_kernel[program_id=0]:
is_greedy[0] = True (temperature=0)
start_idx = 0
end_idx = 2 (cu_num_draft_tokens[0])
num_draft_tokens = 2
pos=0:
draft_token_id = D00
target_argmax_id = argmax(target_logits[0])
# 贪心拒绝逻辑: D00 == argmax?
如果 D00 == argmax → token_id = argmax, rejected=False
如果 D00 != argmax → token_id = argmax, rejected=True
pos=1 (如果rejected=False):
draft_token_id = D01
target_argmax_id = argmax(target_logits[1])
同样检查
如果全部接受:
bonus_token_id = B0
output[0] = [D00_or_argmax, D01_or_argmax, B0]
如果pos=0就拒绝:
output[0] = [argmax, -1, -1]
# -1 = PLACEHOLDER_TOKEN_ID
R.3 随机请求(请求1)的拒绝采样
rejection_random_sample_kernel[program_id=1]:
is_greedy[1] = False (temperature=0.8)
start_idx = 2
end_idx = 5 (cu_num_draft_tokens[1])
num_draft_tokens = 3
pos=0:
draft_token_id = D10
uniform_prob = U10 # 均匀随机数
draft_prob = draft_probs[2, D10] # q(D10)
target_prob = target_probs[2, D10] # p(D10)
# 接受条件: q(D10) > 0 and p(D10)/q(D10) >= U10
accepted = draft_prob > 0 and target_prob / draft_prob >= uniform_prob
如果接受: token_id = D10
如果拒绝: token_id = recovered_token_ids[2], rejected=True
pos=1 (如果rejected=False):
draft_token_id = D11
同样计算 p/q vs uniform
pos=2 (如果rejected=False):
draft_token_id = D12
同样计算
如果全部接受:
bonus_token_id = B1
output[1] = [D10, D11, D12, B1]
R.4 完整输出矩阵
output_token_ids: [batch_size=3, max_spec_len+1=4]
请求0 (greedy, 2 drafts):
[accepted_0, accepted_1_or_recovered, B0_or_-1, -1]
# 注意: max_spec_len=3,但请求0只有2个draft
# 第4列(max_spec_len位置)由max_spec_len决定
# 实际: 如果2个draft全接受,bonus在第3列(索引2)
请求1 (random, 3 drafts):
[accepted_0, accepted_1, accepted_2_or_recovered, B1_or_-1]
请求2 (random, 1 draft):
[accepted_0_or_recovered, B2_or_-1, -1, -1]
解析后(parse_output)
output_token_ids [3, 4]
请求0: [tok, tok, B0, -1]
请求1: [tok, tok, tok, B1]
请求2: [tok, B2, -1, -1]
请求0: [tok, tok, B0]
请求1: [tok, tok, tok, B1]
请求2: [tok, B2]
附录S Triton Kernel 三元搜索算法详解
S.1 标准三分搜索
目标: 找到值x,使得 f(x) 满足特定条件
标准三分搜索:
while range > ε:
m1 = lo + (hi - lo) / 3
m2 = lo + 2*(hi - lo) / 3
if f(m1) < f(m2):
hi = m2 # 去掉右1/3
else:
lo = m1 # 去掉左1/3
S.2 Qrita变体: 同时评估两个pivot
Qrita的三元搜索:
while not found:
pivot_0 = lo + (hi - lo) / 3 # 1/3位置
pivot_1 = lo + 2*(hi - lo) / 3 # 2/3位置
count_0 = count(logits > pivot_0) # 大于pivot_0的元素数
count_1 = count(logits > pivot_1) # 大于pivot_1的元素数
# 检查是否命中
if count_0 >= k and count_0 - dup_count_0 < k:
found! pivot = pivot_0
if count_1 >= k and count_1 - dup_count_1 < k:
found! pivot = pivot_1
# 更新范围
if count_1 > k: lo = pivot_1 # 右pivot仍太多 → 抬高下限
elif count_0 > k: lo = pivot_0 # 左pivot太多 → 抬高下限
if count_0 < k: hi = pivot_0 # 左pivot太少 → 降低上限
elif count_1 < k: hi = pivot_1 # 右pivot太少 → 降低上限
与标准三分搜索的区别:
- 不是找单峰函数的极值,而是找满足条件的阈值
- 终止条件是"恰好有K个元素大于等于pivot"
- 同时评估2个pivot,可能其中一个直接命中
- 最多18次迭代(float32精度足够)
S.3 收敛速度分析
每次迭代:
- 最优: pivot_0或pivot_1直接命中 → 1次
- 最差: 范围缩小1/3 → log_3(V/logit_range) 次
- 对于V=128000, range~20:
log_3(128000/20) ≈ log_3(6400) ≈ 7.8 次
- 18次上限提供充分余量
附录T expand_kernel Triton内核详解
T.1 累积偏移(Cumulative Offset)模式
vLLM投机解码大量使用"累积偏移"模式,类似CSR稀疏矩阵格式的行偏移:
num_draft_tokens = [2, 3, 1]
cu_num_draft_tokens = [2, 5, 6]
含义:
请求0: token 0-1 (start=0, end=2)
请求1: token 2-4 (start=2, end=5)
请求2: token 5 (start=5, end=6)
T.2 expand_kernel 工作原理
输入: x = [a, b, c], cu_num_tokens = [2, 5, 6]
输出: expanded = [a, a, b, b, b, c]
Program 0 (req_idx=0):
start=0, end=2, num=2
src_val = a
store(output[0:2], a) → output = [a, a, _, _, _, _]
Program 1 (req_idx=1):
start=2, end=5, num=3
src_val = b
store(output[2:5], b) → output = [a, a, b, b, b, _]
Program 2 (req_idx=2):
start=5, end=6, num=1
src_val = c
store(output[5:6], c) → output = [a, a, b, b, b, c]
T.3 replace_from/replace_to 参数
温度展平时的特殊处理:
temperature = [0, 0.8, 1.0] # 请求0温度=0(贪心)
replace_from = 0
replace_to = 1
expanded = [1, 1, 0.8, 0.8, 0.8, 1.0]
# 贪心请求的温度被替换为1.0
# 因为后续要做 logits /= temperature
# temperature=0 会导致除零错误
# 但贪心请求已经走了argmax路径,不需要温度缩放
# temperature=1.0 等价于不缩放(logits /= 1.0 = logits)
附录U 查找表数据详解
U.1 _NORMAL_CDF_TO_SIGMA_TABLE
这个表将正态分布的百分位数映射到σ值。用于Standalone Top-P采样。
索引0 → σ = 3.656 (对应百分位约0.5%,即p≈0.005)
索引1 → σ = 3.650
...
索引100 → σ ≈ 0.0 (对应百分位50%)
...
索引199 → σ = -3.813 (对应百分位约99.5%)
使用方式:
p = 0.9 (top-p = 0.9)
idx = int(0.9 * 200) = 180
sigma = NORMAL_CDF_TO_SIGMA_TABLE[180] ≈ -1.881
outlier_pivot = avg_logit + std_logit * sigma
# sigma为负 → outlier_pivot < avg → 收集更多outlier
U.2 _PERCENTILE_TO_STD_TABLE
这个表将百分位数映射到标准正态分布的z值。用于Top-K采样。
索引0 → z = 2.576 (对应百分位约0.5%)
索引100 → z ≈ 0.0 (对应百分位50%)
索引199 → z = -3.813 (对应百分位约99.5%)
使用方式:
k = 50, V = 32000
percentile = int(50/32000 * 200) = 0
sigma = PERCENTILE_TO_STD_TABLE[0] = 2.576
outlier_pivot = avg_logit + std_logit * 2.576
# 很高的z值 → outlier_pivot远大于avg → 只收集极端outlier
U.3 偏移修正
python
# Top-K偏移: sigma = sigma + |sigma| * -0.15
# 使估计偏保守(收集更多outlier)
# -0.15 * |sigma| 当sigma>0时减少sigma值 → 降低阈值 → 收集更多
# -0.15 * |sigma| 当sigma<0时增加sigma值(更负) → 也降低阈值
# Top-P偏移: sigma = sigma + |sigma| * -0.25
# 更激进的偏移(Top-P对遗漏更敏感)
# -0.25 vs -0.15: Top-P需要更保守的估计
为什么需要偏移:
- 估计基于第一个block的统计量,可能不完全代表全词表
- 偏保守意味着多收集一些outlier → 后续三元搜索会精确化
- 偏激进意味着可能遗漏 → 导致top-K/Top-P结果不准确
- "宁可多收集,不可遗漏"的原则
附录V 拒绝采样与Gumbel-Max的统一视角
V.1 Gumbel-Max等价性
定理 : 对于概率分布 p(x) 和独立 G_x ~ Gumbel(0,1):
argmax_x (log p(x) + G_x) ~ Categorical(p)
推论 : 对于 q_x ~ Exp(1),1/q_x ~ Gumbel(0,1):
argmax_x (p(x) / q_x) ~ Categorical(p)
这就是 random_sample() 和 sample_recovered_tokens_kernel 的数学基础。
V.2 拒绝采样的Gumbel-Max解释
标准拒绝采样可以理解为Gumbel-Max的变体:
1. 对draft token x: 计算 r = p(x)/q(x)
2. 生成 u ~ Uniform(0,1)
3. 如果 u < r: 接受x
4. 否则: 从max(p-q,0)采样recovered token
等价于:
1. 对draft token x: 设定 Gumbel增量 g_x = -log(u) 当 u < r
2. 如果 g_x < -log(1-r): 接受
3. 否则: 从修正分布采样
但实际实现不用Gumbel表示,因为:
- p/q比较更直观
- Uniform随机数更易生成
- Triton kernel中Uniform比Gumbel更高效
V.3 sample_recovered_tokens 的Gumbel-Max实现
prob(x) = max(target_prob(x) - draft_prob(x), 0) # 修正分布
inv_q[x] = 1 / Exp(1) # Gumbel噪声倒数
recovered = argmax_x (prob(x) * inv_q[x])
等价于:
score(x) = log(prob(x)) + Gumbel(x) (因为log(1/q) = Gumbel)
但prob可能为0 → log(0) = -inf → 不影响argmax
实际使用 prob * inv_q 而非 log(prob) + Gumbel:
因为乘法避免了log(0)的NaN问题
argmax不变: argmax(a*b) = argmax(log(a)+log(b)) 当a,b>0
附录W 投机解码中Logits处理器的特殊行为
W.1 MinTokens在投机解码中的展开
常规采样: 每个请求1行logits
logits [B, V]
MinTokens: 对每行屏蔽stop tokens
投机解码: 每个请求N行logits (N=draft tokens数)
logits [num_tokens, V]
MinTokens: 对前remaining行屏蔽stop tokens
示例:
请求0: min_tokens=5, current_output=3, num_draft=2
remaining = 5 - 3 = 2
n_mask = min(2, 2) = 2
该请求的2个draft位置都需要屏蔽stop tokens:
rows [0, 1] × stop_tokens → index_put_(rows, stop_ids, -inf)
请求1: min_tokens=10, current_output=9, num_draft=3
remaining = 10 - 9 = 1
n_mask = min(1, 3) = 1
只有第1个draft位置需要屏蔽:
row [2] × stop_tokens → index_put_(row, stop_ids, -inf)
W.2 LogitBias在投机解码中的展开
LogitBias不需要特殊处理:
biases[req_idx] = {token_id: bias_value}
在投机解码中:
repeat_indices将请求级参数展平到token级
bias_tensor[repeat_indices] → 每个draft位置使用相同偏置
原因: LogitBias是全局的token偏好,不依赖于当前位置
不管在哪个draft位置,对特定token的偏置值相同
W.3 ThinkingBudget在投机解码中的双次调用
投机解码中logits处理器被调用两次:
调用1: bonus token logits (predict_bonus_token=True)
只处理1行per request (bonus位置)
如果force_index指向bonus → 在此调用中强制
调用2: target logits (predict_bonus_token=False)
处理num_draft_tokens行per request
如果force_index指向draft位置 → 在此调用中强制
设计原因:
bonus token和draft token的logits是分开计算的
需要在各自的logits张量上独立应用处理器
避免在不同形状的张量上做广播
ThinkingBudget RejectionSampler ThinkingBudget RejectionSampler Phase 1: Bonus token force_index指向bonus? → 强制 force_index指向draft? → 跳过 Phase 2: Target logits force_index指向draft位置 → 强制 逐行设置mask和force_token_ids apply_to_logits(bonus_logits, predict_bonus_token=True) apply_to_logits(target_logits, predict_bonus_token=False)
附录X2 Triton Kernel Memory Access Pattern 分析
X2.1 全局内存访问模式
_topk_topp_kernel 的内存访问:
输入: LOGITS [BATCH_SIZE, VOCAB_SIZE]
每行VOCAB_SIZE=32000~128000个float32元素
每行大小: 128KB~512KB
Pass 0 (采样块):
读取: LOGITS_ROW[0:BLOCK_SIZE] # 8192个float = 32KB
写入: 无
访问模式: 连续读取,高效的内存合并
Pass 1 (全词表扫描):
读取: LOGITS_ROW[NUM_TILES × BLOCK_SIZE] # 全词表
写入: BUFFER_ROW[0:num_outliers] # 稀疏写入
访问模式: 连续读取,稀疏写入(outlier位置不确定)
三元搜索:
读取: BUFFER_ROW[0:search_range] # 只读outlier区域
写入: 无(只更新搜索范围变量)
访问模式: 连续读取BUFFER,效率高
Pass 6 (最终遮罩):
读取: LOGITS_ROW[NUM_TILES × BLOCK_SIZE] # 全词表
写入: LOGITS_ROW[NUM_TILES × BLOCK_SIZE] # 原地修改
访问模式: 连续读写,最高效
X2.2 共享内存使用
Triton kernel的共享内存使用:
- BLOCK_SIZE=8192: 每个program处理1行
- 中间变量存储在寄存器中(非共享内存)
- BUFFER是全局内存(非共享内存)
- 不使用tl.load到shared memory的模式
原因:
- 每个program独立处理1行,不需要program间通信
- BUFFER需要跨pass持久化,必须用全局内存
- 共享内存大小有限(48KB~96KB),无法存8192个float
X2.3 Warp Divergence分析
Kernel中的分支:
1. TOPK_ENABLED / TOPP_ENABLED: 编译时常量 → 无divergence
2. k < VOCAB_SIZE: 每行可能不同 → 可能divergence
3. p < 1.0: 每行可能不同 → 可能divergence
4. found_pivot: 循环条件 → 可能divergence
5. duplicate处理: 条件分支 → 可能divergence
缓解策略:
- 每个program处理1行 → 同一warp内的program处理不同行
- 不同行的分支独立 → 不影响效率
- 三元搜索的迭代次数差异 → 少数program可能更早退出
- 总体: divergence影响较小
附录Y2 rejection_greedy_sample_kernel 完整执行追踪
Y2.1 示例: 5个请求的贪心拒绝采样
配置:
batch_size = 5
num_draft_tokens = [3, 2, 4, 1, 2]
cu_num_draft_tokens = [3, 5, 9, 10, 12]
max_spec_len = 4
draft_token_ids = [D00, D01, D02, D10, D11, D20, D21, D22, D23, D30, D40, D41]
target_argmax = [T00, T01, T02, T10, T11, T20, T21, T22, T23, T30, T40, T41]
bonus_token_ids = [B0, B1, B2, B3, B4]
is_greedy = [True, True, True, True, True] # 全贪心
Kernel启动: 5个program(每个请求1个)
Program 0 (请求0, 3个draft):
is_greedy[0] = True → 继续执行
start_idx = 0, end_idx = 3, num_draft = 3
pos=0: D00 vs T00
如果 D00 == T00: token=D00(或T00), rejected=False
如果 D00 != T00: token=T00, rejected=True, 后续位置保持-1
假设 D00==T00, D01==T01, D02!=T02:
output[0] = [T00, T01, T02, -1]
# pos=0,1接受draft,pos=2拒绝→用target的argmax
Program 1 (请求1, 2个draft):
start_idx = 3, end_idx = 5, num_draft = 2
假设 D10==T10, D11==T11:
rejected=False → 追加bonus
output[1] = [T10, T11, B1, -1]
Program 2 (请求2, 4个draft):
start_idx = 5, end_idx = 9, num_draft = 4
假设全部接受:
output[2] = [T20, T21, T22, T23, B2]
# 注意: max_spec_len=4, 所以最多4个draft+1个bonus=5列
但num_draft=4 == max_spec_len:
output[2] = [T20, T21, T22, T23, B2]
# 5列: 4个draft位置 + 1个bonus位置
Program 3 (请求3, 1个draft):
start_idx = 9, end_idx = 10, num_draft = 1
假设 D30==T30:
output[3] = [T30, B3, -1, -1, -1]
Program 4 (请求4, 2个draft):
start_idx = 10, end_idx = 12, num_draft = 2
假设 D40!=T40:
output[4] = [T40, -1, -1, -1, -1]
# 第1个draft就拒绝,后续全为-1
最终输出:
output_token_ids = [
[T00, T01, T02, -1, -1], # 请求0: 2接受+1恢复
[T10, T11, B1, -1, -1], # 请求1: 全接受+bonus
[T20, T21, T22, T23, B2], # 请求2: 全接受+bonus
[T30, B3, -1, -1, -1], # 请求3: 1接受+bonus
[T40, -1, -1, -1, -1], # 请求4: 1恢复(0接受)
]
parse_output过滤-1后:
[
[T00, T01, T02], # 3个token
[T10, T11, B1], # 3个token
[T20, T21, T22, T23, B2], # 5个token
[T30, B3], # 2个token
[T40], # 1个token
]
Y2.2 Synthetic模式执行路径
synthetic_conditional_rates = [0.8, 0.7, 0.6, 0.5]
# 位置0接受率0.8, 位置1接受率0.7, 位置2接受率0.6, 位置3接受率0.5
Program 0 (请求0, 3个draft, greedy+synthetic):
pos=0:
uniform_prob = U00
rate = 0.8
accepted = (U00 < 0.8)
如果接受: token_id = D00 (draft token)
如果拒绝: token_id = T00 (target argmax), rejected=True
pos=1 (如果未拒绝):
uniform_prob = U01
rate = 0.7
accepted = (U01 < 0.7)
...
注意: synthetic模式不比较p/q
而是使用预设的接受率
这是因为ngram spec decode没有draft概率
无法计算p/q比率
附录Z2 generate_uniform_probs float64精度必要性实验
Z2.1 问题复现
python
# float32下的问题:
import torch
torch.manual_seed(42)
# 生成100万个float32均匀随机数
u32 = torch.rand(1_000_000, dtype=torch.float32)
count_zero_32 = (u32 == 0.0).sum().item()
# 结果: count_zero_32 ≈ 0~10 (取决于种子)
# 概率: P(u == 0.0) ≈ 2^{-24} ≈ 6e-8
# 期望: 1e6 × 6e-8 ≈ 0.06 → 通常为0或1
# 但在拒绝采样中:
# accepted = target_prob / draft_prob >= uniform_prob
# 如果 uniform_prob == 0.0:
# 任何 target_prob/draft_prob >= 0 都为True
# → 本应拒绝的token被错误接受
# → 采样分布偏离target分布
# float64下:
u64 = torch.rand(1_000_000, dtype=torch.float64)
count_zero_64 = (u64 == 0.0).sum().item()
# 结果: count_zero_64 = 0
# 概率: P(u == 0.0) ≈ 2^{-53} ≈ 1.1e-16
# 期望: 1e6 × 1.1e-16 ≈ 1.1e-10 → 实际上不可能为0
Z2.2 精度对采样分布的影响
假设: draft_prob(x) = 0.1, target_prob(x) = 0.05
正确行为: p/q = 0.05/0.1 = 0.5 → 接受概率50%
float32 + uniform_prob恰好=0.0:
0.5 >= 0.0 → True → 接受
本应50%接受 → 实际100%接受 → 分布偏差
float64 + uniform_prob≈0.0但不精确=0:
0.5 >= ~1e-16 → True → 接受
但1e-16的概率几乎不可能 → 对分布影响可忽略
结论: float64消除了"精确零"问题,保证了采样的数学正确性
附录AA2 sample_recovered_tokens_kernel Block分块策略
AA2.1 为什么使用BLOCK_SIZE=8192
vocab_size = 128000 (典型LLM词表大小)
如果BLOCK_SIZE太小 (如256):
需要的迭代次数: 128000 / 256 = 500次
每次迭代: 1次全局内存读取 + 1次比较 + 1次乘法
总体: 500次内存读取 → 延迟高
如果BLOCK_SIZE太大 (如65536):
超过Triton的最大BLOCK_SIZE限制
寄存器压力过大 → 可能spill到local memory → 性能下降
BLOCK_SIZE=8192:
迭代次数: 128000 / 8192 ≈ 16次
每次迭代处理8192个元素
Triton可高效处理这个大小
寄存器使用合理
trade-off:
更大的BLOCK → 更少的迭代 → 更少的kernel启动开销
但寄存器压力增加 → 可能降低occupancy
8192是实测的经验最优值
AA2.2 tile内max操作的实现
python
# 在每个tile内找到score的最大值
score = prob * inv_q # [BLOCK_SIZE] 向量
# tl.max(score, axis=0, return_indices=True)
# 返回: (最大值, 最大值的索引)
# axis=0: 在BLOCK_SIZE维度上归约
# return_indices=True: 同时返回索引
local_max, local_id = tl.max(score, axis=0, return_indices=True)
# 跨tile更新全局max
if local_max > max_val:
max_val = local_max
recovered_id = v + local_id # v是tile的起始偏移
为什么不需要归一化prob:
prob = max(target_prob - draft_prob, 0)
prob可能有很多0值 → 这些位置不影响argmax
prob不需要归一化 → argmax不关心绝对值,只关心相对大小
数学证明:
argmax_x(prob[x] * inv_q[x])
= argmax_x(log(prob[x]) + log(inv_q[x])) (当prob>0时)
= argmax_x(log(prob[x]) + Gumbel[x]) (因为inv_q ~ 1/Exp(1) ~ Gumbel)
不需要prob归一化:
如果prob[x] *= C (常数),则:
argmax_x(C*prob[x] * inv_q[x])
= argmax_x(log(C) + log(prob[x]) + Gumbel[x])
= argmax_x(log(prob[x]) + Gumbel[x]) (log(C)是常数,不影响argmax)
= 原始argmax
所以: argmax(max(p-q,0) * inv_q) = argmax(归一化(max(p-q,0)) * inv_q)
无需归一化步骤,节省了一次sum操作
附录BB2 SpecDecodeMetadata 关键字段说明
RejectionSampler接收的 SpecDecodeMetadata 包含以下关键字段:
| 字段 | 形状 | 含义 |
|---|---|---|
draft_token_ids |
[num_tokens] | draft模型生成的token ids |
num_draft_tokens |
list[int] | 每个请求的draft token数量 |
max_spec_len |
int | 最大spec长度(≤128) |
cu_num_draft_tokens |
[batch_size] | 累积draft token偏移 |
cu_num_sampled_tokens |
[batch_size] | 累积采样token偏移(含bonus) |
bonus_logits_indices |
[batch_size] | bonus token在logits中的行索引 |
target_logits_indices |
[num_tokens] | target token在logits中的行索引 |
logits的展平布局:
logits shape: [num_tokens + batch_size, vocab_size]
对于 num_draft_tokens = [2, 3, 1]:
num_tokens = 2 + 3 + 1 = 6
total_rows = 6 + 3 = 9
行0-1: 请求0的draft位置0-1
行2-4: 请求1的draft位置0-2
行5: 请求2的draft位置0
行6: 请求0的bonus位置
行7: 请求1的bonus位置
行8: 请求2的bonus位置
target_logits_indices = [0, 1, 2, 3, 4, 5]
bonus_logits_indices = [6, 7, 8]
索引
target: [0,1,2,3,4,5]
bonus: [6,7,8]
logits [9, V]
行0: 请求0 draft-0
行1: 请求0 draft-1
行2: 请求1 draft-0
行3: 请求1 draft-1
行4: 请求1 draft-2
行5: 请求2 draft-0
行6: 请求0 bonus
行7: 请求1 bonus
行8: 请求2 bonus
附录CC2 拒绝采样与标准采样的等价性完整证明
CC2.1 单步拒绝采样
设draft分布为q(x),target分布为p(x)。
接受概率:
P(accept x|draft generates x) = min(p(x)/q(x), 1)
当 p(x) ≥ q(x) 时: 接受概率=1(target比draft更倾向x)
当 p(x) < q(x) 时: 接受概率=p(x)/q(x)(按概率接受)
联合概率(draft生成x且被接受):
P(accept x) = q(x) × min(p(x)/q(x), 1) = min(p(x), q(x))
拒绝概率:
P(reject) = 1 - Σ_x min(p(x), q(x)) = Σ_x max(p(x)-q(x), 0)
(利用恒等式: min(a,b) + max(a-b,0) = a 当a≥b; min(a,b) + 0 = b 当a<b)
CC2.2 恢复采样
拒绝后,从修正分布采样:
P(recover y|reject) = max(p(y)-q(y), 0) / Σ_x max(p(x)-q(x), 0)
CC2.3 最终输出分布
P(output = y) = P(accept y) + P(reject) × P(recover y|reject)
= min(p(y), q(y)) + [Σ_x max(p(x)-q(x),0)] × [max(p(y)-q(y),0) / Σ_x max(p(x)-q(x),0)]
= min(p(y), q(y)) + max(p(y)-q(y), 0)
Case 1: p(y) ≥ q(y)
= q(y) + (p(y)-q(y)) = p(y) ✓
Case 2: p(y) < q(y)
= p(y) + 0 = p(y) ✓
因此: P(output = y) = p(y) 对所有y成立。 □
CC2.4 多步拒绝采样
对于K个draft token,按顺序验证:
P(accept all K) = Π_{i=1}^{K} min(p_i/q_i, 1)
= Π_{i=1}^{K} min(p(x_i)/q(x_i), 1)
P(reject at position j) = [Π_{i=1}^{j-1} min(p_i/q_i, 1)] × [1 - min(p_j/q_j, 1)]
由于每步独立(draft token独立生成),乘积分解成立。
被拒绝后只采一个recovered token,后续draft token全部丢弃。
最终输出 = (accepted tokens up to j-1) + (recovered token at j)
数学保证: 对于位置j的最终token,其分布等价于从p采样。
对于位置j+1及之后,重新开始新的投机步骤。