vLLM v1 Sample 采样模块超深度分析
分析对象:
vllm/vllm/v1/sample代码规模:14 Python 文件(含2个__init__.py),4,133 行有效代码
分析日期:2026-04-22
vLLM v1 Sample 采样模块超深度逐行分析 --- 第一部分:核心层
一、模块定位
1.1 业务职责
vLLM v1 的 sample 模块是整个推理引擎中从模型输出到最终采样决策的关键桥梁。它的核心业务职责可以用一句话概括:
将模型输出的 logits 张量,经过一系列确定性与随机性变换,转换为每个请求的下一个生成 token。
这看似简单的"从连续分布到离散决策",实际包含了 vLLM 作为生产级推理框架所必须支持的全部采样控制能力------温度调节、Top-K/Top-P 截断、惩罚机制(重复/频率/出现)、坏词排除、白名单约束、自定义 logits 处理器、最小 token 数约束、min-p 过滤、思考 token 预算控制等。
1.2 功能定位
从功能维度看,sample 模块承担以下六项核心功能:
| 功能 | 说明 |
|---|---|
| Logits 预处理 | 精度提升(→float32)、白名单掩码、坏词排除、非 argmax 不变处理器 |
| 惩罚施加 | 重复惩罚、频率惩罚、出现惩罚------三种基于已生成 token 的调节机制 |
| 采样决策 | 贪心采样 vs 随机采样的分支路由,温度缩放,Top-K/Top-P 联合过滤 |
| Logprobs 计算 | 原始 logprobs(用于 top-k 展示)与处理后 logprobs(用于采样后输出)的分离计算 |
| 拒绝采样 | 面向投机解码场景的 draft token 验证与拒绝采样 |
| 批处理编排 | 在单次 forward pass 中同时处理混合的贪心/随机请求,高效利用 GPU 并行性 |
1.3 边界与范围
属于 sample 模块的:
- 从原始 logits 到 sampled token 的全部变换链路
- 采样参数的元数据封装(
SamplingMetadata) - Logits 处理器的抽象接口与内置实现
- Top-K/Top-P 的 Triton 高性能内核
- 投机解码的拒绝采样逻辑
不属于 sample 模块的:
- logits 的产生(属于模型前向传播
model_runner) - 采样参数的解析与构建(属于
sampling_params与scheduler) - 采样结果的后处理与文本解码(属于
detokenizer) - KV cache 管理(属于
kv_cache_manager)
1.4 在 vLLM 架构中的位置
┌─────────────────────────────────────────────────────────────────┐
│ vLLM v1 推理引擎 │
│ │
│ ┌──────────┐ ┌──────────────┐ ┌───────────┐ │
│ │Scheduler │───>│ Model Runner │───>│ Sample │ │
│ │ │ │ (forward) │ │ Module │ │
│ │ 构建请求 │ │ 产生 logits │ │ logits→token│ │
│ └──────────┘ └──────────────┘ └─────┬─────┘ │
│ │ │ │
│ │ ┌──────────────┐ │ │
│ └────────>│ Sampling │<───────────┘ │
│ │ Metadata │ sampled_token_ids │
│ │ 参数打包 │ + logprobs_tensors │
│ └──────────────┘ │ │
│ v │
│ ┌─────────────┐ │
│ │ Scheduler │ │
│ │ 结果收集 │ │
│ └─────────────┘ │
└─────────────────────────────────────────────────────────────────┘
sample 模块位于 Model Runner 的下游,Scheduler 的结果回收环节上游。它是纯计算模块------接收 logits + metadata,输出 sampled tokens + logprobs,不持有任何跨步持久状态(除 logits 处理器外)。
1.5 核心商业价值
- 推理质量保障:9 步采样管线确保模型输出的可控性,从温度到惩罚的每一步都有精确的数学保证
- 性能极致优化:Triton 内核 + FlashInfer 集成 + 批量混合采样,单次 forward 完成异构请求处理
- 扩展性设计:LogitsProcessor 插件体系支持自定义采样策略,通过 entry_points 机制实现零侵入扩展
- 投机解码支持:RejectionSampler 使 vLLM 能在投机解码场景下验证 draft token,大幅提升吞吐
二、模块整体结构
2.1 文件职责矩阵
sample 模块共包含 14 个 Python 文件 (含 2 个 __init__.py),总计 4133 行代码,分布在 3 个层级目录中:
| # | 文件路径 | 行数 | 架构层级 | 核心职责 |
|---|---|---|---|---|
| 1 | sample/__init__.py |
0 | 包初始化 | 空文件,模块包标记 |
| 2 | sample/sampler.py |
410 | 核心层 | Sampler 主类,采样管线编排,9步采样流程入口 |
| 3 | sample/metadata.py |
49 | 核心层 | SamplingMetadata 数据类,采样参数的批量化封装 |
| 4 | sample/rejection_sampler.py |
850 | 投机解码层 | RejectionSampler,draft token 验证与拒绝采样 |
| 5 | sample/ops/__init__.py |
0 | 算子层 | 空文件,ops 子包标记 |
| 6 | sample/ops/bad_words.py |
57 | 算子层 | 坏词排除算子,前缀匹配 + logits 掩码 |
| 7 | sample/ops/logprobs.py |
27 | 算子层 | logprobs 工具函数,批量计数排名 |
| 8 | sample/ops/penalties.py |
57 | 算子层 | 惩罚算子,repetition/frequency/presence 三合一 |
| 9 | sample/ops/topk_topp_sampler.py |
402 | 算子层 | TopK/TopP 采样器,多后端(CUDA/CPU/HIP/FlashInfer) |
| 10 | sample/ops/topk_topp_triton.py |
1057 | 算子层 | TopK/TopP Triton 内核,GPU 高性能实现 |
| 11 | sample/logits_processor/__init__.py |
360 | 处理器层 | LogitsProcessor 加载/构建/插件机制 |
| 12 | sample/logits_processor/interface.py |
106 | 处理器层 | LogitsProcessor 抽象接口与 BatchUpdate 定义 |
| 13 | sample/logits_processor/state.py |
165 | 处理器层 | BatchUpdateBuilder + LogitsProcessors 容器 |
| 14 | sample/logits_processor/builtin.py |
593 | 处理器层 | 4 个内置处理器(MinP/LogitBias/MinTokens/ThinkingTokenBudget) |
2.2 三层架构
┌─────────────────────────────────────────────────────────┐
│ 核心层 (Core) │
│ ┌──────────────┐ ┌───────────────────┐ │
│ │ sampler.py │───────>│ metadata.py │ │
│ │ 采样管线编排 │ <────── │ 参数元数据封装 │ │
│ │ 410 行 │ │ 49 行 │ │
│ └──────┬───────┘ └───────────────────┘ │
│ │ 调用 │
└─────────┼───────────────────────────────────────────────┘
│
v
┌─────────────────────────────────────────────────────────┐
│ 算子层 (Operators) │
│ ┌──────────────┐ ┌────────────┐ ┌──────────────────┐ │
│ │ penalties.py │ │ bad_words │ │ topk_topp │ │
│ │ 惩罚算子 │ │ 坏词排除 │ │ sampler.py │ │
│ │ 57 行 │ │ 57 行 │ │ 402 行 │ │
│ └──────────────┘ └────────────┘ └──────────────────┘ │
│ ┌──────────────┐ ┌────────────────────────────────┐ │
│ │ logprobs.py │ │ topk_topp_triton.py │ │
│ │ 排名计数 │ │ Triton GPU 内核 │ │
│ │ 27 行 │ │ 1057 行 │ │
│ └──────────────┘ └────────────────────────────────┘ │
└─────────────────────────────────────────────────────────┘
▲
│ 被调用
┌─────────────────────────────────────────────────────────┐
│ 处理器层 (Processors) │
│ ┌──────────────────┐ ┌──────────────────────────────┐ │
│ │ interface.py │ │ state.py │ │
│ │ 抽象接口+BatchUpd│ │ BatchUpdateBuilder+LogitsProc │ │
│ │ 106 行 │ │ 165 行 │ │
│ └──────────────────┘ └──────────────────────────────┘ │
│ ┌──────────────────┐ ┌──────────────────────────────┐ │
│ │ builtin.py │ │ __init__.py │ │
│ │ 4个内置处理器 │ │ 加载/构建/插件机制 │ │
│ │ 593 行 │ │ 360 行 │ │
│ └──────────────────┘ └──────────────────────────────┘ │
└─────────────────────────────────────────────────────────┘
‖ 并行
┌─────────────────────────────────────────────────────────┐
│ 投机解码层 (Speculative) │
│ ┌───────────────────────────────────────────────────┐ │
│ │ rejection_sampler.py │ │
│ │ RejectionSampler + 拒绝采样内核 │ │
│ │ 850 行 │ │
│ └───────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────┘
2.3 类层次结构
nn.Module
├── Sampler [核心] 采样管线主编排器
│ └── 内含 TopKTopPSampler [算子] Top-K/Top-P 采样器
│
└── RejectionSampler [投机解码] draft token 验证
└── 内含 Sampler 复用主采样器
LogitsProcessor (ABC) [处理器] 抽象基类
├── MinPLogitsProcessor [内置] min-p 过滤(argmax不变)
├── LogitBiasLogitsProcessor [内置] logit 偏置(argmax不变)
│ └── 继承 AdapterLogitsProcessor 每请求适配器模式
├── MinTokensLogitsProcessor [内置] 最小token数约束(非argmax不变)
│ └── 继承 AdapterLogitsProcessor
├── ThinkingTokenBudgetLogitsProcessor [内置] 思考token预算(非argmax不变)
│ └── 继承 AdapterLogitsProcessor
└── [用户自定义] 通过 entry_points 插件加载
LogitsProcessors [容器] 聚合所有处理器实例
├── .argmax_invariant: list argmax不变处理器列表
└── .non_argmax_invariant: list 非argmax不变处理器列表
SamplingMetadata (dataclass) [元数据] 采样参数批量化封装
BatchUpdateBuilder [状态] 批更新构建器
BatchUpdate (frozen dataclass) [状态] 批更新数据
2.4 依赖注入关系
sample 模块的依赖注入遵循"构造时注入,运行时消费"的模式:
Sampler.__init__(logprobs_mode)
├── 创建 TopKTopPSampler(logprobs_mode) ← 算子层注入
└── 查询 is_pin_memory_available() ← 平台能力注入
Sampler.forward(logits, sampling_metadata)
├── sampling_metadata ← 由 Scheduler 构建,运行时注入
│ ├── .logitsprocs: LogitsProcessors ← 处理器集合,由 build_logitsprocs() 构建
│ │ ├── .argmax_invariant[] ← 运行时分离的两类处理器
│ │ └── .non_argmax_invariant[]
│ ├── .temperature, .top_k, .top_p ← 采样参数张量
│ ├── .generators ← 随机数生成器
│ ├── .penalties_* ← 惩罚参数
│ └── .bad_words_token_ids ← 坏词列表
└── logits ← 由 ModelRunner 产生,运行时注入
build_logitsprocs(vllm_config, device, ...)
├── BUILTIN_LOGITS_PROCESSORS ← 硬编码内置列表
├── _load_logitsprocs_plugins() ← entry_points 插件发现
└── _load_logitsprocs_by_fqcns() ← FQCN 字符串加载
2.5 核心方法清单
Sampler 类方法
| 方法 | 类型 | 行数范围 | 核心职责 |
|---|---|---|---|
__init__ |
构造 | 55-59 | 初始化 TopKTopPSampler、pin_memory、logprobs_mode |
forward |
主入口 | 61-122 | 9步采样管线编排,logits→SamplerOutput |
gather_specific_token_logprobs |
辅助 | 124-185 | 指定 token ID 的 logprobs 高效计算 |
apply_temperature |
静态 | 187-194 | 温度缩放,in-place 除法 |
greedy_sample |
静态 | 196-197 | argmax 贪心采样 |
sample |
核心 | 199-244 | 采样决策:贪心/随机分支路由 |
compute_logprobs |
静态 | 246-247 | log_softmax 计算 |
gather_logprobs |
静态 | 249-293 | Top-K logprobs 收集与排名计算 |
_combine_outputs_with_spec_tokens |
静态 | 295-303 | 投机解码输出合并 |
apply_logits_processors |
核心 | 305-333 | Logits 预处理管线(白名单→坏词→处理器→惩罚) |
apply_penalties |
静态 | 335-349 | 惩罚施加委托 |
SamplingMetadata 字段
| 字段 | 类型 | 说明 |
|---|---|---|
temperature |
`Tensor | None` |
all_greedy |
bool |
是否全部贪心 |
all_random |
bool |
是否全部随机 |
top_p |
`Tensor | None` |
top_k |
`Tensor | None` |
generators |
dict[int, Generator] |
每请求随机数生成器 |
max_num_logprobs |
`int | None` |
no_penalties |
bool |
是否无惩罚 |
prompt_token_ids |
`Tensor | None` |
frequency_penalties |
Tensor |
频率惩罚 |
presence_penalties |
Tensor |
出现惩罚 |
repetition_penalties |
Tensor |
重复惩罚 |
output_token_ids |
list[list[int]] |
已生成 token |
allowed_token_ids_mask |
`Tensor | None` |
bad_words_token_ids |
dict[int, list[list[int]]] |
坏词列表 |
logitsprocs |
LogitsProcessors |
Logits 处理器集合 |
logprob_token_ids |
`dict[int, list[int]] | None` |
spec_token_ids |
`list[list[int]] | None` |
2.6 数据流出入
输入
logits: Tensor[batch_size, vocab_size] ← 模型输出,可能是 float16/bfloat16
sampling_metadata: SamplingMetadata ← Scheduler 构建的参数包
predict_bonus_token: bool = False ← 投机解码标记
logprobs_mode_override: LogprobsMode | None ← 运行时模式覆盖
输出
SamplerOutput:
├── sampled_token_ids: Tensor[batch_size, 1] ← int32,每请求一个采样 token
└── logprobs_tensors: LogprobsTensors | None ← 可选的 logprobs 信息
├── logprob_token_ids: Tensor ← int32 token ID 索引
├── logprobs: Tensor ← float32 logprobs 值
└── selected_token_ranks: Tensor ← 采样 token 排名
内部数据流(9步管线)
logits ──┐
│ Step 1: 计算原始 logprobs (raw_logprobs/raw_logits)
v
raw_logprobs ──┐
│ Step 2: logits → float32
v
float32_logits ──┐
│ Step 3: 白名单掩码
v
filtered_logits ──┐
│ Step 4: 坏词排除
v
clean_logits ──┐
│ Step 5: 非argmax不变处理器
v
proc_logits ──┐
│ Step 6: 惩罚
v
penalized_logits ──┐
│ Step 7: 采样决策
v
sampled + processed_logprobs ──┐
│ Step 8: gather logprobs
v
logprobs_tensors ──┐
│ Step 9: 构建输出
v
SamplerOutput
三、sampler.py 逐行深度解析
本节对
sampler.py(410 行)进行真正的逐行分析。每一行代码、每一个分支、每一个变量都会被解释。
3.1 文件头与许可证(第 1-3 行)
python
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
第 1 行:SPDX 许可证标识符,声明本文件采用 Apache 2.0 许可证。SPDX(Software Package Data Exchange)是标准化的许可证标识方式,被 Linux 基金会推荐使用。
第 2 行:SPDX 版权声明,版权归 vLLM 项目的所有贡献者所有。注意这里使用"contributors to the vLLM project"而非单一实体,体现了开源社区项目性质。
3.2 模块文档字符串(第 3 行)
python
"""A layer that samples the next tokens from the model's outputs."""
模块级文档字符串,一句话定义了本模块的核心功能:从模型输出中采样下一个 token 的层 。注意关键词"a layer"------这里把 Sampler 设计为一个 nn.Module 层,使其能被嵌入 PyTorch 的模型计算图中。
3.3 导入区(第 5-14 行)
python
import torch
import torch.nn as nn
第 5 行 :import torch --- 导入 PyTorch 核心库。用于张量操作、设备管理、自动微分等。在 sampler 中主要使用:张量类型转换(.to()、.long())、张量操作(.argmax()、.topk()、.where()、.gather()、.unsqueeze())、张量掩码(.masked_fill_())、张量拼接(.cat())。
第 6 行 :import torch.nn as nn --- 导入 PyTorch 神经网络模块。Sampler 继承自 nn.Module,因此需要此导入。nn.Module 赋予 Sampler 以下能力:
- 被纳入模型的
.modules()树 - 参数/缓冲区管理(虽然 Sampler 没有可训练参数)
- 设备转移(
.to(device)) - 与
torch.compile兼容
python
from vllm.config.model import LogprobsMode
第 8 行 :从 vllm.config.model 导入 LogprobsMode。这是一个类型别名(Literal 类型),定义了 logprobs 的计算模式。可能的取值包括:
"raw_logprobs":在 logits 经任何处理之前,先用log_softmax计算原始 logprobs"raw_logits":直接克隆原始 logits(转换为 float32)作为 logprobs"processed_logits":使用经过处理(惩罚、温度等)后的 logits"processed_logprobs":使用经过处理后 logits 的 log_softmax
这个模式控制的是 Step 1 中用什么作为最终返回的 logprobs 基底。
python
from vllm.utils.platform_utils import is_pin_memory_available
第 9 行 :从 vllm.utils.platform_utils 导入 is_pin_memory_available。此函数检测当前平台是否支持 CUDA pinned memory(锁页内存)。Pinned memory 允许 CPU→GPU 的 DMA 传输,提高数据搬运速度。在 penalty 计算中,output_token_ids 从 CPU 张量传输到 GPU 时会用到 pin_memory。
python
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
第 10 行 :从 vllm.v1.outputs 导入两个输出数据结构:
LogprobsTensors:封装 logprobs 相关的张量三元组(token_ids、logprobs、ranks)SamplerOutput:Sampler.forward() 的最终返回类型,包含sampled_token_ids和logprobs_tensors
python
from vllm.v1.sample.metadata import SamplingMetadata
第 11 行 :从本模块的 metadata.py 导入 SamplingMetadata。这是采样参数的批量化数据类,将每个请求的采样参数(temperature、top_k、top_p、penalties 等)打包为张量,支持批量并行处理。详见第 2.5 节。
python
from vllm.v1.sample.ops.bad_words import apply_bad_words
第 12 行 :从 ops/bad_words.py 导入 apply_bad_words 函数。此函数在 logits 上执行坏词排除------检查已生成的前缀是否匹配坏词的前缀部分,如果匹配则将坏词最后一个 token 的 logit 设为 -inf。这是 Step 4 的实现。
python
from vllm.v1.sample.ops.logprobs import batched_count_greater_than
第 13 行 :从 ops/logprobs.py 导入 batched_count_greater_than。此函数计算每个 batch 行中有多少元素的值大于等于给定阈值,用于计算采样 token 的排名。在 gather_logprobs 方法中使用。该函数使用 torch.compile 优化,避免内存拷贝问题。
python
from vllm.v1.sample.ops.penalties import apply_all_penalties
第 14 行 (此处指导入区的最后一行):从 ops/penalties.py 导入 apply_all_penalties。此函数一次性应用三种惩罚(repetition、frequency、presence),委托给 vllm.model_executor.layers.utils.apply_penalties 实现。这是 Step 6 的算子实现。
python
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
紧随其后的导入------从 ops/topk_topp_sampler.py 导入 TopKTopPSampler 类。这是 Top-K 和 Top-P 联合采样的核心算子,支持多后端(原生 PyTorch、CUDA、CPU、HIP/AMD、FlashInfer)。在 Sampler 构造时创建实例,作为 Step 7d 的执行者。
python
from vllm.v1.worker.gpu.sample.logprob import compute_token_logprobs
从 v1/worker/gpu/sample/logprob.py 导入 compute_token_logprobs。这是一个 Triton JIT 编译的 fused kernel,将 log_softmax + gather 融合为一次 GPU 计算。在 gather_specific_token_logprobs 方法中使用,用于高效计算指定 token 的 logprobs(避免计算整个词表的 log_softmax)。
3.4 模块常量(第 16 行)
python
_SAMPLING_EPS = 1e-5
第 16 行 :定义采样精度阈值常量 _SAMPLING_EPS = 1e-5(0.00001)。
用途:在两处使用------
apply_temperature:当温度低于此阈值时,视为贪心采样(温度 → 1.0,不缩放)sample:在混合贪心/随机采样时,用temperature < _SAMPLING_EPS判断某个请求是否应走贪心路径
设计考量 :不使用 0 而使用极小正值的原因------浮点数精度问题。temperature 可能是 1e-6 这样接近零但非零的值,数学上等价于贪心采样,但直接除以它会引入数值不稳定。将 < EPS 的温度替换为 1.0,使得 logits / 1.0 = logits(不变),等效于贪心。
命名约定:前导下划线表示模块私有常量。
3.5 Sampler 类定义与文档字符串(第 18-53 行)
python
class Sampler(nn.Module):
第 18 行 :Sampler 继承自 nn.Module。这是一个关键设计决策------将采样器设计为 PyTorch 层而非普通类,原因是:
- 模型图集成 :作为
nn.Module,Sampler 可以被纳入模型的forward计算流,与模型的其他层统一管理 - 设备一致性 :
nn.Module的.to(device)会自动递归应用到子模块(如TopKTopPSampler) - 编译兼容 :
torch.compile能正确处理nn.Module的调用图 - 序列化支持 :虽然 Sampler 没有可训练参数,但
nn.Module的 state_dict 机制确保了检查点兼容性
python
"""
A layer that samples the next tokens from the model's outputs
with the following steps in order:
1. If logprobs are requested:
a) If `logprobs_mode` is `raw_logprobs`, compute logprobs
as the final logprobs to return.
b) If `logprobs_mode` is `raw_logits`, clone the logits
as the final logprobs to return.
2. Convert logits to float32.
3. Apply allowed token ids whitelist.
4. Apply bad words exclusion.
5. Apply logit processors which are not argmax-invariant,
i.e. that can impact greedy sampling.
a) Min tokens processor
b) Logit bias processor
6. Apply penalties
a) Repetition penalty
b) Frequency penalty
c) Presence penalty
7. Sample the next tokens. `sample` method performs the following steps:
a) If not `all_random`, perform greedy sampling. If `all_greedy`,
return the greedily sampled tokens and final logprobs if requested.
b) Apply temperature.
c) Apply logit processors which are argmax-invariant, by default
the min_p processor.
d) Apply top_k and/or top_p.
e) Sample the next tokens with the probability distribution.
f) If `all_random` or temperature >= epsilon (1e-5), return the
randomly sampled tokens and final logprobs if requested. Else,
return the greedily sampled tokens and logprobs if requested.
8. Gather the logprobs of the top `max_num_logprobs` and sampled token
(if requested). Note that if the sampled token is within the top
`max_num_logprobs`, the logprob will be eventually merged in
`LogprobsProcessor` during output processing. Therefore, the
final output may contain either `max_num_logprobs + 1` or
`max_num_logprobs` logprobs.
9. Return the final `SamplerOutput`.
"""
第 19-53 行:类文档字符串,详细描述了 9 步采样管线。这是理解 Sampler 工作原理的最重要的文档。逐步解析:
Step 1 :如果请求了 logprobs(max_num_logprobs is not None),根据 logprobs_mode 决定如何准备原始 logprobs:
raw_logprobs:对原始 logits 执行log_softmax,得到数学上精确的 log 概率raw_logits:直接克隆原始 logits(转为 float32),用于需要原始 logit 值的场景
关键注释 :这里计算的 logprobs 是基于未经惩罚和温度缩放的原始 logits,这与 V0 sampler 的行为不同(V0 使用处理后的 logits)。这是一个有意的 API 变更------原始 logprobs 更能反映模型的"真实信念"。
Step 2 :将 logits 转换为 float32 精度。这是数值稳定性的关键------log_softmax、温度除法、Top-K/Top-P 比较等操作在 float16/bfloat16 下可能溢出或丢失精度。
Step 3 :应用 token 白名单掩码。将不在白名单中的 token logits 设为 -inf,等效于概率 0。通过 masked_fill_ 原地操作。
Step 4 :应用坏词排除。基于前缀匹配检查已生成序列,如果坏词前缀已出现,将坏词最后 token 的 logit 设为 -inf。
Step 5 :应用非 argmax 不变的 logits 处理器。这些处理器会改变 argmax 结果(即贪心采样结果),因此必须在贪心采样之前应用:
- MinTokens 处理器:当已生成 token 数 < 最小要求时,将 EOS token 的 logit 设为
-inf - LogitBias 处理器:对特定 token 加偏置值
Step 6:应用三种惩罚------重复惩罚(按比例缩放已出现 token 的 logits)、频率惩罚(线性递减)、出现惩罚(常数递减)。
Step 7:采样决策,内部又包含 6 个子步骤:
- (a) 贪心采样分支:如果不是全部随机,先做 argmax;如果全部贪心则直接返回
- (b) 温度缩放:
logits /= temperature - © argmax 不变处理器:如 min_p 过滤(不影响贪心结果但影响随机采样分布)
- (d) Top-K/Top-P 过滤
- (e) 从概率分布中随机采样
- (f) 根据温度决定最终使用贪心还是随机结果
Step 8 :收集 Top-K logprobs 和采样 token 的 logprobs。注意最终可能返回 max_num_logprobs + 1 个 logprobs(采样 token 可能不在 Top-K 中)。
Step 9 :构建并返回 SamplerOutput。
3.6 __init__ 方法(第 55-59 行)
python
def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs"):
第 55 行 :构造函数签名。接受一个参数 logprobs_mode,默认值为 "raw_logprobs"。
参数 logprobs_mode:LogprobsMode 类型(Literal 字符串类型),控制 logprobs 的计算方式:
"raw_logprobs":使用log_softmax(logits)计算,这是数学上精确的 log 概率"raw_logits":直接使用原始 logit 值(转为 float32)- 默认
"raw_logprobs"保持向后兼容性
默认值设计 :使用 "raw_logprobs" 作为默认值意味着大多数用户场景下返回的是标准化的 log 概率,而非原始 logit 值。
python
super().__init__()
第 56 行 :调用 nn.Module.__init__()。这初始化了 PyTorch 模块的底层机制:
_parameters(有序字典)_buffers(有序字典)_modules(有序字典)_forward_hooks、_forward_pre_hooks等
虽然 Sampler 没有可训练参数,但 super().__init__() 仍然必须调用,否则 nn.Module 的基础设施不会被正确初始化。
python
self.topk_topp_sampler = TopKTopPSampler(logprobs_mode)
第 57 行 :创建 TopKTopPSampler 实例并赋值给 self.topk_topp_sampler。
作为 nn.Module 属性 :因为 TopKTopPSampler 也是 nn.Module,将其赋值给 self.xxx 会自动注册到 self._modules 中。这意味着:
Sampler.modules()会包含TopKTopPSamplerSampler.to(device)会递归转移TopKTopPSamplerSampler.state_dict()会包含TopKTopPSampler的缓冲区
参数传递 :logprobs_mode 传递给 TopKTopPSampler,因为采样器在随机采样时可能需要计算 processed logprobs。
python
self.pin_memory = is_pin_memory_available()
第 58 行:检测并缓存平台是否支持 pin_memory。
为什么缓存 :is_pin_memory_available() 可能涉及平台检测逻辑(查询 CUDA 是否可用),在每次需要时调用会产生不必要的开销。构造时查询一次,后续直接使用布尔值。
用途 :在 apply_penalties → ops/penalties.py → _convert_to_tensors 中,将 output_token_ids 转为 CPU 张量时使用 pin_memory 加速 CPU→GPU 传输。
python
self.logprobs_mode = logprobs_mode
第 59 行 :将 logprobs_mode 存储为实例属性。
用途 :在 forward() 和 sample() 方法中作为默认 logprobs 模式使用。forward() 允许通过 logprobs_mode_override 参数覆盖此默认值。
3.7 forward 方法(第 61-122 行)
这是 Sampler 的主入口方法,编排了完整的 9 步采样管线。
python
def forward(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
predict_bonus_token: bool = False,
logprobs_mode_override: LogprobsMode | None = None,
) -> SamplerOutput:
第 61-66 行:方法签名。
参数 logits :Tensor[batch_size, vocab_size]。模型前向传播的输出,未经任何采样处理。数据类型可能是 float16 或 bfloat16(取决于模型配置和硬件)。
参数 sampling_metadata :SamplingMetadata 实例。由 Scheduler 构建,包含当前 batch 中所有请求的采样参数。注意这是一个 dataclass,不是 nn.Module。
参数 predict_bonus_token :bool,默认 False。用于投机解码场景。当为 True 时,如果存在惩罚或坏词,需要将 speculative token 合并到 output_token_ids 中再计算惩罚(因为惩罚基于已出现的 token 集合)。
参数 logprobs_mode_override :LogprobsMode | None,默认 None。允许在运行时覆盖构造时设置的 logprobs_mode。如果提供,则使用此值;否则使用 self.logprobs_mode。
返回值 :SamplerOutput 实例,包含 sampled_token_ids 和可选的 logprobs_tensors。
python
logprobs_mode = logprobs_mode_override or self.logprobs_mode
第 67 行 :解析 logprobs 模式。优先使用运行时覆盖值,如果没有覆盖则使用构造时的默认值。Python 的 or 语义:如果 logprobs_mode_override 为 None(falsy),则回退到 self.logprobs_mode。
python
# NOTE(woosuk): Use the original logits (before any penalties or
# temperature scaling) for the top-k logprobs.
# This is different from the V0 sampler, which uses the logits that
# is used for sampling (after penalties and temperature scaling).
第 68-71 行:重要注释。由 woosuk 添加,说明了一个关键的 API 变更:
V0 vs V1 行为差异:
- V0:logprobs 基于处理后的 logits(经过惩罚和温度缩放),这意味着返回的 logprobs 反映了采样时的实际概率分布
- V1:logprobs 基于原始 logits(未经处理),这更接近模型的"真实置信度"
设计理由:原始 logprobs 更有解释性------它反映的是模型对 token 的原始预测概率,不受采样参数的影响。用户可能想看到"模型认为这个 token 有多可能"而非"经过各种调节后这个 token 有多可能"。
python
num_logprobs = sampling_metadata.max_num_logprobs
第 72 行 :提取 max_num_logprobs 到局部变量,减少属性访问次数。
max_num_probs 的语义:
None:不需要 logprobs0:仅需要采样 token 的 logprob- 正整数:需要 Top-K logprobs + 采样 token 的 logprob
-1:返回完整的未排序未排名 logprobs 张量
python
if num_logprobs is not None:
if logprobs_mode == "raw_logprobs":
raw_logprobs = self.compute_logprobs(logits)
elif logprobs_mode == "raw_logits":
if logits.dtype == torch.float32:
raw_logprobs = logits.clone()
else:
raw_logprobs = logits.to(torch.float32)
第 73-79 行 :Step 1 --- 计算原始 logprobs。
第 73 行 :if num_logprobs is not None --- 仅当请求了 logprobs 时才计算,避免不必要的计算开销。
第 74-75 行 :如果模式是 "raw_logprobs",调用 self.compute_logprobs(logits),即 logits.log_softmax(dim=-1, dtype=torch.float32)。这计算每个 token 在词表上的 log 概率。注意 log_softmax 在内部先减去最大值(数值稳定技巧),然后取 exp、求和、取 log,最后减去 log_sum_exp。
第 76-79 行 :如果模式是 "raw_logits":
- 如果 logits 已经是 float32,直接
.clone()(避免修改原始张量) - 如果 logits 不是 float32(float16/bfloat16),用
.to(torch.float32)转换
为什么不直接用 logits.clone().to(torch.float32) :当 logits 已经是 float32 时,.clone() 比 .to(torch.float32) 更高效------后者即使是相同类型也会创建新张量,但前者语义更清晰。这里的分支优化避免了不必要的 .to() 调用。
变量 raw_logprobs 的生命周期 :这个变量可能在后续被 processed_logprobs 覆盖(第 97 行),也用于 gather_logprobs 的输入。
python
# Use float32 for the logits.
logits = logits.to(torch.float32)
第 81 行 :Step 2 --- logits 精度提升。
将 logits 转换为 float32。这是一个关键操作,原因:
log_softmax需要高精度以避免数值溢出/下溢- 温度除法
logits / temp在低精度下可能不准确 - Top-K/Top-P 的概率归一化需要足够精度
- 惩罚计算(乘以 >1 的重复惩罚系数)可能导致 float16 溢出
.to(torch.float32) 的行为 :如果 logits 已经是 float32,这是 no-op(返回自身,不创建新张量)。如果是 float16/bfloat16,创建新的 float32 张量。注意这里的 logits 是局部变量,重新绑定不会影响调用方的原始张量。
python
logits = self.apply_logits_processors(
logits, sampling_metadata, predict_bonus_token
)
第 83-85 行 :Steps 3-6 --- Logits 预处理管线。
调用 apply_logits_processors,该方法内部按序执行:
- 应用白名单掩码(Step 3)
- 应用坏词排除(Step 4)
- 应用非 argmax 不变处理器(Step 5)
- 应用惩罚(Step 6)
返回值 :处理后的 logits 张量(可能是原地修改,也可能是新张量)。重新绑定到 logits 变量。
predict_bonus_token 参数 :传递到 apply_logits_processors 中,影响 output_token_ids 的合并逻辑(是否包含 speculative tokens)。
python
# Sample the next token.
sampled, processed_logprobs = self.sample(logits, sampling_metadata)
第 86-87 行 :Step 7 --- 采样决策。
调用 self.sample() 执行采样,返回两个值:
sampled:Tensor[batch_size],采样得到的 token IDprocessed_logprobs:Tensor | None,处理后的 logprobs(仅在特定 logprobs_mode 下返回)
为什么叫 processed_logprobs :与前面的 raw_logprobs 对应。raw_logprobs 是基于原始 logits 计算的,而 processed_logprobs 是基于经过所有处理(惩罚、温度、Top-K/Top-P)后的 logits/logprobs 计算的。
python
if processed_logprobs is not None:
raw_logprobs = processed_logprobs
第 88-89 行 :如果 sample() 返回了 processed_logprobs,则用其覆盖 raw_logprobs。
什么时候 processed_logprobs 不为 None :当 logprobs_mode 为 "processed_logits" 或 "processed_logprobs" 时,sample() 方法内部会计算并返回处理后的 logprobs。此时覆盖 raw_logprobs 使得后续的 gather_logprobs 使用处理后的值。
设计理由 :通过覆盖而非新变量,简化了后续代码的逻辑------gather_logprobs 只需要关注 raw_logprobs 这个变量名,不需要知道它是原始的还是处理后的。
python
# Convert sampled token ids to int64 (long) type to ensure compatibility
# with subsequent operations that may use these values as indices.
# This conversion is necessary because FlashInfer sampling operations
# return int32 (while PyTorch argmax and topk return int64).
sampled = sampled.long()
第 90-94 行:将采样结果转换为 int64。
为什么需要转换:不同的采样后端返回不同的数据类型:
- PyTorch
argmax():返回 int64 - PyTorch
topk():返回 int64 - FlashInfer 采样操作:返回 int32
后续操作(如 logprobs.gather(-1, token_ids))需要索引为 int64,因此统一转换。
.long() :等价于 .to(torch.int64),是 PyTorch 的惯用写法。
python
# Handle logprob_token_ids if specified (more efficient than full vocab)
# This is used by generative_scoring API to get logprobs for specific tokens
logprob_token_ids_tensors = None
if sampling_metadata.logprob_token_ids:
logprob_token_ids_tensors = self.gather_specific_token_logprobs(
logits, sampling_metadata.logprob_token_ids, sampled
)
第 96-101 行:处理指定的 logprob_token_ids------一种高效计算特定 token logprobs 的方式。
logprob_token_ids 的用途:某些 API(如 generative_scoring)只需要特定 token 的 logprobs,而非整个 Top-K。例如,评估模型在某个特定 token 上的置信度。此时计算整个词表的 log_softmax 再取 Top-K 是浪费的。
gather_specific_token_logprobs :使用 Triton fused kernel(log_softmax + gather)高效计算,只需对指定的 token 做 log_softmax + gather,避免计算整个词表。
传入 logits :注意这里传入的是经过所有处理后的 logits(Step 2-6 之后),因此计算出的 logprobs 反映了惩罚等处理的效果。
传入 sampled:已转为 int64 的采样 token ID,用于获取采样 token 的排名。
python
if num_logprobs is None:
logprobs_tensors = logprob_token_ids_tensors
第 103-104 行 :如果没有请求 Top-K logprobs (num_logprobs is None),则 logprobs 输出完全由 logprob_token_ids_tensors 决定:
- 如果有
logprob_token_ids:返回指定 token 的 logprobs - 如果没有:
logprob_token_ids_tensors为None,最终logprobs_tensors也为None
python
elif num_logprobs == -1:
# Return the full unsorted and unranked logprobs.
logprobs_tensors = LogprobsTensors(
torch.empty(0), raw_logprobs, torch.empty(0)
)
第 105-108 行 :如果 num_logprobs == -1,返回完整的未排序未排名 logprobs。
LogprobsTensors(torch.empty(0), raw_logprobs, torch.empty(0)):
- 第一个参数
logprob_token_ids:空张量(没有索引,因为返回全部) - 第二个参数
logprobs:完整的[batch_size, vocab_size]logprobs 张量 - 第三个参数
selected_token_ranks:空张量(没有排名,因为返回全部)
使用 torch.empty(0) 而非 None:保持张量类型一致性,避免下游代码的类型检查问题。
python
else:
# Gather the logprobs and ranks of the topk and sampled token.
logprobs_tensors = self.gather_logprobs(
raw_logprobs, num_logprobs, token_ids=sampled
)
第 109-112 行 :Step 8 --- 正常的 Top-K logprobs 收集。
当 num_logprobs 为 0 或正整数时,调用 gather_logprobs 收集 Top-K logprobs 和采样 token 的 logprob/rank。
传入 raw_logprobs :可能是原始 logprobs("raw_logprobs" 模式)或处理后的 logprobs("processed_*" 模式),取决于之前的覆盖逻辑。
传入 token_ids=sampled:采样 token 的 ID,用于获取该 token 的 logprob 和排名。
python
# If we have both num_logprobs and logprob_token_ids, prefer
# logprob_token_ids as it's more specific
if logprob_token_ids_tensors is not None and num_logprobs is not None:
logprobs_tensors = logprob_token_ids_tensors
第 114-116 行 :当同时存在 Top-K logprobs 和指定 token logprobs 时,优先使用指定 token logprobs。
设计理由 :logprob_token_ids 是更精确的请求------用户明确知道要哪些 token 的 logprobs。Top-K 是更泛化的请求。当两者共存时,精确请求应优先。
注意 :这个覆盖发生在 gather_logprobs 之后,意味着 Top-K 的计算被浪费了。但这种情况(同时请求 Top-K 和指定 token)在实际中极少出现,性能损失可接受。
python
# Use int32 to reduce the tensor size.
sampled = sampled.to(torch.int32)
第 118-119 行:将采样 token ID 从 int64 转回 int32。
为什么先转 int64 再转 int32 :int64 是中间态,为了 gather_logprobs 中的 torch.topk 和 logprobs.gather 操作(它们需要 int64 索引)。计算完成后转回 int32 以减少传输和存储开销------int32 的张量大小是 int64 的一半。
python
# These are GPU tensors.
sampler_output = SamplerOutput(
# The sampled tokens are expanded to 2D tensor with shape
# [num_requests, 1], where each row represents one generated
# token per request.
sampled_token_ids=sampled.unsqueeze(-1),
logprobs_tensors=logprobs_tensors,
)
return sampler_output
第 121-122 行 (实际跨越第 120-127 行):Step 9 --- 构建并返回 SamplerOutput。
sampled.unsqueeze(-1) :将 [batch_size] 扩展为 [batch_size, 1]。这是因为输出格式约定每个请求可能生成多个 token(投机解码场景),1 表示当前只生成了 1 个 token。
logprobs_tensors :可能是 None(无 logprobs 请求)、LogprobsTensors(Top-K 或指定 token 的 logprobs)。
注释 "These are GPU tensors":提醒下游代码,这些张量位于 GPU 上,如果需要 CPU 访问需要显式传输。
SamplerOutput 的结构:
python
@dataclass
class SamplerOutput:
sampled_token_ids: torch.Tensor # [batch_size, 1], int32
logprobs_tensors: LogprobsTensors | None
3.8 gather_specific_token_logprobs 方法(第 124-185 行)
python
def gather_specific_token_logprobs(
self,
logits: torch.Tensor,
logprob_token_ids: dict[int, list[int]],
sampled: torch.Tensor,
) -> LogprobsTensors | None:
第 124-128 行:方法签名。
参数 logits :[batch_size, vocab_size],经过所有处理的 logits(float32)。
参数 logprob_token_ids :字典,键为请求索引(int),值为该请求需要计算 logprobs 的 token ID 列表。注意这是异构的------不同请求可能有不同数量和不同 ID 的 token。
参数 sampled :[batch_size],已采样的 token ID(int64)。
返回值 :LogprobsTensors | None。如果没有请求需要指定 token logprobs,返回 None。
python
"""Compute logprobs for specific token IDs using Triton kernel.
This method handles heterogeneous token ID lists across requests by
padding shorter lists to max length and using a fused Triton kernel
for efficient log_softmax + gather computation.
Benchmarks show the Triton kernel approach is ~1.4x faster than sparse
gather for batch sizes > 1 due to the fused kernel reducing memory
bandwidth requirements.
Args:
logits: [batch_size, vocab_size] tensor of logits
logprob_token_ids: dict mapping req_index -> list of token IDs
sampled: [batch_size] tensor of sampled token IDs
Returns:
LogprobsTensors with logprobs for the specified tokens, or None
if no requests have logprob_token_ids.
"""
第 129-150 行:方法文档字符串,包含重要的性能注释。
关键性能数据:Triton fused kernel 比稀疏 gather 快约 1.4 倍(batch_size > 1 时),因为融合内核减少了内存带宽需求------不需要先写出完整的 log_softmax 结果再 gather,而是在一次 kernel 调用中完成。
python
if not logprob_token_ids:
return None
第 151-152 行 :空字典快速返回。如果没有任何请求指定了 logprob_token_ids,直接返回 None,避免后续不必要的张量创建。
python
batch_size = logits.shape[0]
device = logits.device
第 154-155 行 :提取 batch 大小和设备信息。device 用于确保所有创建的张量都在同一设备上。
python
# Find max number of tokens across all requests
max_num_tokens = max(len(tids) for tids in logprob_token_ids.values())
第 157-158 行:计算所有请求中最大的 token ID 列表长度。这决定了填充后的张量宽度。由于不同请求可能有不同数量的指定 token,需要填充到统一长度。
python
# Create padded token_ids tensor: [batch_size, max_num_tokens + 1]
# +1 for sampled token in first position
token_ids_tensor = torch.zeros(
batch_size, max_num_tokens + 1, dtype=torch.int64, device=device
)
token_ids_tensor[:, 0] = sampled # First column is sampled token
第 160-165 行:创建填充后的 token_ids 张量。
形状 [batch_size, max_num_tokens + 1]:+1 是因为第一列固定存放采样 token 的 ID。
第 0 列为采样 token:这是一个设计约定------采样 token 总是在第一列,后续列是用户指定的 token。这确保采样 token 的 logprob 总是可获取的。
torch.zeros :初始化为 0,后续填充有效位置。0 在词表中可能对应有效 token,但配合 valid_mask 使用不会引起歧义。
python
# Create mask for valid positions (True = valid, False = padded)
valid_mask = torch.zeros(
batch_size, max_num_tokens + 1, dtype=torch.bool, device=device
)
valid_mask[:, 0] = True # Sampled token is always valid
第 167-171 行:创建有效性掩码。
用途 :填充位置的 logprobs 需要被标记为无效(-inf),否则 0 号 token 的 logprob 会是有效值,造成误导。
第 0 列始终有效:采样 token 的 logprob 总是需要。
python
# Fill in token IDs for each request
for req_idx, token_ids in logprob_token_ids.items():
num_tokens = len(token_ids)
token_ids_tensor[req_idx, 1 : num_tokens + 1] = torch.tensor(
token_ids, dtype=torch.int64, device=device
)
valid_mask[req_idx, 1 : num_tokens + 1] = True
第 173-178 行:填充每个请求的指定 token ID。
req_idx:字典的键,即请求在 batch 中的索引。
1 : num_tokens + 1 :从第 1 列开始填充(第 0 列是采样 token),填充 num_tokens 个元素。
torch.tensor(token_ids, ...) :将 Python list 转为 GPU 张量。注意这是在循环中创建小张量,虽然不是最优的,但因为 logprob_token_ids 通常只有少量请求且每个请求的 token 数很少,性能影响可忽略。
valid_mask[req_idx, 1 : num_tokens + 1] = True:标记这些位置为有效。
python
# Compute logprobs using the fused Triton kernel (log_softmax + gather)
logprobs = compute_token_logprobs(logits, token_ids_tensor)
第 180-181 行:使用 Triton fused kernel 计算 logprobs。
compute_token_logprobs :定义在 v1/worker/gpu/sample/logprob.py 中的 Triton JIT 内核。它融合了 log_softmax + gather 两个操作:
- 对 logits 执行 log_softmax(数值稳定版:减最大值 → exp → 归一化 → log)
- 根据
token_ids_tensor中的索引 gather 出指定位置的 logprobs
优势 :不需要先写出 [batch_size, vocab_size] 的完整 logprobs 再 gather,而是在 kernel 内部直接计算目标位置的值,大幅减少内存写入。
输入 :logits [batch_size, vocab_size],token_ids_tensor [batch_size, max_num_tokens + 1]
输出 :logprobs [batch_size, max_num_tokens + 1]
python
# Mask invalid (padded) positions with -inf
logprobs = logprobs.masked_fill(~valid_mask, float("-inf"))
第 183-184 行 :将填充位置的 logprobs 设为 -inf。
~valid_mask:取反,True 变 False,False 变 True。即标记无效位置。
float("-inf") :负无穷,语义上表示"这个位置的 logprob 无效/不存在"。下游代码可以通过 -inf 识别并跳过这些位置。
python
# Compute ranks for the sampled token
sampled_logits = logits.gather(-1, sampled.unsqueeze(-1))
token_ranks = (logits > sampled_logits).sum(dim=-1)
第 186-187 行:计算采样 token 的排名。
sampled_logits :[batch_size, 1],采样 token 对应的 logit 值。gather 从 logits 的最后一维按索引取值。
token_ranks :[batch_size],每个 batch 行中有多少 token 的 logit 严格大于采样 token 的 logit。排名从 0 开始------如果采样 token 的 logit 是最大的,排名为 0。
注意 :这里使用 >(严格大于)而非 >=。这意味着如果有多个 token 与采样 token 的 logit 相同,它们的排名相同。
python
return LogprobsTensors(
logprob_token_ids=token_ids_tensor.to(torch.int32),
logprobs=logprobs,
selected_token_ranks=token_ranks,
)
第 189-193 行:构建并返回 LogprobsTensors。
token_ids_tensor.to(torch.int32):将 token ID 从 int64 转为 int32 以节省空间。int32 可表示到 2^31 ≈ 21 亿,远超任何词表大小。
logprobs :[batch_size, max_num_tokens + 1],float32,指定 token 的 logprobs(填充位置为 -inf)。
token_ranks :[batch_size],采样 token 的排名。
3.9 apply_temperature 静态方法(第 187-194 行)
python
@staticmethod
def apply_temperature(
logits: torch.Tensor,
temp: torch.Tensor,
all_random: bool,
) -> torch.Tensor:
第 187-191 行:方法签名。
@staticmethod :标记为静态方法,不需要 self。这意味着此方法不访问实例状态,可以独立调用。
参数 logits :[batch_size, vocab_size],float32 logits。
参数 temp :[batch_size],每个请求的温度参数。
参数 all_random :bool,是否所有请求都是随机采样。
python
# Use in-place division to avoid creating a new tensor.
# Avoid division by zero if there are greedy requests.
if not all_random:
temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
第 192-194 行:避免除以零。
当 all_random=False:意味着存在贪心请求(温度为 0 或接近 0)。直接除以接近 0 的温度会导致数值爆炸或 NaN。
torch.where:元素级条件------如果温度 < EPS,替换为 1.0;否则保持原值。这样贪心请求的 logits 不变(除以 1.0),随机请求正常缩放。
temp = torch.where(...) :创建新的 temp 张量,不影响原始 sampling_metadata.temperature。
当 all_random=True :所有请求都是随机的(温度 > 0),不需要除零保护,跳过 torch.where 操作节省计算。
python
return logits.div_(temp.unsqueeze(dim=1))
第 195 行 :执行温度缩放,原地除法。
temp.unsqueeze(dim=1) :将 [batch_size] 扩展为 [batch_size, 1],以便与 [batch_size, vocab_size] 的 logits 广播。
.div_():原地除法(注意尾随下划线)。修改 logits 张量本身,不创建新张量。
温度的数学含义 :logits / T 等效于 Boltzmann 分布的温度参数。高温(T > 1)使分布更均匀,低温(T < 1)使分布更尖锐,T → 0 趋近于 argmax。
3.10 greedy_sample 静态方法(第 196-197 行)
python
@staticmethod
def greedy_sample(logits: torch.Tensor) -> torch.Tensor:
return logits.argmax(dim=-1).view(-1)
第 196-197 行:贪心采样实现。
logits.argmax(dim=-1) :沿最后一维(词表维度)取最大值的索引。返回 [batch_size] 的 int64 张量。
.view(-1) :确保输出为 1D。虽然 argmax(dim=-1) 已经返回 1D,但 view(-1) 是防御性编码,确保即使输入形状变化也不会出错。
数学含义:选择 logit 值最大的 token,等效于概率最大(未加温度时)。
3.11 sample 方法(第 199-244 行)
这是采样决策的核心方法,实现了 Step 7 的全部逻辑。
python
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
logprobs_mode_override: LogprobsMode | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""Sample logits based on sampling metadata.
The various logits processing functions called in this method
may update the logits tensor in-place.
"""
第 199-207 行:方法签名和文档。
参数 logits :[batch_size, vocab_size],经过 Step 2-6 处理后的 float32 logits。
参数 sampling_metadata:采样参数。
参数 logprobs_mode_override:运行时 logprobs 模式覆盖。
返回值 :tuple[Tensor, Tensor | None]。第一个是采样 token,第二个是处理后的 logprobs(可能为 None)。
文档注释:"may update the logits tensor in-place"------提醒调用方,此方法可能修改传入的 logits 张量。
python
logprobs_mode = logprobs_mode_override or self.logprobs_mode
第 208 行 :与 forward 中相同的模式------优先使用覆盖值。
python
assert not (sampling_metadata.all_greedy and sampling_metadata.all_random)
第 209 行:不变量断言------不可能同时 all_greedy 和 all_random。这是逻辑约束:如果所有请求都是贪心的,不可能同时所有请求都是随机的。
为什么用 assert 而非 if:这是编程错误(Scheduler 构建的 metadata 不应违反此约束),而非运行时异常。assert 在生产环境(Python -O 模式)下会被移除,不影响性能。
python
if sampling_metadata.all_random:
greedy_sampled = None
第 210-211 行:如果所有请求都是随机的,不需要贪心采样结果。
greedy_sampled = None :设为 None 而非跳过计算,是为了统一后续的分支逻辑------后续代码通过 greedy_sampled is None 判断是否存在贪心请求。
python
else:
greedy_sampled = self.greedy_sample(logits)
if sampling_metadata.all_greedy:
processed_logprobs = None
if sampling_metadata.max_num_logprobs is not None:
if logprobs_mode == "processed_logits":
processed_logprobs = logits
elif logprobs_mode == "processed_logprobs":
processed_logprobs = self.compute_logprobs(logits)
return greedy_sampled, processed_logprobs
第 212-220 行:贪心采样分支。
第 213 行 :执行贪心采样 argmax(dim=-1)。
第 214 行 :如果所有请求都是贪心的,直接返回,跳过后续的温度/Top-K/Top-P 等随机采样步骤。这是快速路径优化------全贪心 batch 不需要任何随机采样计算。
第 215-219 行:如果请求了 logprobs,根据模式计算:
"processed_logits":直接使用处理后的 logits(已经是 float32)"processed_logprobs":对处理后的 logits 执行log_softmax
注意:此时 logits 已经经过 Step 2-6 的所有处理(白名单、坏词、处理器、惩罚),所以是"processed"的。
processed_logprobs = logits(非 clone) :因为 logits 在后续不会被修改(直接 return),所以可以直接引用。但如果调用方后续修改 logits 会有问题------不过此方法已 return,不会发生。
python
assert sampling_metadata.temperature is not None
第 221 行:断言温度参数存在。如果不是全贪心,必须有温度参数。
python
# Apply temperature.
logits = self.apply_temperature(
logits, sampling_metadata.temperature, sampling_metadata.all_random
)
第 223-225 行 :Step 7b --- 温度缩放。
注意 logits 被重新绑定。apply_temperature 返回的是原地修改后的 logits(.div_()),所以实际上 logits 对象没变,只是内容被修改了。但写 logits = ... 使得代码意图更清晰。
python
# Apply logits processors that only apply to random sampling
# (argmax invariant)
for processor in sampling_metadata.logitsprocs.argmax_invariant:
logits = processor.apply(logits)
第 227-229 行 :Step 7c --- 应用 argmax 不变处理器。
argmax_invariant :不影响 argmax 结果的处理器,仅在随机采样时有意义。典型代表是 MinPLogitsProcessor。
为什么在温度之后应用:argmax 不变处理器作用于温度缩放后的 logits。例如 min-p 需要基于 softmax 概率(即温度后的分布)来设定阈值,因此必须在温度缩放后。
循环遍历 :可能有多个 argmax 不变处理器,按顺序应用。每个处理器的 apply() 返回 logits(可能是原地修改)。
python
# Apply top_k and/or top_p.
random_sampled, processed_logprobs = self.topk_topp_sampler(
logits,
sampling_metadata.generators,
sampling_metadata.top_k,
sampling_metadata.top_p,
)
第 231-235 行 :Step 7d-e --- Top-K/Top-P 过滤 + 随机采样。
self.topk_topp_sampler() :调用 TopKTopPSampler.__call__()(即 forward()),它会:
- 根据 top_k 和 top_p 参数选择后端(CUDA/FlashInfer/CPU/HIP)
- 执行 Top-K 过滤(保留概率最高的 K 个 token)
- 执行 Top-P 过滤(保留累积概率达到 P 的最小 token 集)
- 从过滤后的分布中随机采样
参数:
logits:温度缩放 + argmax 不变处理后的 logitsgenerators:每请求的随机数生成器(保证可复现性)top_k:Top-K 参数张量top_p:Top-P 参数张量
返回值:
random_sampled:[batch_size],随机采样的 token IDprocessed_logprobs:Tensor | None,TopKTopPSampler 可能计算的处理后 logprobs
python
if greedy_sampled is None:
return random_sampled, processed_logprobs
第 237-238 行 :如果不存在贪心请求(all_random=True),直接返回随机采样结果。
python
sampled = torch.where(
sampling_metadata.temperature < _SAMPLING_EPS,
greedy_sampled,
random_sampled,
out=greedy_sampled, # Reuse tensor
)
return sampled, processed_logprobs
第 240-244 行 :Step 7f --- 混合贪心/随机结果。
torch.where(condition, x, y):元素级条件选择:
- 如果
temperature < EPS(贪心请求),使用greedy_sampled - 否则(随机请求),使用
random_sampled
out=greedy_sampled :复用 greedy_sampled 张量的内存,避免分配新张量。这是一个重要的内存优化------torch.where 的结果可以写入 out 指定的张量中。
为什么可以复用 :greedy_sampled 在 where 操作后就不再需要了,所以可以安全地覆盖。
注意 :此时 sampled 和 greedy_sampled 可能指向同一块内存(out 的效果),但这是有意为之。
返回值分析:
- 混合 batch:贪心请求得到 argmax 结果,随机请求得到随机采样结果
processed_logprobs:从topk_topp_sampler返回的,可能为 None
3.12 compute_logprobs 静态方法(第 246-247 行)
python
@staticmethod
def compute_logprobs(logits: torch.Tensor) -> torch.Tensor:
return logits.log_softmax(dim=-1, dtype=torch.float32)
第 246-247 行:计算 log softmax。
log_softmax(dim=-1, dtype=torch.float32):沿最后一维(词表维度)计算 log softmax,强制使用 float32 精度。
dtype=torch.float32 :即使 logits 是 float16,计算也在 float32 下进行,保证数值稳定性。PyTorch 的 log_softmax 在内部会先转为 float32 计算,再转回输入类型。显式指定 dtype=torch.float32 确保输出也是 float32。
数学公式 :log_softmax(x_i) = x_i - log(sum(exp(x_j))),等价于 softmax 后取 log,但数值更稳定(直接计算 softmax 再 log 可能导致 log(0) = -inf 的问题)。
用途:
forward中计算raw_logprobs(Step 1a)sample中计算processed_logprobs(全贪心 +processed_logprobs模式)
3.13 gather_logprobs 静态方法(第 249-293 行)
python
@staticmethod
def gather_logprobs(
logprobs: torch.Tensor,
num_logprobs: int,
token_ids: torch.Tensor,
) -> LogprobsTensors:
第 249-253 行:方法签名。
参数 logprobs :[num_tokens, vocab_size],log softmax 后的 logprobs。
参数 num_logprobs:要收集的 Top-K 数量。
参数 token_ids :[num_tokens],采样 token 或 prompt token 的 ID(int64)。
python
"""
Gather logprobs for topk and sampled/prompt token.
Args:
logprobs: (num tokens) x (vocab) tensor
num_logprobs: maximum number of logprobs to
retain per token
token_ids: prompt tokens (if prompt logprobs)
or sampled tokens (if sampled
logprobs); 1D token ID tensor
with (num tokens) elements
Must be int64.
Returns:
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
Sampled token rank tensor, (num tokens)
"""
第 254-273 行:详细文档。注意:
- 输出形状
(num tokens) x (num_logprobs + 1)------+1 是因为采样/prompt token 的 logprob 被拼接到 Top-K 之前 token_ids必须是 int64
python
assert token_ids.dtype == torch.int64
第 274 行 :断言 token_ids 为 int64。这是因为 logprobs.gather(-1, index) 要求 index 为 int64,否则 PyTorch 会抛出错误。
python
# Find the topK values.
topk_logprobs, topk_indices = torch.topk(logprobs, num_logprobs, dim=-1)
第 276-277 行:执行 Top-K 选择。
torch.topk(input, k, dim):沿指定维度取前 K 大的值和索引。
返回值:
topk_logprobs:[num_tokens, num_logprobs],Top-K 的 logprob 值topk_indices:[num_tokens, num_logprobs],Top-K 对应的 token ID
注意 :topk 返回的是降序排列的------第一列是最大值,最后一列是第 K 大值。
python
# Get with the logprob of the prompt or sampled token.
token_ids = token_ids.unsqueeze(-1)
token_logprobs = logprobs.gather(-1, token_ids)
第 279-281 行:获取采样/prompt token 的 logprob。
token_ids.unsqueeze(-1) :从 [num_tokens] 变为 [num_tokens, 1],为 gather 操作准备索引形状。
logprobs.gather(-1, token_ids) :沿最后一维按索引取值。返回 [num_tokens, 1],每个位置是对应 token 的 logprob。
为什么不在 topk 中直接查找:采样 token 可能不在 Top-K 中(如果它的 logprob 较低),因此需要单独 gather。
python
# Compute the ranks of the actual token.
# Avoid 0/1 specialization recompile on the batch dimension
# of the compiled batched_count_greater_than. mark_unbacked makes
# the size fully symbolic so dynamo doesn't specialize when
# batch_size transitions from 1 to >=2.
torch._dynamo.decorators.mark_unbacked(logprobs, 0)
torch._dynamo.decorators.mark_unbacked(token_logprobs, 0)
token_ranks = batched_count_greater_than(logprobs, token_logprobs)
第 283-288 行:计算采样 token 的排名。
mark_unbacked :这是 PyTorch Dynamo(torch.compile 的前端)的内部 API。mark_unbacked(tensor, dim) 将指定维度标记为"非编译时已知",使 Dynamo 不会在 batch_size 从 1 变为 2 时触发重新编译。
为什么需要 :batched_count_greater_than 内部使用 (x >= values).sum(-1)。Dynamo 在编译时会尝试特化(specialize)张量的形状。当 batch_size=1 时编译的 kernel 在 batch_size=2 时不能直接使用。mark_unbacked 使 batch 维度保持符号化,避免重复编译。
batched_count_greater_than(logprobs, token_logprobs) :计算每行中有多少 logprob 值 >= 采样 token 的 logprob。由于 logprobs 是降序排列的,这个计数就是采样 token 的排名(0-indexed:最大 token 的排名为 1,因为自身也满足 >= 条件...等等,实际上需要仔细看)。
排名的精确定义 :(x >= values).sum(-1) 其中 values 是采样 token 的 logprob。这计算的是有多少 token 的 logprob >= 采样 token 的 logprob。如果采样 token 的 logprob 是最大的,只有它自己满足条件,排名为 1。如果它是第 3 大的,排名为 3。
注意 :这是 >= 而非 >,所以如果有多个 token 的 logprob 与采样 token 相同,它们都会被计入。
python
# Concatenate together with the topk.
indices = torch.cat((token_ids, topk_indices), dim=1)
logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)
第 290-292 行:拼接采样 token 和 Top-K。
拼接顺序:采样/prompt token 在前(第 0 列),Top-K 在后(第 1 到 num_logprobs 列)。
形状 :[num_tokens, num_logprobs + 1]。
可能的重复 :如果采样 token 在 Top-K 中,它会出现两次(第 0 列和 Top-K 中的某个位置)。这由下游的 LogprobsProcessor 去重处理。
python
# Use int32 to reduce the tensor size.
indices = indices.to(torch.int32)
return LogprobsTensors(indices, logprobs, token_ranks)
第 294-297 行:类型转换并返回。
indices.to(torch.int32):将 token ID 从 int64 转为 int32,节省一半存储。
LogprobsTensors(indices, logprobs, token_ranks):
indices:[num_tokens, num_logprobs + 1],int32logprobs:[num_tokens, num_logprobs + 1],float32token_ranks:[num_tokens],int64(由batched_count_greater_than返回)
3.14 _combine_outputs_with_spec_tokens 静态方法(第 295-303 行)
python
@staticmethod
def _combine_outputs_with_spec_tokens(
output_token_ids: list[list[int]],
spec_token_ids: list[list[int]] | None = None,
) -> list[list[int]]:
if spec_token_ids is None:
return output_token_ids
return [
[*out, *spec] if spec else out
for out, spec in zip(output_token_ids, spec_token_ids)
]
第 295-303 行:将投机解码的 draft token 合并到已生成 token 列表中。
前导下划线:表示内部方法,不对外暴露。
参数 output_token_ids :list[list[int]],每个请求已生成的 token ID 列表。
参数 spec_token_ids :list[list[int]] | None,每个请求的投机(draft)token ID 列表。
spec_token_ids is None:如果没有投机 token,直接返回原始列表。
列表推导式 :对每个请求,如果存在投机 token,将其追加到输出 token 后面。[*out, *spec] 是 Python 的解包合并语法。
条件 if spec else out:如果某个请求没有投机 token(空列表),不追加。这处理了异构 batch 的情况------部分请求可能有投机 token,部分没有。
用途:在惩罚计算中,已生成 token 集合应包含投机 token。因为如果模型已经(通过投机解码)生成了某些 token,这些 token 的重复/频率/出现惩罚应该被计算。
3.15 apply_logits_processors 方法(第 305-333 行)
python
def apply_logits_processors(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
predict_bonus_token: bool,
) -> torch.Tensor:
第 305-309 行:方法签名。此方法编排 Steps 3-6 的 logits 预处理管线。
python
bad_words_token_ids = sampling_metadata.bad_words_token_ids
any_penalties_or_bad_words = (
bool(bad_words_token_ids) or not sampling_metadata.no_penalties
)
第 310-313 行:预计算条件标志。
bad_words_token_ids:提取到局部变量,减少属性访问。
any_penalties_or_bad_words:是否存在任何惩罚或坏词。用于决定是否需要合并投机 token。
bool(bad_words_token_ids):空字典为 False,非空为 True。
not sampling_metadata.no_penalties :no_penalties=True 表示没有惩罚,取反表示有惩罚。
python
output_token_ids = sampling_metadata.output_token_ids
if predict_bonus_token and any_penalties_or_bad_words:
# Combine base outputs with spec tokens when speculative decoding
# is enabled.
output_token_ids = self._combine_outputs_with_spec_tokens(
output_token_ids,
sampling_metadata.spec_token_ids,
)
第 315-320 行:投机 token 合并。
条件 predict_bonus_token and any_penalties_or_bad_words:仅在投机解码模式下且存在惩罚/坏词时才合并。如果没有任何惩罚或坏词,投机 token 不影响 logits 处理,无需合并。
优化意义:避免不必要的列表合并操作。
python
# Apply allowed token ids.
if sampling_metadata.allowed_token_ids_mask is not None:
logits.masked_fill_(sampling_metadata.allowed_token_ids_mask, float("-inf"))
第 322-324 行 :Step 3 --- 白名单掩码。
allowed_token_ids_mask :[max_batch_size, vocab_size] 的布尔张量。注意语义------True 表示不允许的 token(需要被掩码的位置),这是反向的(通常 True 表示有效,但这里是 True 表示无效)。这可能与构建掩码时的效率考量有关。
masked_fill_(mask, value) :原地操作,将 mask 为 True 的位置填充为 -inf。
数学效果 :-inf 经过 softmax 后概率为 0,等效于完全排除这些 token。
为什么在坏词之前:白名单是更严格的约束------如果 token 不在白名单中,无论是否是坏词都应该排除。
python
# Apply bad words exclusion.
if bad_words_token_ids:
apply_bad_words(logits, bad_words_token_ids, output_token_ids)
第 326-328 行 :Step 4 --- 坏词排除。
apply_bad_words :定义在 ops/bad_words.py 中,通过前缀匹配实现坏词排除。对于每个坏词,检查已生成 token 序列的末尾是否匹配坏词的前缀部分,如果匹配则将坏词最后一个 token 的 logit 设为 -inf。
output_token_ids:已生成的 token 列表,用于前缀匹配。注意这可能已经合并了投机 token。
python
# Apply logits processors which can impact greedy sampling.
for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
logits = processor.apply(logits)
第 330-332 行 :Step 5 --- 非 argmax 不变处理器。
non_argmax_invariant:会影响 argmax 结果的处理器,必须在贪心采样之前应用。包括:
MinTokensLogitsProcessor:强制生成至少 N 个 token(将 EOS logit 设为-inf)LogitBiasLogitsProcessor:对特定 token 加偏置(可能改变 argmax)
循环应用:每个处理器的输出作为下一个的输入,形成链式处理。
python
# Apply penalties (e.g., freq_penalties).
logits = self.apply_penalties(logits, sampling_metadata, output_token_ids)
return logits
第 334-336 行 :Step 6 --- 惩罚施加。
委托给 apply_penalties 静态方法,该方法进一步委托给 ops/penalties.apply_all_penalties。
返回 logits:经过 Steps 3-6 完整处理后的 logits。
3.16 apply_penalties 静态方法(第 338-349 行)
python
@staticmethod
def apply_penalties(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
output_token_ids: list[list[int]],
) -> torch.Tensor:
第 338-342 行:方法签名。
参数 output_token_ids :独立传递而非从 sampling_metadata 获取,因为在投机解码场景下可能已经被合并了投机 token。
python
if sampling_metadata.no_penalties:
return logits
第 343-344 行:快速路径------如果没有惩罚,直接返回 logits,跳过所有惩罚计算。
no_penalties :这是一个预计算标志,由 Scheduler 在构建 SamplingMetadata 时设置。如果所有请求的三种惩罚值都为默认值(0),则 no_penalties=True。这是一个重要的性能优化------大多数推理请求不使用惩罚,跳过可以避免不必要的张量操作。
python
assert sampling_metadata.prompt_token_ids is not None
第 345 行:断言 prompt_token_ids 存在。惩罚计算需要区分 prompt token 和生成 token:
- repetition_penalty:对所有已出现的 token(prompt + output)统一缩放
- frequency_penalty:仅对 output token 的出现次数线性递减
- presence_penalty:仅对 output token 是否出现做常数递减
因此 prompt_token_ids 是必须的。
python
return apply_all_penalties(
logits,
sampling_metadata.prompt_token_ids,
sampling_metadata.presence_penalties,
sampling_metadata.frequency_penalties,
sampling_metadata.repetition_penalties,
output_token_ids,
)
第 346-352 行 :委托给 ops/penalties.apply_all_penalties。
传入参数:
logits:[batch_size, vocab_size]prompt_token_ids:[batch_size, max_prompt_len],prompt tokenpresence_penalties:[batch_size],出现惩罚系数frequency_penalties:[batch_size],频率惩罚系数repetition_penalties:[batch_size],重复惩罚系数output_token_ids:list[list[int]],已生成 token
apply_all_penalties 内部逻辑 (在 ops/penalties.py 中):
- 将
output_token_ids转为填充张量(make_tensor_with_pad) - 将 -1 占位符替换为
vocab_size(无效 token ID) - 调用
vllm.model_executor.layers.utils.apply_penalties执行三种惩罚
3.17 完整 9 步管线回顾
将 forward 方法的流程与类文档中描述的 9 步对应:
| 步骤 | 对应代码 | 操作 | 输入→输出 |
|---|---|---|---|
| 1 | forward 第73-79行 |
计算原始 logprobs | logits → raw_logprobs |
| 2 | forward 第81行 |
logits → float32 | logits → float32 logits |
| 3 | apply_logits_processors 第322-324行 |
白名单掩码 | logits → filtered logits |
| 4 | apply_logits_processors 第326-328行 |
坏词排除 | logits → clean logits |
| 5 | apply_logits_processors 第330-332行 |
非argmax不变处理器 | logits → processed logits |
| 6 | apply_logits_processors 第334行 → apply_penalties |
惩罚 | logits → penalized logits |
| 7a | sample 第210-220行 |
贪心采样/快速返回 | logits → greedy_sampled |
| 7b | sample 第223-225行 |
温度缩放 | logits → scaled logits |
| 7c | sample 第227-229行 |
argmax不变处理器 | logits → filtered logits |
| 7d-e | sample 第231-235行 |
Top-K/Top-P + 随机采样 | logits → random_sampled |
| 7f | sample 第240-244行 |
混合贪心/随机 | → sampled |
| 8 | forward 第109-112行 |
收集 Top-K logprobs | raw_logprobs → logprobs_tensors |
| 9 | forward 第120-127行 |
构建 SamplerOutput | → SamplerOutput |
3.18 关键设计决策总结
-
原始 vs 处理后 logprobs:V1 选择返回基于原始 logits 的 logprobs(V0 返回处理后的),这更符合"模型真实置信度"的语义
-
贪心/随机混合批处理 :通过
torch.where+out=greedy_sampled实现零额外内存的混合采样 -
两阶段处理器分类:argmax 不变 vs 非不变,确保贪心采样不受 min-p 等处理器影响
-
投机 token 合并的惰性策略:仅在存在惩罚/坏词时才合并投机 token,避免不必要的列表操作
-
int32 ↔ int64 的精确转换时机:仅在需要作为索引时使用 int64,存储和传输时使用 int32
-
原地操作优先 :
div_()、masked_fill_()、out=greedy_sampled等减少内存分配 -
Triton fused kernel :
gather_specific_token_logprobs使用融合的 log_softmax+gather,比分离操作快 1.4 倍 -
Dynamo 编译优化 :
mark_unbacked避免 batch_size 变化时的重复编译
本部分完成。第二部分将覆盖 metadata.py 逐行解析、ops/ 算子层深度分析、以及 logits_processor/ 处理器层详解。