【vllm】(六)vLLM v1 Sample — 模块超深度分析之一

vLLM v1 Sample 采样模块超深度分析

分析对象:vllm/vllm/v1/sample

代码规模:14 Python 文件(含2个__init__.py),4,133 行有效代码

分析日期:2026-04-22


vLLM v1 Sample 采样模块超深度逐行分析 --- 第一部分:核心层


一、模块定位

1.1 业务职责

vLLM v1 的 sample 模块是整个推理引擎中从模型输出到最终采样决策的关键桥梁。它的核心业务职责可以用一句话概括:

将模型输出的 logits 张量,经过一系列确定性与随机性变换,转换为每个请求的下一个生成 token。

这看似简单的"从连续分布到离散决策",实际包含了 vLLM 作为生产级推理框架所必须支持的全部采样控制能力------温度调节、Top-K/Top-P 截断、惩罚机制(重复/频率/出现)、坏词排除、白名单约束、自定义 logits 处理器、最小 token 数约束、min-p 过滤、思考 token 预算控制等。

1.2 功能定位

从功能维度看,sample 模块承担以下六项核心功能:

功能 说明
Logits 预处理 精度提升(→float32)、白名单掩码、坏词排除、非 argmax 不变处理器
惩罚施加 重复惩罚、频率惩罚、出现惩罚------三种基于已生成 token 的调节机制
采样决策 贪心采样 vs 随机采样的分支路由,温度缩放,Top-K/Top-P 联合过滤
Logprobs 计算 原始 logprobs(用于 top-k 展示)与处理后 logprobs(用于采样后输出)的分离计算
拒绝采样 面向投机解码场景的 draft token 验证与拒绝采样
批处理编排 在单次 forward pass 中同时处理混合的贪心/随机请求,高效利用 GPU 并行性

1.3 边界与范围

属于 sample 模块的:

  • 从原始 logits 到 sampled token 的全部变换链路
  • 采样参数的元数据封装(SamplingMetadata
  • Logits 处理器的抽象接口与内置实现
  • Top-K/Top-P 的 Triton 高性能内核
  • 投机解码的拒绝采样逻辑

不属于 sample 模块的:

  • logits 的产生(属于模型前向传播 model_runner
  • 采样参数的解析与构建(属于 sampling_paramsscheduler
  • 采样结果的后处理与文本解码(属于 detokenizer
  • KV cache 管理(属于 kv_cache_manager

1.4 在 vLLM 架构中的位置

复制代码
┌─────────────────────────────────────────────────────────────────┐
│                    vLLM v1 推理引擎                              │
│                                                                 │
│  ┌──────────┐    ┌──────────────┐    ┌───────────┐              │
│  │Scheduler │───>│ Model Runner │───>│  Sample   │              │
│  │          │    │  (forward)   │    │  Module   │              │
│  │ 构建请求  │    │  产生 logits │    │ logits→token│             │
│  └──────────┘    └──────────────┘    └─────┬─────┘              │
│       │                                     │                    │
│       │         ┌──────────────┐            │                    │
│       └────────>│  Sampling    │<───────────┘                    │
│                 │  Metadata    │  sampled_token_ids              │
│                 │  参数打包     │  + logprobs_tensors             │
│                 └──────────────┘            │                    │
│                                             v                    │
│                                      ┌─────────────┐            │
│                                      │  Scheduler   │            │
│                                      │  结果收集     │            │
│                                      └─────────────┘            │
└─────────────────────────────────────────────────────────────────┘

sample 模块位于 Model Runner 的下游,Scheduler 的结果回收环节上游。它是纯计算模块------接收 logits + metadata,输出 sampled tokens + logprobs,不持有任何跨步持久状态(除 logits 处理器外)。

1.5 核心商业价值

  1. 推理质量保障:9 步采样管线确保模型输出的可控性,从温度到惩罚的每一步都有精确的数学保证
  2. 性能极致优化:Triton 内核 + FlashInfer 集成 + 批量混合采样,单次 forward 完成异构请求处理
  3. 扩展性设计:LogitsProcessor 插件体系支持自定义采样策略,通过 entry_points 机制实现零侵入扩展
  4. 投机解码支持:RejectionSampler 使 vLLM 能在投机解码场景下验证 draft token,大幅提升吞吐

二、模块整体结构

2.1 文件职责矩阵

sample 模块共包含 14 个 Python 文件 (含 2 个 __init__.py),总计 4133 行代码,分布在 3 个层级目录中:

# 文件路径 行数 架构层级 核心职责
1 sample/__init__.py 0 包初始化 空文件,模块包标记
2 sample/sampler.py 410 核心层 Sampler 主类,采样管线编排,9步采样流程入口
3 sample/metadata.py 49 核心层 SamplingMetadata 数据类,采样参数的批量化封装
4 sample/rejection_sampler.py 850 投机解码层 RejectionSampler,draft token 验证与拒绝采样
5 sample/ops/__init__.py 0 算子层 空文件,ops 子包标记
6 sample/ops/bad_words.py 57 算子层 坏词排除算子,前缀匹配 + logits 掩码
7 sample/ops/logprobs.py 27 算子层 logprobs 工具函数,批量计数排名
8 sample/ops/penalties.py 57 算子层 惩罚算子,repetition/frequency/presence 三合一
9 sample/ops/topk_topp_sampler.py 402 算子层 TopK/TopP 采样器,多后端(CUDA/CPU/HIP/FlashInfer)
10 sample/ops/topk_topp_triton.py 1057 算子层 TopK/TopP Triton 内核,GPU 高性能实现
11 sample/logits_processor/__init__.py 360 处理器层 LogitsProcessor 加载/构建/插件机制
12 sample/logits_processor/interface.py 106 处理器层 LogitsProcessor 抽象接口与 BatchUpdate 定义
13 sample/logits_processor/state.py 165 处理器层 BatchUpdateBuilder + LogitsProcessors 容器
14 sample/logits_processor/builtin.py 593 处理器层 4 个内置处理器(MinP/LogitBias/MinTokens/ThinkingTokenBudget)

2.2 三层架构

复制代码
┌─────────────────────────────────────────────────────────┐
│                    核心层 (Core)                          │
│  ┌──────────────┐        ┌───────────────────┐          │
│  │  sampler.py  │───────>│   metadata.py     │          │
│  │  采样管线编排  │ <────── │  参数元数据封装    │          │
│  │  410 行       │        │  49 行             │          │
│  └──────┬───────┘        └───────────────────┘          │
│         │ 调用                                           │
└─────────┼───────────────────────────────────────────────┘
          │
          v
┌─────────────────────────────────────────────────────────┐
│                  算子层 (Operators)                        │
│  ┌──────────────┐ ┌────────────┐ ┌──────────────────┐  │
│  │ penalties.py │ │ bad_words  │ │ topk_topp        │  │
│  │ 惩罚算子      │ │ 坏词排除    │ │ sampler.py       │  │
│  │ 57 行        │ │ 57 行      │ │ 402 行           │  │
│  └──────────────┘ └────────────┘ └──────────────────┘  │
│  ┌──────────────┐ ┌────────────────────────────────┐   │
│  │ logprobs.py  │ │ topk_topp_triton.py            │   │
│  │ 排名计数      │ │ Triton GPU 内核                │   │
│  │ 27 行        │ │ 1057 行                        │   │
│  └──────────────┘ └────────────────────────────────┘   │
└─────────────────────────────────────────────────────────┘
          ▲
          │ 被调用
┌─────────────────────────────────────────────────────────┐
│                处理器层 (Processors)                       │
│  ┌──────────────────┐ ┌──────────────────────────────┐  │
│  │ interface.py     │ │ state.py                     │  │
│  │ 抽象接口+BatchUpd│ │ BatchUpdateBuilder+LogitsProc │  │
│  │ 106 行           │ │ 165 行                       │  │
│  └──────────────────┘ └──────────────────────────────┘  │
│  ┌──────────────────┐ ┌──────────────────────────────┐  │
│  │ builtin.py       │ │ __init__.py                  │  │
│  │ 4个内置处理器      │ │ 加载/构建/插件机制            │  │
│  │ 593 行           │ │ 360 行                       │  │
│  └──────────────────┘ └──────────────────────────────┘  │
└─────────────────────────────────────────────────────────┘
          ‖ 并行
┌─────────────────────────────────────────────────────────┐
│               投机解码层 (Speculative)                    │
│  ┌───────────────────────────────────────────────────┐  │
│  │ rejection_sampler.py                              │  │
│  │ RejectionSampler + 拒绝采样内核                    │  │
│  │ 850 行                                            │  │
│  └───────────────────────────────────────────────────┘  │
└─────────────────────────────────────────────────────────┘

2.3 类层次结构

复制代码
nn.Module
├── Sampler                          [核心] 采样管线主编排器
│   └── 内含 TopKTopPSampler         [算子] Top-K/Top-P 采样器
│
└── RejectionSampler                 [投机解码] draft token 验证
    └── 内含 Sampler                  复用主采样器

LogitsProcessor (ABC)                [处理器] 抽象基类
├── MinPLogitsProcessor              [内置] min-p 过滤(argmax不变)
├── LogitBiasLogitsProcessor         [内置] logit 偏置(argmax不变)
│   └── 继承 AdapterLogitsProcessor  每请求适配器模式
├── MinTokensLogitsProcessor         [内置] 最小token数约束(非argmax不变)
│   └── 继承 AdapterLogitsProcessor
├── ThinkingTokenBudgetLogitsProcessor [内置] 思考token预算(非argmax不变)
│   └── 继承 AdapterLogitsProcessor
└── [用户自定义] 通过 entry_points 插件加载

LogitsProcessors                     [容器] 聚合所有处理器实例
├── .argmax_invariant: list          argmax不变处理器列表
└── .non_argmax_invariant: list      非argmax不变处理器列表

SamplingMetadata (dataclass)         [元数据] 采样参数批量化封装

BatchUpdateBuilder                   [状态] 批更新构建器
BatchUpdate (frozen dataclass)       [状态] 批更新数据

2.4 依赖注入关系

sample 模块的依赖注入遵循"构造时注入,运行时消费"的模式:

复制代码
Sampler.__init__(logprobs_mode)
  ├── 创建 TopKTopPSampler(logprobs_mode)     ← 算子层注入
  └── 查询 is_pin_memory_available()           ← 平台能力注入

Sampler.forward(logits, sampling_metadata)
  ├── sampling_metadata                        ← 由 Scheduler 构建,运行时注入
  │   ├── .logitsprocs: LogitsProcessors       ← 处理器集合,由 build_logitsprocs() 构建
  │   │   ├── .argmax_invariant[]              ← 运行时分离的两类处理器
  │   │   └── .non_argmax_invariant[]
  │   ├── .temperature, .top_k, .top_p         ← 采样参数张量
  │   ├── .generators                          ← 随机数生成器
  │   ├── .penalties_*                         ← 惩罚参数
  │   └── .bad_words_token_ids                 ← 坏词列表
  └── logits                                   ← 由 ModelRunner 产生,运行时注入

build_logitsprocs(vllm_config, device, ...)
  ├── BUILTIN_LOGITS_PROCESSORS                ← 硬编码内置列表
  ├── _load_logitsprocs_plugins()              ← entry_points 插件发现
  └── _load_logitsprocs_by_fqcns()             ← FQCN 字符串加载

2.5 核心方法清单

Sampler 类方法
方法 类型 行数范围 核心职责
__init__ 构造 55-59 初始化 TopKTopPSampler、pin_memory、logprobs_mode
forward 主入口 61-122 9步采样管线编排,logits→SamplerOutput
gather_specific_token_logprobs 辅助 124-185 指定 token ID 的 logprobs 高效计算
apply_temperature 静态 187-194 温度缩放,in-place 除法
greedy_sample 静态 196-197 argmax 贪心采样
sample 核心 199-244 采样决策:贪心/随机分支路由
compute_logprobs 静态 246-247 log_softmax 计算
gather_logprobs 静态 249-293 Top-K logprobs 收集与排名计算
_combine_outputs_with_spec_tokens 静态 295-303 投机解码输出合并
apply_logits_processors 核心 305-333 Logits 预处理管线(白名单→坏词→处理器→惩罚)
apply_penalties 静态 335-349 惩罚施加委托
SamplingMetadata 字段
字段 类型 说明
temperature `Tensor None`
all_greedy bool 是否全部贪心
all_random bool 是否全部随机
top_p `Tensor None`
top_k `Tensor None`
generators dict[int, Generator] 每请求随机数生成器
max_num_logprobs `int None`
no_penalties bool 是否无惩罚
prompt_token_ids `Tensor None`
frequency_penalties Tensor 频率惩罚
presence_penalties Tensor 出现惩罚
repetition_penalties Tensor 重复惩罚
output_token_ids list[list[int]] 已生成 token
allowed_token_ids_mask `Tensor None`
bad_words_token_ids dict[int, list[list[int]]] 坏词列表
logitsprocs LogitsProcessors Logits 处理器集合
logprob_token_ids `dict[int, list[int]] None`
spec_token_ids `list[list[int]] None`

2.6 数据流出入

输入
复制代码
logits: Tensor[batch_size, vocab_size]        ← 模型输出,可能是 float16/bfloat16
sampling_metadata: SamplingMetadata            ← Scheduler 构建的参数包
predict_bonus_token: bool = False              ← 投机解码标记
logprobs_mode_override: LogprobsMode | None    ← 运行时模式覆盖
输出
复制代码
SamplerOutput:
  ├── sampled_token_ids: Tensor[batch_size, 1]   ← int32,每请求一个采样 token
  └── logprobs_tensors: LogprobsTensors | None    ← 可选的 logprobs 信息
        ├── logprob_token_ids: Tensor             ← int32 token ID 索引
        ├── logprobs: Tensor                      ← float32 logprobs 值
        └── selected_token_ranks: Tensor          ← 采样 token 排名
内部数据流(9步管线)
复制代码
logits ──┐
         │ Step 1: 计算原始 logprobs (raw_logprobs/raw_logits)
         v
    raw_logprobs ──┐
                   │ Step 2: logits → float32
                   v
              float32_logits ──┐
                               │ Step 3: 白名单掩码
                               v
                        filtered_logits ──┐
                                          │ Step 4: 坏词排除
                                          v
                                   clean_logits ──┐
                                                   │ Step 5: 非argmax不变处理器
                                                   v
                                            proc_logits ──┐
                                                           │ Step 6: 惩罚
                                                           v
                                                    penalized_logits ──┐
                                                                       │ Step 7: 采样决策
                                                                       v
                                                    sampled + processed_logprobs ──┐
                                                                                    │ Step 8: gather logprobs
                                                                                    v
                                                                             logprobs_tensors ──┐
                                                                                                │ Step 9: 构建输出
                                                                                                v
                                                                                          SamplerOutput

三、sampler.py 逐行深度解析

本节对 sampler.py(410 行)进行真正的逐行分析。每一行代码、每一个分支、每一个变量都会被解释。


3.1 文件头与许可证(第 1-3 行)

python 复制代码
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

第 1 行:SPDX 许可证标识符,声明本文件采用 Apache 2.0 许可证。SPDX(Software Package Data Exchange)是标准化的许可证标识方式,被 Linux 基金会推荐使用。

第 2 行:SPDX 版权声明,版权归 vLLM 项目的所有贡献者所有。注意这里使用"contributors to the vLLM project"而非单一实体,体现了开源社区项目性质。


3.2 模块文档字符串(第 3 行)

python 复制代码
"""A layer that samples the next tokens from the model's outputs."""

模块级文档字符串,一句话定义了本模块的核心功能:从模型输出中采样下一个 token 的层 。注意关键词"a layer"------这里把 Sampler 设计为一个 nn.Module 层,使其能被嵌入 PyTorch 的模型计算图中。


3.3 导入区(第 5-14 行)

python 复制代码
import torch
import torch.nn as nn

第 5 行import torch --- 导入 PyTorch 核心库。用于张量操作、设备管理、自动微分等。在 sampler 中主要使用:张量类型转换(.to().long())、张量操作(.argmax().topk().where().gather().unsqueeze())、张量掩码(.masked_fill_())、张量拼接(.cat())。

第 6 行import torch.nn as nn --- 导入 PyTorch 神经网络模块。Sampler 继承自 nn.Module,因此需要此导入。nn.Module 赋予 Sampler 以下能力:

  • 被纳入模型的 .modules()
  • 参数/缓冲区管理(虽然 Sampler 没有可训练参数)
  • 设备转移(.to(device)
  • torch.compile 兼容
python 复制代码
from vllm.config.model import LogprobsMode

第 8 行 :从 vllm.config.model 导入 LogprobsMode。这是一个类型别名(Literal 类型),定义了 logprobs 的计算模式。可能的取值包括:

  • "raw_logprobs":在 logits 经任何处理之前,先用 log_softmax 计算原始 logprobs
  • "raw_logits":直接克隆原始 logits(转换为 float32)作为 logprobs
  • "processed_logits":使用经过处理(惩罚、温度等)后的 logits
  • "processed_logprobs":使用经过处理后 logits 的 log_softmax

这个模式控制的是 Step 1 中用什么作为最终返回的 logprobs 基底。

python 复制代码
from vllm.utils.platform_utils import is_pin_memory_available

第 9 行 :从 vllm.utils.platform_utils 导入 is_pin_memory_available。此函数检测当前平台是否支持 CUDA pinned memory(锁页内存)。Pinned memory 允许 CPU→GPU 的 DMA 传输,提高数据搬运速度。在 penalty 计算中,output_token_ids 从 CPU 张量传输到 GPU 时会用到 pin_memory。

python 复制代码
from vllm.v1.outputs import LogprobsTensors, SamplerOutput

第 10 行 :从 vllm.v1.outputs 导入两个输出数据结构:

  • LogprobsTensors:封装 logprobs 相关的张量三元组(token_ids、logprobs、ranks)
  • SamplerOutput:Sampler.forward() 的最终返回类型,包含 sampled_token_idslogprobs_tensors
python 复制代码
from vllm.v1.sample.metadata import SamplingMetadata

第 11 行 :从本模块的 metadata.py 导入 SamplingMetadata。这是采样参数的批量化数据类,将每个请求的采样参数(temperature、top_k、top_p、penalties 等)打包为张量,支持批量并行处理。详见第 2.5 节。

python 复制代码
from vllm.v1.sample.ops.bad_words import apply_bad_words

第 12 行 :从 ops/bad_words.py 导入 apply_bad_words 函数。此函数在 logits 上执行坏词排除------检查已生成的前缀是否匹配坏词的前缀部分,如果匹配则将坏词最后一个 token 的 logit 设为 -inf。这是 Step 4 的实现。

python 复制代码
from vllm.v1.sample.ops.logprobs import batched_count_greater_than

第 13 行 :从 ops/logprobs.py 导入 batched_count_greater_than。此函数计算每个 batch 行中有多少元素的值大于等于给定阈值,用于计算采样 token 的排名。在 gather_logprobs 方法中使用。该函数使用 torch.compile 优化,避免内存拷贝问题。

python 复制代码
from vllm.v1.sample.ops.penalties import apply_all_penalties

第 14 行 (此处指导入区的最后一行):从 ops/penalties.py 导入 apply_all_penalties。此函数一次性应用三种惩罚(repetition、frequency、presence),委托给 vllm.model_executor.layers.utils.apply_penalties 实现。这是 Step 6 的算子实现。

python 复制代码
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler

紧随其后的导入------从 ops/topk_topp_sampler.py 导入 TopKTopPSampler 类。这是 Top-K 和 Top-P 联合采样的核心算子,支持多后端(原生 PyTorch、CUDA、CPU、HIP/AMD、FlashInfer)。在 Sampler 构造时创建实例,作为 Step 7d 的执行者。

python 复制代码
from vllm.v1.worker.gpu.sample.logprob import compute_token_logprobs

v1/worker/gpu/sample/logprob.py 导入 compute_token_logprobs。这是一个 Triton JIT 编译的 fused kernel,将 log_softmax + gather 融合为一次 GPU 计算。在 gather_specific_token_logprobs 方法中使用,用于高效计算指定 token 的 logprobs(避免计算整个词表的 log_softmax)。


3.4 模块常量(第 16 行)

python 复制代码
_SAMPLING_EPS = 1e-5

第 16 行 :定义采样精度阈值常量 _SAMPLING_EPS = 1e-5(0.00001)。

用途:在两处使用------

  1. apply_temperature:当温度低于此阈值时,视为贪心采样(温度 → 1.0,不缩放)
  2. sample:在混合贪心/随机采样时,用 temperature < _SAMPLING_EPS 判断某个请求是否应走贪心路径

设计考量 :不使用 0 而使用极小正值的原因------浮点数精度问题。temperature 可能是 1e-6 这样接近零但非零的值,数学上等价于贪心采样,但直接除以它会引入数值不稳定。将 < EPS 的温度替换为 1.0,使得 logits / 1.0 = logits(不变),等效于贪心。

命名约定:前导下划线表示模块私有常量。


3.5 Sampler 类定义与文档字符串(第 18-53 行)

python 复制代码
class Sampler(nn.Module):

第 18 行 :Sampler 继承自 nn.Module。这是一个关键设计决策------将采样器设计为 PyTorch 层而非普通类,原因是:

  1. 模型图集成 :作为 nn.Module,Sampler 可以被纳入模型的 forward 计算流,与模型的其他层统一管理
  2. 设备一致性nn.Module.to(device) 会自动递归应用到子模块(如 TopKTopPSampler
  3. 编译兼容torch.compile 能正确处理 nn.Module 的调用图
  4. 序列化支持 :虽然 Sampler 没有可训练参数,但 nn.Module 的 state_dict 机制确保了检查点兼容性
python 复制代码
    """
    A layer that samples the next tokens from the model's outputs
    with the following steps in order:

    1. If logprobs are requested:
        a) If `logprobs_mode` is `raw_logprobs`, compute logprobs
           as the final logprobs to return.
        b) If `logprobs_mode` is `raw_logits`, clone the logits
           as the final logprobs to return.
    2. Convert logits to float32.
    3. Apply allowed token ids whitelist.
    4. Apply bad words exclusion.
    5. Apply logit processors which are not argmax-invariant,
       i.e. that can impact greedy sampling.
        a) Min tokens processor
        b) Logit bias processor
    6. Apply penalties
        a) Repetition penalty
        b) Frequency penalty
        c) Presence penalty
    7. Sample the next tokens. `sample` method performs the following steps:
        a) If not `all_random`, perform greedy sampling. If `all_greedy`,
           return the greedily sampled tokens and final logprobs if requested.
        b) Apply temperature.
        c) Apply logit processors which are argmax-invariant, by default
           the min_p processor.
        d) Apply top_k and/or top_p.
        e) Sample the next tokens with the probability distribution.
        f) If `all_random` or temperature >= epsilon (1e-5), return the
           randomly sampled tokens and final logprobs if requested. Else,
           return the greedily sampled tokens and logprobs if requested.
    8. Gather the logprobs of the top `max_num_logprobs` and sampled token
       (if requested). Note that if the sampled token is within the top
       `max_num_logprobs`, the logprob will be eventually merged in
       `LogprobsProcessor` during output processing. Therefore, the
       final output may contain either `max_num_logprobs + 1` or
       `max_num_logprobs` logprobs.
    9. Return the final `SamplerOutput`.
    """

第 19-53 行:类文档字符串,详细描述了 9 步采样管线。这是理解 Sampler 工作原理的最重要的文档。逐步解析:

Step 1 :如果请求了 logprobs(max_num_logprobs is not None),根据 logprobs_mode 决定如何准备原始 logprobs:

  • raw_logprobs:对原始 logits 执行 log_softmax,得到数学上精确的 log 概率
  • raw_logits:直接克隆原始 logits(转为 float32),用于需要原始 logit 值的场景

关键注释 :这里计算的 logprobs 是基于未经惩罚和温度缩放的原始 logits,这与 V0 sampler 的行为不同(V0 使用处理后的 logits)。这是一个有意的 API 变更------原始 logprobs 更能反映模型的"真实信念"。

Step 2 :将 logits 转换为 float32 精度。这是数值稳定性的关键------log_softmax、温度除法、Top-K/Top-P 比较等操作在 float16/bfloat16 下可能溢出或丢失精度。

Step 3 :应用 token 白名单掩码。将不在白名单中的 token logits 设为 -inf,等效于概率 0。通过 masked_fill_ 原地操作。

Step 4 :应用坏词排除。基于前缀匹配检查已生成序列,如果坏词前缀已出现,将坏词最后 token 的 logit 设为 -inf

Step 5 :应用非 argmax 不变的 logits 处理器。这些处理器会改变 argmax 结果(即贪心采样结果),因此必须在贪心采样之前应用:

  • MinTokens 处理器:当已生成 token 数 < 最小要求时,将 EOS token 的 logit 设为 -inf
  • LogitBias 处理器:对特定 token 加偏置值

Step 6:应用三种惩罚------重复惩罚(按比例缩放已出现 token 的 logits)、频率惩罚(线性递减)、出现惩罚(常数递减)。

Step 7:采样决策,内部又包含 6 个子步骤:

  • (a) 贪心采样分支:如果不是全部随机,先做 argmax;如果全部贪心则直接返回
  • (b) 温度缩放:logits /= temperature
  • © argmax 不变处理器:如 min_p 过滤(不影响贪心结果但影响随机采样分布)
  • (d) Top-K/Top-P 过滤
  • (e) 从概率分布中随机采样
  • (f) 根据温度决定最终使用贪心还是随机结果

Step 8 :收集 Top-K logprobs 和采样 token 的 logprobs。注意最终可能返回 max_num_logprobs + 1 个 logprobs(采样 token 可能不在 Top-K 中)。

Step 9 :构建并返回 SamplerOutput


3.6 __init__ 方法(第 55-59 行)

python 复制代码
    def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs"):

第 55 行 :构造函数签名。接受一个参数 logprobs_mode,默认值为 "raw_logprobs"

参数 logprobs_mode:LogprobsMode 类型(Literal 字符串类型),控制 logprobs 的计算方式:

  • "raw_logprobs":使用 log_softmax(logits) 计算,这是数学上精确的 log 概率
  • "raw_logits":直接使用原始 logit 值(转为 float32)
  • 默认 "raw_logprobs" 保持向后兼容性

默认值设计 :使用 "raw_logprobs" 作为默认值意味着大多数用户场景下返回的是标准化的 log 概率,而非原始 logit 值。

python 复制代码
        super().__init__()

第 56 行 :调用 nn.Module.__init__()。这初始化了 PyTorch 模块的底层机制:

  • _parameters(有序字典)
  • _buffers(有序字典)
  • _modules(有序字典)
  • _forward_hooks_forward_pre_hooks

虽然 Sampler 没有可训练参数,但 super().__init__() 仍然必须调用,否则 nn.Module 的基础设施不会被正确初始化。

python 复制代码
        self.topk_topp_sampler = TopKTopPSampler(logprobs_mode)

第 57 行 :创建 TopKTopPSampler 实例并赋值给 self.topk_topp_sampler

作为 nn.Module 属性 :因为 TopKTopPSampler 也是 nn.Module,将其赋值给 self.xxx 会自动注册到 self._modules 中。这意味着:

  • Sampler.modules() 会包含 TopKTopPSampler
  • Sampler.to(device) 会递归转移 TopKTopPSampler
  • Sampler.state_dict() 会包含 TopKTopPSampler 的缓冲区

参数传递logprobs_mode 传递给 TopKTopPSampler,因为采样器在随机采样时可能需要计算 processed logprobs。

python 复制代码
        self.pin_memory = is_pin_memory_available()

第 58 行:检测并缓存平台是否支持 pin_memory。

为什么缓存is_pin_memory_available() 可能涉及平台检测逻辑(查询 CUDA 是否可用),在每次需要时调用会产生不必要的开销。构造时查询一次,后续直接使用布尔值。

用途 :在 apply_penaltiesops/penalties.py_convert_to_tensors 中,将 output_token_ids 转为 CPU 张量时使用 pin_memory 加速 CPU→GPU 传输。

python 复制代码
        self.logprobs_mode = logprobs_mode

第 59 行 :将 logprobs_mode 存储为实例属性。

用途 :在 forward()sample() 方法中作为默认 logprobs 模式使用。forward() 允许通过 logprobs_mode_override 参数覆盖此默认值。


3.7 forward 方法(第 61-122 行)

这是 Sampler 的主入口方法,编排了完整的 9 步采样管线。

python 复制代码
    def forward(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
        predict_bonus_token: bool = False,
        logprobs_mode_override: LogprobsMode | None = None,
    ) -> SamplerOutput:

第 61-66 行:方法签名。

参数 logitsTensor[batch_size, vocab_size]。模型前向传播的输出,未经任何采样处理。数据类型可能是 float16 或 bfloat16(取决于模型配置和硬件)。

参数 sampling_metadataSamplingMetadata 实例。由 Scheduler 构建,包含当前 batch 中所有请求的采样参数。注意这是一个 dataclass,不是 nn.Module。

参数 predict_bonus_tokenbool,默认 False。用于投机解码场景。当为 True 时,如果存在惩罚或坏词,需要将 speculative token 合并到 output_token_ids 中再计算惩罚(因为惩罚基于已出现的 token 集合)。

参数 logprobs_mode_overrideLogprobsMode | None,默认 None。允许在运行时覆盖构造时设置的 logprobs_mode。如果提供,则使用此值;否则使用 self.logprobs_mode

返回值SamplerOutput 实例,包含 sampled_token_ids 和可选的 logprobs_tensors

python 复制代码
        logprobs_mode = logprobs_mode_override or self.logprobs_mode

第 67 行 :解析 logprobs 模式。优先使用运行时覆盖值,如果没有覆盖则使用构造时的默认值。Python 的 or 语义:如果 logprobs_mode_overrideNone(falsy),则回退到 self.logprobs_mode

python 复制代码
        # NOTE(woosuk): Use the original logits (before any penalties or
        # temperature scaling) for the top-k logprobs.
        # This is different from the V0 sampler, which uses the logits that
        # is used for sampling (after penalties and temperature scaling).

第 68-71 行:重要注释。由 woosuk 添加,说明了一个关键的 API 变更:

V0 vs V1 行为差异

  • V0:logprobs 基于处理后的 logits(经过惩罚和温度缩放),这意味着返回的 logprobs 反映了采样时的实际概率分布
  • V1:logprobs 基于原始 logits(未经处理),这更接近模型的"真实置信度"

设计理由:原始 logprobs 更有解释性------它反映的是模型对 token 的原始预测概率,不受采样参数的影响。用户可能想看到"模型认为这个 token 有多可能"而非"经过各种调节后这个 token 有多可能"。

python 复制代码
        num_logprobs = sampling_metadata.max_num_logprobs

第 72 行 :提取 max_num_logprobs 到局部变量,减少属性访问次数。

max_num_probs 的语义

  • None:不需要 logprobs
  • 0:仅需要采样 token 的 logprob
  • 正整数:需要 Top-K logprobs + 采样 token 的 logprob
  • -1:返回完整的未排序未排名 logprobs 张量
python 复制代码
        if num_logprobs is not None:
            if logprobs_mode == "raw_logprobs":
                raw_logprobs = self.compute_logprobs(logits)
            elif logprobs_mode == "raw_logits":
                if logits.dtype == torch.float32:
                    raw_logprobs = logits.clone()
                else:
                    raw_logprobs = logits.to(torch.float32)

第 73-79 行Step 1 --- 计算原始 logprobs

第 73 行if num_logprobs is not None --- 仅当请求了 logprobs 时才计算,避免不必要的计算开销。

第 74-75 行 :如果模式是 "raw_logprobs",调用 self.compute_logprobs(logits),即 logits.log_softmax(dim=-1, dtype=torch.float32)。这计算每个 token 在词表上的 log 概率。注意 log_softmax 在内部先减去最大值(数值稳定技巧),然后取 exp、求和、取 log,最后减去 log_sum_exp。

第 76-79 行 :如果模式是 "raw_logits"

  • 如果 logits 已经是 float32,直接 .clone()(避免修改原始张量)
  • 如果 logits 不是 float32(float16/bfloat16),用 .to(torch.float32) 转换

为什么不直接用 logits.clone().to(torch.float32) :当 logits 已经是 float32 时,.clone().to(torch.float32) 更高效------后者即使是相同类型也会创建新张量,但前者语义更清晰。这里的分支优化避免了不必要的 .to() 调用。

变量 raw_logprobs 的生命周期 :这个变量可能在后续被 processed_logprobs 覆盖(第 97 行),也用于 gather_logprobs 的输入。

python 复制代码
        # Use float32 for the logits.
        logits = logits.to(torch.float32)

第 81 行Step 2 --- logits 精度提升

将 logits 转换为 float32。这是一个关键操作,原因:

  1. log_softmax 需要高精度以避免数值溢出/下溢
  2. 温度除法 logits / temp 在低精度下可能不准确
  3. Top-K/Top-P 的概率归一化需要足够精度
  4. 惩罚计算(乘以 >1 的重复惩罚系数)可能导致 float16 溢出

.to(torch.float32) 的行为 :如果 logits 已经是 float32,这是 no-op(返回自身,不创建新张量)。如果是 float16/bfloat16,创建新的 float32 张量。注意这里的 logits 是局部变量,重新绑定不会影响调用方的原始张量。

python 复制代码
        logits = self.apply_logits_processors(
            logits, sampling_metadata, predict_bonus_token
        )

第 83-85 行Steps 3-6 --- Logits 预处理管线

调用 apply_logits_processors,该方法内部按序执行:

  1. 应用白名单掩码(Step 3)
  2. 应用坏词排除(Step 4)
  3. 应用非 argmax 不变处理器(Step 5)
  4. 应用惩罚(Step 6)

返回值 :处理后的 logits 张量(可能是原地修改,也可能是新张量)。重新绑定到 logits 变量。

predict_bonus_token 参数 :传递到 apply_logits_processors 中,影响 output_token_ids 的合并逻辑(是否包含 speculative tokens)。

python 复制代码
        # Sample the next token.
        sampled, processed_logprobs = self.sample(logits, sampling_metadata)

第 86-87 行Step 7 --- 采样决策

调用 self.sample() 执行采样,返回两个值:

  • sampledTensor[batch_size],采样得到的 token ID
  • processed_logprobsTensor | None,处理后的 logprobs(仅在特定 logprobs_mode 下返回)

为什么叫 processed_logprobs :与前面的 raw_logprobs 对应。raw_logprobs 是基于原始 logits 计算的,而 processed_logprobs 是基于经过所有处理(惩罚、温度、Top-K/Top-P)后的 logits/logprobs 计算的。

python 复制代码
        if processed_logprobs is not None:
            raw_logprobs = processed_logprobs

第 88-89 行 :如果 sample() 返回了 processed_logprobs,则用其覆盖 raw_logprobs

什么时候 processed_logprobs 不为 None :当 logprobs_mode"processed_logits""processed_logprobs" 时,sample() 方法内部会计算并返回处理后的 logprobs。此时覆盖 raw_logprobs 使得后续的 gather_logprobs 使用处理后的值。

设计理由 :通过覆盖而非新变量,简化了后续代码的逻辑------gather_logprobs 只需要关注 raw_logprobs 这个变量名,不需要知道它是原始的还是处理后的。

python 复制代码
        # Convert sampled token ids to int64 (long) type to ensure compatibility
        # with subsequent operations that may use these values as indices.
        # This conversion is necessary because FlashInfer sampling operations
        # return int32 (while PyTorch argmax and topk return int64).
        sampled = sampled.long()

第 90-94 行:将采样结果转换为 int64。

为什么需要转换:不同的采样后端返回不同的数据类型:

  • PyTorch argmax():返回 int64
  • PyTorch topk():返回 int64
  • FlashInfer 采样操作:返回 int32

后续操作(如 logprobs.gather(-1, token_ids))需要索引为 int64,因此统一转换。

.long() :等价于 .to(torch.int64),是 PyTorch 的惯用写法。

python 复制代码
        # Handle logprob_token_ids if specified (more efficient than full vocab)
        # This is used by generative_scoring API to get logprobs for specific tokens
        logprob_token_ids_tensors = None
        if sampling_metadata.logprob_token_ids:
            logprob_token_ids_tensors = self.gather_specific_token_logprobs(
                logits, sampling_metadata.logprob_token_ids, sampled
            )

第 96-101 行:处理指定的 logprob_token_ids------一种高效计算特定 token logprobs 的方式。

logprob_token_ids 的用途:某些 API(如 generative_scoring)只需要特定 token 的 logprobs,而非整个 Top-K。例如,评估模型在某个特定 token 上的置信度。此时计算整个词表的 log_softmax 再取 Top-K 是浪费的。

gather_specific_token_logprobs :使用 Triton fused kernel(log_softmax + gather)高效计算,只需对指定的 token 做 log_softmax + gather,避免计算整个词表。

传入 logits :注意这里传入的是经过所有处理后的 logits(Step 2-6 之后),因此计算出的 logprobs 反映了惩罚等处理的效果。

传入 sampled:已转为 int64 的采样 token ID,用于获取采样 token 的排名。

python 复制代码
        if num_logprobs is None:
            logprobs_tensors = logprob_token_ids_tensors

第 103-104 行如果没有请求 Top-K logprobsnum_logprobs is None),则 logprobs 输出完全由 logprob_token_ids_tensors 决定:

  • 如果有 logprob_token_ids:返回指定 token 的 logprobs
  • 如果没有:logprob_token_ids_tensorsNone,最终 logprobs_tensors 也为 None
python 复制代码
        elif num_logprobs == -1:
            # Return the full unsorted and unranked logprobs.
            logprobs_tensors = LogprobsTensors(
                torch.empty(0), raw_logprobs, torch.empty(0)
            )

第 105-108 行如果 num_logprobs == -1,返回完整的未排序未排名 logprobs。

LogprobsTensors(torch.empty(0), raw_logprobs, torch.empty(0))

  • 第一个参数 logprob_token_ids:空张量(没有索引,因为返回全部)
  • 第二个参数 logprobs:完整的 [batch_size, vocab_size] logprobs 张量
  • 第三个参数 selected_token_ranks:空张量(没有排名,因为返回全部)

使用 torch.empty(0) 而非 None:保持张量类型一致性,避免下游代码的类型检查问题。

python 复制代码
        else:
            # Gather the logprobs and ranks of the topk and sampled token.
            logprobs_tensors = self.gather_logprobs(
                raw_logprobs, num_logprobs, token_ids=sampled
            )

第 109-112 行Step 8 --- 正常的 Top-K logprobs 收集

num_logprobs 为 0 或正整数时,调用 gather_logprobs 收集 Top-K logprobs 和采样 token 的 logprob/rank。

传入 raw_logprobs :可能是原始 logprobs("raw_logprobs" 模式)或处理后的 logprobs("processed_*" 模式),取决于之前的覆盖逻辑。

传入 token_ids=sampled:采样 token 的 ID,用于获取该 token 的 logprob 和排名。

python 复制代码
        # If we have both num_logprobs and logprob_token_ids, prefer
        # logprob_token_ids as it's more specific
        if logprob_token_ids_tensors is not None and num_logprobs is not None:
            logprobs_tensors = logprob_token_ids_tensors

第 114-116 行当同时存在 Top-K logprobs 和指定 token logprobs 时,优先使用指定 token logprobs

设计理由logprob_token_ids 是更精确的请求------用户明确知道要哪些 token 的 logprobs。Top-K 是更泛化的请求。当两者共存时,精确请求应优先。

注意 :这个覆盖发生在 gather_logprobs 之后,意味着 Top-K 的计算被浪费了。但这种情况(同时请求 Top-K 和指定 token)在实际中极少出现,性能损失可接受。

python 复制代码
        # Use int32 to reduce the tensor size.
        sampled = sampled.to(torch.int32)

第 118-119 行:将采样 token ID 从 int64 转回 int32。

为什么先转 int64 再转 int32 :int64 是中间态,为了 gather_logprobs 中的 torch.topklogprobs.gather 操作(它们需要 int64 索引)。计算完成后转回 int32 以减少传输和存储开销------int32 的张量大小是 int64 的一半。

python 复制代码
        # These are GPU tensors.
        sampler_output = SamplerOutput(
            # The sampled tokens are expanded to 2D tensor with shape
            # [num_requests, 1], where each row represents one generated
            # token per request.
            sampled_token_ids=sampled.unsqueeze(-1),
            logprobs_tensors=logprobs_tensors,
        )
        return sampler_output

第 121-122 行 (实际跨越第 120-127 行):Step 9 --- 构建并返回 SamplerOutput

sampled.unsqueeze(-1) :将 [batch_size] 扩展为 [batch_size, 1]。这是因为输出格式约定每个请求可能生成多个 token(投机解码场景),1 表示当前只生成了 1 个 token。

logprobs_tensors :可能是 None(无 logprobs 请求)、LogprobsTensors(Top-K 或指定 token 的 logprobs)。

注释 "These are GPU tensors":提醒下游代码,这些张量位于 GPU 上,如果需要 CPU 访问需要显式传输。

SamplerOutput 的结构

python 复制代码
@dataclass
class SamplerOutput:
    sampled_token_ids: torch.Tensor    # [batch_size, 1], int32
    logprobs_tensors: LogprobsTensors | None

3.8 gather_specific_token_logprobs 方法(第 124-185 行)

python 复制代码
    def gather_specific_token_logprobs(
        self,
        logits: torch.Tensor,
        logprob_token_ids: dict[int, list[int]],
        sampled: torch.Tensor,
    ) -> LogprobsTensors | None:

第 124-128 行:方法签名。

参数 logits[batch_size, vocab_size],经过所有处理的 logits(float32)。

参数 logprob_token_ids :字典,键为请求索引(int),值为该请求需要计算 logprobs 的 token ID 列表。注意这是异构的------不同请求可能有不同数量和不同 ID 的 token。

参数 sampled[batch_size],已采样的 token ID(int64)。

返回值LogprobsTensors | None。如果没有请求需要指定 token logprobs,返回 None

python 复制代码
        """Compute logprobs for specific token IDs using Triton kernel.

        This method handles heterogeneous token ID lists across requests by
        padding shorter lists to max length and using a fused Triton kernel
        for efficient log_softmax + gather computation.

        Benchmarks show the Triton kernel approach is ~1.4x faster than sparse
        gather for batch sizes > 1 due to the fused kernel reducing memory
        bandwidth requirements.

        Args:
            logits: [batch_size, vocab_size] tensor of logits
            logprob_token_ids: dict mapping req_index -> list of token IDs
            sampled: [batch_size] tensor of sampled token IDs

        Returns:
            LogprobsTensors with logprobs for the specified tokens, or None
            if no requests have logprob_token_ids.
        """

第 129-150 行:方法文档字符串,包含重要的性能注释。

关键性能数据:Triton fused kernel 比稀疏 gather 快约 1.4 倍(batch_size > 1 时),因为融合内核减少了内存带宽需求------不需要先写出完整的 log_softmax 结果再 gather,而是在一次 kernel 调用中完成。

python 复制代码
        if not logprob_token_ids:
            return None

第 151-152 行 :空字典快速返回。如果没有任何请求指定了 logprob_token_ids,直接返回 None,避免后续不必要的张量创建。

python 复制代码
        batch_size = logits.shape[0]
        device = logits.device

第 154-155 行 :提取 batch 大小和设备信息。device 用于确保所有创建的张量都在同一设备上。

python 复制代码
        # Find max number of tokens across all requests
        max_num_tokens = max(len(tids) for tids in logprob_token_ids.values())

第 157-158 行:计算所有请求中最大的 token ID 列表长度。这决定了填充后的张量宽度。由于不同请求可能有不同数量的指定 token,需要填充到统一长度。

python 复制代码
        # Create padded token_ids tensor: [batch_size, max_num_tokens + 1]
        # +1 for sampled token in first position
        token_ids_tensor = torch.zeros(
            batch_size, max_num_tokens + 1, dtype=torch.int64, device=device
        )
        token_ids_tensor[:, 0] = sampled  # First column is sampled token

第 160-165 行:创建填充后的 token_ids 张量。

形状 [batch_size, max_num_tokens + 1]:+1 是因为第一列固定存放采样 token 的 ID。

第 0 列为采样 token:这是一个设计约定------采样 token 总是在第一列,后续列是用户指定的 token。这确保采样 token 的 logprob 总是可获取的。

torch.zeros :初始化为 0,后续填充有效位置。0 在词表中可能对应有效 token,但配合 valid_mask 使用不会引起歧义。

python 复制代码
        # Create mask for valid positions (True = valid, False = padded)
        valid_mask = torch.zeros(
            batch_size, max_num_tokens + 1, dtype=torch.bool, device=device
        )
        valid_mask[:, 0] = True  # Sampled token is always valid

第 167-171 行:创建有效性掩码。

用途 :填充位置的 logprobs 需要被标记为无效(-inf),否则 0 号 token 的 logprob 会是有效值,造成误导。

第 0 列始终有效:采样 token 的 logprob 总是需要。

python 复制代码
        # Fill in token IDs for each request
        for req_idx, token_ids in logprob_token_ids.items():
            num_tokens = len(token_ids)
            token_ids_tensor[req_idx, 1 : num_tokens + 1] = torch.tensor(
                token_ids, dtype=torch.int64, device=device
            )
            valid_mask[req_idx, 1 : num_tokens + 1] = True

第 173-178 行:填充每个请求的指定 token ID。

req_idx:字典的键,即请求在 batch 中的索引。

1 : num_tokens + 1 :从第 1 列开始填充(第 0 列是采样 token),填充 num_tokens 个元素。

torch.tensor(token_ids, ...) :将 Python list 转为 GPU 张量。注意这是在循环中创建小张量,虽然不是最优的,但因为 logprob_token_ids 通常只有少量请求且每个请求的 token 数很少,性能影响可忽略。

valid_mask[req_idx, 1 : num_tokens + 1] = True:标记这些位置为有效。

python 复制代码
        # Compute logprobs using the fused Triton kernel (log_softmax + gather)
        logprobs = compute_token_logprobs(logits, token_ids_tensor)

第 180-181 行:使用 Triton fused kernel 计算 logprobs。

compute_token_logprobs :定义在 v1/worker/gpu/sample/logprob.py 中的 Triton JIT 内核。它融合了 log_softmax + gather 两个操作:

  1. 对 logits 执行 log_softmax(数值稳定版:减最大值 → exp → 归一化 → log)
  2. 根据 token_ids_tensor 中的索引 gather 出指定位置的 logprobs

优势 :不需要先写出 [batch_size, vocab_size] 的完整 logprobs 再 gather,而是在 kernel 内部直接计算目标位置的值,大幅减少内存写入。

输入logits [batch_size, vocab_size]token_ids_tensor [batch_size, max_num_tokens + 1]
输出logprobs [batch_size, max_num_tokens + 1]

python 复制代码
        # Mask invalid (padded) positions with -inf
        logprobs = logprobs.masked_fill(~valid_mask, float("-inf"))

第 183-184 行 :将填充位置的 logprobs 设为 -inf

~valid_mask:取反,True 变 False,False 变 True。即标记无效位置。

float("-inf") :负无穷,语义上表示"这个位置的 logprob 无效/不存在"。下游代码可以通过 -inf 识别并跳过这些位置。

python 复制代码
        # Compute ranks for the sampled token
        sampled_logits = logits.gather(-1, sampled.unsqueeze(-1))
        token_ranks = (logits > sampled_logits).sum(dim=-1)

第 186-187 行:计算采样 token 的排名。

sampled_logits[batch_size, 1],采样 token 对应的 logit 值。gather 从 logits 的最后一维按索引取值。

token_ranks[batch_size],每个 batch 行中有多少 token 的 logit 严格大于采样 token 的 logit。排名从 0 开始------如果采样 token 的 logit 是最大的,排名为 0。

注意 :这里使用 >(严格大于)而非 >=。这意味着如果有多个 token 与采样 token 的 logit 相同,它们的排名相同。

python 复制代码
        return LogprobsTensors(
            logprob_token_ids=token_ids_tensor.to(torch.int32),
            logprobs=logprobs,
            selected_token_ranks=token_ranks,
        )

第 189-193 行:构建并返回 LogprobsTensors。

token_ids_tensor.to(torch.int32):将 token ID 从 int64 转为 int32 以节省空间。int32 可表示到 2^31 ≈ 21 亿,远超任何词表大小。

logprobs[batch_size, max_num_tokens + 1],float32,指定 token 的 logprobs(填充位置为 -inf)。

token_ranks[batch_size],采样 token 的排名。


3.9 apply_temperature 静态方法(第 187-194 行)

python 复制代码
    @staticmethod
    def apply_temperature(
        logits: torch.Tensor,
        temp: torch.Tensor,
        all_random: bool,
    ) -> torch.Tensor:

第 187-191 行:方法签名。

@staticmethod :标记为静态方法,不需要 self。这意味着此方法不访问实例状态,可以独立调用。

参数 logits[batch_size, vocab_size],float32 logits。

参数 temp[batch_size],每个请求的温度参数。

参数 all_randombool,是否所有请求都是随机采样。

python 复制代码
        # Use in-place division to avoid creating a new tensor.
        # Avoid division by zero if there are greedy requests.
        if not all_random:
            temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)

第 192-194 行:避免除以零。

all_random=False:意味着存在贪心请求(温度为 0 或接近 0)。直接除以接近 0 的温度会导致数值爆炸或 NaN。

torch.where:元素级条件------如果温度 < EPS,替换为 1.0;否则保持原值。这样贪心请求的 logits 不变(除以 1.0),随机请求正常缩放。

temp = torch.where(...) :创建新的 temp 张量,不影响原始 sampling_metadata.temperature

all_random=True :所有请求都是随机的(温度 > 0),不需要除零保护,跳过 torch.where 操作节省计算。

python 复制代码
        return logits.div_(temp.unsqueeze(dim=1))

第 195 行 :执行温度缩放,原地除法

temp.unsqueeze(dim=1) :将 [batch_size] 扩展为 [batch_size, 1],以便与 [batch_size, vocab_size] 的 logits 广播。

.div_():原地除法(注意尾随下划线)。修改 logits 张量本身,不创建新张量。

温度的数学含义logits / T 等效于 Boltzmann 分布的温度参数。高温(T > 1)使分布更均匀,低温(T < 1)使分布更尖锐,T → 0 趋近于 argmax。


3.10 greedy_sample 静态方法(第 196-197 行)

python 复制代码
    @staticmethod
    def greedy_sample(logits: torch.Tensor) -> torch.Tensor:
        return logits.argmax(dim=-1).view(-1)

第 196-197 行:贪心采样实现。

logits.argmax(dim=-1) :沿最后一维(词表维度)取最大值的索引。返回 [batch_size] 的 int64 张量。

.view(-1) :确保输出为 1D。虽然 argmax(dim=-1) 已经返回 1D,但 view(-1) 是防御性编码,确保即使输入形状变化也不会出错。

数学含义:选择 logit 值最大的 token,等效于概率最大(未加温度时)。


3.11 sample 方法(第 199-244 行)

这是采样决策的核心方法,实现了 Step 7 的全部逻辑。

python 复制代码
    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
        logprobs_mode_override: LogprobsMode | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        """Sample logits based on sampling metadata.

        The various logits processing functions called in this method
        may update the logits tensor in-place.
        """

第 199-207 行:方法签名和文档。

参数 logits[batch_size, vocab_size],经过 Step 2-6 处理后的 float32 logits。

参数 sampling_metadata:采样参数。

参数 logprobs_mode_override:运行时 logprobs 模式覆盖。

返回值tuple[Tensor, Tensor | None]。第一个是采样 token,第二个是处理后的 logprobs(可能为 None)。

文档注释:"may update the logits tensor in-place"------提醒调用方,此方法可能修改传入的 logits 张量。

python 复制代码
        logprobs_mode = logprobs_mode_override or self.logprobs_mode

第 208 行 :与 forward 中相同的模式------优先使用覆盖值。

python 复制代码
        assert not (sampling_metadata.all_greedy and sampling_metadata.all_random)

第 209 行:不变量断言------不可能同时 all_greedy 和 all_random。这是逻辑约束:如果所有请求都是贪心的,不可能同时所有请求都是随机的。

为什么用 assert 而非 if:这是编程错误(Scheduler 构建的 metadata 不应违反此约束),而非运行时异常。assert 在生产环境(Python -O 模式)下会被移除,不影响性能。

python 复制代码
        if sampling_metadata.all_random:
            greedy_sampled = None

第 210-211 行:如果所有请求都是随机的,不需要贪心采样结果。

greedy_sampled = None :设为 None 而非跳过计算,是为了统一后续的分支逻辑------后续代码通过 greedy_sampled is None 判断是否存在贪心请求。

python 复制代码
        else:
            greedy_sampled = self.greedy_sample(logits)
            if sampling_metadata.all_greedy:
                processed_logprobs = None
                if sampling_metadata.max_num_logprobs is not None:
                    if logprobs_mode == "processed_logits":
                        processed_logprobs = logits
                    elif logprobs_mode == "processed_logprobs":
                        processed_logprobs = self.compute_logprobs(logits)
                return greedy_sampled, processed_logprobs

第 212-220 行:贪心采样分支。

第 213 行 :执行贪心采样 argmax(dim=-1)

第 214 行 :如果所有请求都是贪心的,直接返回,跳过后续的温度/Top-K/Top-P 等随机采样步骤。这是快速路径优化------全贪心 batch 不需要任何随机采样计算。

第 215-219 行:如果请求了 logprobs,根据模式计算:

  • "processed_logits":直接使用处理后的 logits(已经是 float32)
  • "processed_logprobs":对处理后的 logits 执行 log_softmax

注意:此时 logits 已经经过 Step 2-6 的所有处理(白名单、坏词、处理器、惩罚),所以是"processed"的。

processed_logprobs = logits(非 clone) :因为 logits 在后续不会被修改(直接 return),所以可以直接引用。但如果调用方后续修改 logits 会有问题------不过此方法已 return,不会发生。

python 复制代码
        assert sampling_metadata.temperature is not None

第 221 行:断言温度参数存在。如果不是全贪心,必须有温度参数。

python 复制代码
        # Apply temperature.
        logits = self.apply_temperature(
            logits, sampling_metadata.temperature, sampling_metadata.all_random
        )

第 223-225 行Step 7b --- 温度缩放

注意 logits 被重新绑定。apply_temperature 返回的是原地修改后的 logits(.div_()),所以实际上 logits 对象没变,只是内容被修改了。但写 logits = ... 使得代码意图更清晰。

python 复制代码
        # Apply logits processors that only apply to random sampling
        # (argmax invariant)
        for processor in sampling_metadata.logitsprocs.argmax_invariant:
            logits = processor.apply(logits)

第 227-229 行Step 7c --- 应用 argmax 不变处理器

argmax_invariant :不影响 argmax 结果的处理器,仅在随机采样时有意义。典型代表是 MinPLogitsProcessor

为什么在温度之后应用:argmax 不变处理器作用于温度缩放后的 logits。例如 min-p 需要基于 softmax 概率(即温度后的分布)来设定阈值,因此必须在温度缩放后。

循环遍历 :可能有多个 argmax 不变处理器,按顺序应用。每个处理器的 apply() 返回 logits(可能是原地修改)。

python 复制代码
        # Apply top_k and/or top_p.
        random_sampled, processed_logprobs = self.topk_topp_sampler(
            logits,
            sampling_metadata.generators,
            sampling_metadata.top_k,
            sampling_metadata.top_p,
        )

第 231-235 行Step 7d-e --- Top-K/Top-P 过滤 + 随机采样

self.topk_topp_sampler() :调用 TopKTopPSampler.__call__()(即 forward()),它会:

  1. 根据 top_k 和 top_p 参数选择后端(CUDA/FlashInfer/CPU/HIP)
  2. 执行 Top-K 过滤(保留概率最高的 K 个 token)
  3. 执行 Top-P 过滤(保留累积概率达到 P 的最小 token 集)
  4. 从过滤后的分布中随机采样

参数

  • logits:温度缩放 + argmax 不变处理后的 logits
  • generators:每请求的随机数生成器(保证可复现性)
  • top_k:Top-K 参数张量
  • top_p:Top-P 参数张量

返回值

  • random_sampled[batch_size],随机采样的 token ID
  • processed_logprobsTensor | None,TopKTopPSampler 可能计算的处理后 logprobs
python 复制代码
        if greedy_sampled is None:
            return random_sampled, processed_logprobs

第 237-238 行 :如果不存在贪心请求(all_random=True),直接返回随机采样结果。

python 复制代码
        sampled = torch.where(
            sampling_metadata.temperature < _SAMPLING_EPS,
            greedy_sampled,
            random_sampled,
            out=greedy_sampled,  # Reuse tensor
        )
        return sampled, processed_logprobs

第 240-244 行Step 7f --- 混合贪心/随机结果

torch.where(condition, x, y):元素级条件选择:

  • 如果 temperature < EPS(贪心请求),使用 greedy_sampled
  • 否则(随机请求),使用 random_sampled

out=greedy_sampled :复用 greedy_sampled 张量的内存,避免分配新张量。这是一个重要的内存优化------torch.where 的结果可以写入 out 指定的张量中。

为什么可以复用greedy_sampledwhere 操作后就不再需要了,所以可以安全地覆盖。

注意 :此时 sampledgreedy_sampled 可能指向同一块内存(out 的效果),但这是有意为之。

返回值分析

  • 混合 batch:贪心请求得到 argmax 结果,随机请求得到随机采样结果
  • processed_logprobs:从 topk_topp_sampler 返回的,可能为 None

3.12 compute_logprobs 静态方法(第 246-247 行)

python 复制代码
    @staticmethod
    def compute_logprobs(logits: torch.Tensor) -> torch.Tensor:
        return logits.log_softmax(dim=-1, dtype=torch.float32)

第 246-247 行:计算 log softmax。

log_softmax(dim=-1, dtype=torch.float32):沿最后一维(词表维度)计算 log softmax,强制使用 float32 精度。

dtype=torch.float32 :即使 logits 是 float16,计算也在 float32 下进行,保证数值稳定性。PyTorch 的 log_softmax 在内部会先转为 float32 计算,再转回输入类型。显式指定 dtype=torch.float32 确保输出也是 float32。

数学公式log_softmax(x_i) = x_i - log(sum(exp(x_j))),等价于 softmax 后取 log,但数值更稳定(直接计算 softmaxlog 可能导致 log(0) = -inf 的问题)。

用途

  1. forward 中计算 raw_logprobs(Step 1a)
  2. sample 中计算 processed_logprobs(全贪心 + processed_logprobs 模式)

3.13 gather_logprobs 静态方法(第 249-293 行)

python 复制代码
    @staticmethod
    def gather_logprobs(
        logprobs: torch.Tensor,
        num_logprobs: int,
        token_ids: torch.Tensor,
    ) -> LogprobsTensors:

第 249-253 行:方法签名。

参数 logprobs[num_tokens, vocab_size],log softmax 后的 logprobs。

参数 num_logprobs:要收集的 Top-K 数量。

参数 token_ids[num_tokens],采样 token 或 prompt token 的 ID(int64)。

python 复制代码
        """
        Gather logprobs for topk and sampled/prompt token.

        Args:
          logprobs: (num tokens) x (vocab) tensor
          num_logprobs: maximum number of logprobs to
                        retain per token
          token_ids: prompt tokens (if prompt logprobs)
                     or sampled tokens (if sampled
                     logprobs); 1D token ID tensor
                     with (num tokens) elements
                     Must be int64.

        Returns:
          Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
          Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
          Sampled token rank tensor, (num tokens)
        """

第 254-273 行:详细文档。注意:

  • 输出形状 (num tokens) x (num_logprobs + 1)------+1 是因为采样/prompt token 的 logprob 被拼接到 Top-K 之前
  • token_ids 必须是 int64
python 复制代码
        assert token_ids.dtype == torch.int64

第 274 行 :断言 token_ids 为 int64。这是因为 logprobs.gather(-1, index) 要求 index 为 int64,否则 PyTorch 会抛出错误。

python 复制代码
        # Find the topK values.
        topk_logprobs, topk_indices = torch.topk(logprobs, num_logprobs, dim=-1)

第 276-277 行:执行 Top-K 选择。

torch.topk(input, k, dim):沿指定维度取前 K 大的值和索引。

返回值

  • topk_logprobs[num_tokens, num_logprobs],Top-K 的 logprob 值
  • topk_indices[num_tokens, num_logprobs],Top-K 对应的 token ID

注意topk 返回的是降序排列的------第一列是最大值,最后一列是第 K 大值。

python 复制代码
        # Get with the logprob of the prompt or sampled token.
        token_ids = token_ids.unsqueeze(-1)
        token_logprobs = logprobs.gather(-1, token_ids)

第 279-281 行:获取采样/prompt token 的 logprob。

token_ids.unsqueeze(-1) :从 [num_tokens] 变为 [num_tokens, 1],为 gather 操作准备索引形状。

logprobs.gather(-1, token_ids) :沿最后一维按索引取值。返回 [num_tokens, 1],每个位置是对应 token 的 logprob。

为什么不在 topk 中直接查找:采样 token 可能不在 Top-K 中(如果它的 logprob 较低),因此需要单独 gather。

python 复制代码
        # Compute the ranks of the actual token.
        # Avoid 0/1 specialization recompile on the batch dimension
        # of the compiled batched_count_greater_than. mark_unbacked makes
        # the size fully symbolic so dynamo doesn't specialize when
        # batch_size transitions from 1 to >=2.
        torch._dynamo.decorators.mark_unbacked(logprobs, 0)
        torch._dynamo.decorators.mark_unbacked(token_logprobs, 0)
        token_ranks = batched_count_greater_than(logprobs, token_logprobs)

第 283-288 行:计算采样 token 的排名。

mark_unbacked :这是 PyTorch Dynamo(torch.compile 的前端)的内部 API。mark_unbacked(tensor, dim) 将指定维度标记为"非编译时已知",使 Dynamo 不会在 batch_size 从 1 变为 2 时触发重新编译。

为什么需要batched_count_greater_than 内部使用 (x >= values).sum(-1)。Dynamo 在编译时会尝试特化(specialize)张量的形状。当 batch_size=1 时编译的 kernel 在 batch_size=2 时不能直接使用。mark_unbacked 使 batch 维度保持符号化,避免重复编译。

batched_count_greater_than(logprobs, token_logprobs) :计算每行中有多少 logprob 值 >= 采样 token 的 logprob。由于 logprobs 是降序排列的,这个计数就是采样 token 的排名(0-indexed:最大 token 的排名为 1,因为自身也满足 >= 条件...等等,实际上需要仔细看)。

排名的精确定义(x >= values).sum(-1) 其中 values 是采样 token 的 logprob。这计算的是有多少 token 的 logprob >= 采样 token 的 logprob。如果采样 token 的 logprob 是最大的,只有它自己满足条件,排名为 1。如果它是第 3 大的,排名为 3。

注意 :这是 >= 而非 >,所以如果有多个 token 的 logprob 与采样 token 相同,它们都会被计入。

python 复制代码
        # Concatenate together with the topk.
        indices = torch.cat((token_ids, topk_indices), dim=1)
        logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)

第 290-292 行:拼接采样 token 和 Top-K。

拼接顺序:采样/prompt token 在前(第 0 列),Top-K 在后(第 1 到 num_logprobs 列)。

形状[num_tokens, num_logprobs + 1]

可能的重复 :如果采样 token 在 Top-K 中,它会出现两次(第 0 列和 Top-K 中的某个位置)。这由下游的 LogprobsProcessor 去重处理。

python 复制代码
        # Use int32 to reduce the tensor size.
        indices = indices.to(torch.int32)

        return LogprobsTensors(indices, logprobs, token_ranks)

第 294-297 行:类型转换并返回。

indices.to(torch.int32):将 token ID 从 int64 转为 int32,节省一半存储。

LogprobsTensors(indices, logprobs, token_ranks)

  • indices[num_tokens, num_logprobs + 1],int32
  • logprobs[num_tokens, num_logprobs + 1],float32
  • token_ranks[num_tokens],int64(由 batched_count_greater_than 返回)

3.14 _combine_outputs_with_spec_tokens 静态方法(第 295-303 行)

python 复制代码
    @staticmethod
    def _combine_outputs_with_spec_tokens(
        output_token_ids: list[list[int]],
        spec_token_ids: list[list[int]] | None = None,
    ) -> list[list[int]]:
        if spec_token_ids is None:
            return output_token_ids

        return [
            [*out, *spec] if spec else out
            for out, spec in zip(output_token_ids, spec_token_ids)
        ]

第 295-303 行:将投机解码的 draft token 合并到已生成 token 列表中。

前导下划线:表示内部方法,不对外暴露。

参数 output_token_idslist[list[int]],每个请求已生成的 token ID 列表。

参数 spec_token_idslist[list[int]] | None,每个请求的投机(draft)token ID 列表。

spec_token_ids is None:如果没有投机 token,直接返回原始列表。

列表推导式 :对每个请求,如果存在投机 token,将其追加到输出 token 后面。[*out, *spec] 是 Python 的解包合并语法。

条件 if spec else out:如果某个请求没有投机 token(空列表),不追加。这处理了异构 batch 的情况------部分请求可能有投机 token,部分没有。

用途:在惩罚计算中,已生成 token 集合应包含投机 token。因为如果模型已经(通过投机解码)生成了某些 token,这些 token 的重复/频率/出现惩罚应该被计算。


3.15 apply_logits_processors 方法(第 305-333 行)

python 复制代码
    def apply_logits_processors(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
        predict_bonus_token: bool,
    ) -> torch.Tensor:

第 305-309 行:方法签名。此方法编排 Steps 3-6 的 logits 预处理管线。

python 复制代码
        bad_words_token_ids = sampling_metadata.bad_words_token_ids
        any_penalties_or_bad_words = (
            bool(bad_words_token_ids) or not sampling_metadata.no_penalties
        )

第 310-313 行:预计算条件标志。

bad_words_token_ids:提取到局部变量,减少属性访问。

any_penalties_or_bad_words:是否存在任何惩罚或坏词。用于决定是否需要合并投机 token。

bool(bad_words_token_ids):空字典为 False,非空为 True。

not sampling_metadata.no_penaltiesno_penalties=True 表示没有惩罚,取反表示有惩罚。

python 复制代码
        output_token_ids = sampling_metadata.output_token_ids
        if predict_bonus_token and any_penalties_or_bad_words:
            # Combine base outputs with spec tokens when speculative decoding
            # is enabled.
            output_token_ids = self._combine_outputs_with_spec_tokens(
                output_token_ids,
                sampling_metadata.spec_token_ids,
            )

第 315-320 行:投机 token 合并。

条件 predict_bonus_token and any_penalties_or_bad_words:仅在投机解码模式下且存在惩罚/坏词时才合并。如果没有任何惩罚或坏词,投机 token 不影响 logits 处理,无需合并。

优化意义:避免不必要的列表合并操作。

python 复制代码
        # Apply allowed token ids.
        if sampling_metadata.allowed_token_ids_mask is not None:
            logits.masked_fill_(sampling_metadata.allowed_token_ids_mask, float("-inf"))

第 322-324 行Step 3 --- 白名单掩码

allowed_token_ids_mask[max_batch_size, vocab_size] 的布尔张量。注意语义------True 表示不允许的 token(需要被掩码的位置),这是反向的(通常 True 表示有效,但这里是 True 表示无效)。这可能与构建掩码时的效率考量有关。

masked_fill_(mask, value) :原地操作,将 mask 为 True 的位置填充为 -inf

数学效果-inf 经过 softmax 后概率为 0,等效于完全排除这些 token。

为什么在坏词之前:白名单是更严格的约束------如果 token 不在白名单中,无论是否是坏词都应该排除。

python 复制代码
        # Apply bad words exclusion.
        if bad_words_token_ids:
            apply_bad_words(logits, bad_words_token_ids, output_token_ids)

第 326-328 行Step 4 --- 坏词排除

apply_bad_words :定义在 ops/bad_words.py 中,通过前缀匹配实现坏词排除。对于每个坏词,检查已生成 token 序列的末尾是否匹配坏词的前缀部分,如果匹配则将坏词最后一个 token 的 logit 设为 -inf

output_token_ids:已生成的 token 列表,用于前缀匹配。注意这可能已经合并了投机 token。

python 复制代码
        # Apply logits processors which can impact greedy sampling.
        for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
            logits = processor.apply(logits)

第 330-332 行Step 5 --- 非 argmax 不变处理器

non_argmax_invariant:会影响 argmax 结果的处理器,必须在贪心采样之前应用。包括:

  • MinTokensLogitsProcessor:强制生成至少 N 个 token(将 EOS logit 设为 -inf
  • LogitBiasLogitsProcessor:对特定 token 加偏置(可能改变 argmax)

循环应用:每个处理器的输出作为下一个的输入,形成链式处理。

python 复制代码
        # Apply penalties (e.g., freq_penalties).
        logits = self.apply_penalties(logits, sampling_metadata, output_token_ids)
        return logits

第 334-336 行Step 6 --- 惩罚施加

委托给 apply_penalties 静态方法,该方法进一步委托给 ops/penalties.apply_all_penalties

返回 logits:经过 Steps 3-6 完整处理后的 logits。


3.16 apply_penalties 静态方法(第 338-349 行)

python 复制代码
    @staticmethod
    def apply_penalties(
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
        output_token_ids: list[list[int]],
    ) -> torch.Tensor:

第 338-342 行:方法签名。

参数 output_token_ids :独立传递而非从 sampling_metadata 获取,因为在投机解码场景下可能已经被合并了投机 token。

python 复制代码
        if sampling_metadata.no_penalties:
            return logits

第 343-344 行:快速路径------如果没有惩罚,直接返回 logits,跳过所有惩罚计算。

no_penalties :这是一个预计算标志,由 Scheduler 在构建 SamplingMetadata 时设置。如果所有请求的三种惩罚值都为默认值(0),则 no_penalties=True。这是一个重要的性能优化------大多数推理请求不使用惩罚,跳过可以避免不必要的张量操作。

python 复制代码
        assert sampling_metadata.prompt_token_ids is not None

第 345 行:断言 prompt_token_ids 存在。惩罚计算需要区分 prompt token 和生成 token:

  • repetition_penalty:对所有已出现的 token(prompt + output)统一缩放
  • frequency_penalty:仅对 output token 的出现次数线性递减
  • presence_penalty:仅对 output token 是否出现做常数递减

因此 prompt_token_ids 是必须的。

python 复制代码
        return apply_all_penalties(
            logits,
            sampling_metadata.prompt_token_ids,
            sampling_metadata.presence_penalties,
            sampling_metadata.frequency_penalties,
            sampling_metadata.repetition_penalties,
            output_token_ids,
        )

第 346-352 行 :委托给 ops/penalties.apply_all_penalties

传入参数

  • logits[batch_size, vocab_size]
  • prompt_token_ids[batch_size, max_prompt_len],prompt token
  • presence_penalties[batch_size],出现惩罚系数
  • frequency_penalties[batch_size],频率惩罚系数
  • repetition_penalties[batch_size],重复惩罚系数
  • output_token_idslist[list[int]],已生成 token

apply_all_penalties 内部逻辑 (在 ops/penalties.py 中):

  1. output_token_ids 转为填充张量(make_tensor_with_pad
  2. 将 -1 占位符替换为 vocab_size(无效 token ID)
  3. 调用 vllm.model_executor.layers.utils.apply_penalties 执行三种惩罚

3.17 完整 9 步管线回顾

forward 方法的流程与类文档中描述的 9 步对应:

步骤 对应代码 操作 输入→输出
1 forward 第73-79行 计算原始 logprobs logitsraw_logprobs
2 forward 第81行 logits → float32 logitsfloat32 logits
3 apply_logits_processors 第322-324行 白名单掩码 logitsfiltered logits
4 apply_logits_processors 第326-328行 坏词排除 logitsclean logits
5 apply_logits_processors 第330-332行 非argmax不变处理器 logitsprocessed logits
6 apply_logits_processors 第334行 → apply_penalties 惩罚 logitspenalized logits
7a sample 第210-220行 贪心采样/快速返回 logitsgreedy_sampled
7b sample 第223-225行 温度缩放 logitsscaled logits
7c sample 第227-229行 argmax不变处理器 logitsfiltered logits
7d-e sample 第231-235行 Top-K/Top-P + 随机采样 logitsrandom_sampled
7f sample 第240-244行 混合贪心/随机 sampled
8 forward 第109-112行 收集 Top-K logprobs raw_logprobslogprobs_tensors
9 forward 第120-127行 构建 SamplerOutput SamplerOutput

3.18 关键设计决策总结

  1. 原始 vs 处理后 logprobs:V1 选择返回基于原始 logits 的 logprobs(V0 返回处理后的),这更符合"模型真实置信度"的语义

  2. 贪心/随机混合批处理 :通过 torch.where + out=greedy_sampled 实现零额外内存的混合采样

  3. 两阶段处理器分类:argmax 不变 vs 非不变,确保贪心采样不受 min-p 等处理器影响

  4. 投机 token 合并的惰性策略:仅在存在惩罚/坏词时才合并投机 token,避免不必要的列表操作

  5. int32 ↔ int64 的精确转换时机:仅在需要作为索引时使用 int64,存储和传输时使用 int32

  6. 原地操作优先div_()masked_fill_()out=greedy_sampled 等减少内存分配

  7. Triton fused kernelgather_specific_token_logprobs 使用融合的 log_softmax+gather,比分离操作快 1.4 倍

  8. Dynamo 编译优化mark_unbacked 避免 batch_size 变化时的重复编译


本部分完成。第二部分将覆盖 metadata.py 逐行解析、ops/ 算子层深度分析、以及 logits_processor/ 处理器层详解。

相关推荐
SamDeepThinking2 小时前
秒杀下单,用户点一下按钮,后端要过六道关卡
java·后端·架构
一只AI打工虾的自我修养2 小时前
DeepSeek V4 Hybrid Attention Architecture 技术解析
人工智能·ai·开源·aigc
虾米Life2 小时前
MVC与MVVM 架构
架构·mvc·mvvm
薛定谔的猫3692 小时前
基于 MCP (Model Context Protocol) 的智能 Agent 开发指南
ai·llm·agent·mcp·software engineering
Sam_Deep_Thinking2 小时前
适合中小型企业的出口入口网关微服务
java·微服务·架构
阿珊和她的猫2 小时前
大模型在客服场景:落地路径 + 效果评估
ai·agent·llama·cli·mcp
jinanwuhuaguo2 小时前
生态融合与基座成型——OpenClaw v2026.4.24 的功能完备性跃迁与基础设施化拐点(第七篇)
人工智能·安全·架构·kotlin·openclaw
阿泽的AI工具笔记3 小时前
OpenClaw 接入大模型 API 的完整配置流程(Windows 实测可用)
windows·ai
ofoxcoding3 小时前
OpenClaw 自动化交易机器人怎么配置?从零搭建 + 踩坑全记录(2026)
运维·ai·机器人·自动化