《从 0 实现 SGLang》第 2 篇 · 核心数据结构: Req 与 SamplingParams

千行代码,一步步搭出一个现代 LLM 推理引擎,吃透大模型推理的每一项关键技术。

本阶段目标 --- 最简推理实现

用最朴素的方式把端到端推理跑通:先搭起整体框架,再逐个模块替换为完整实现。整个阶段共 5 篇短文:

序号 主题
1 序章:LLM 推理引擎到底在做什么
2 核心数据结构:ReqSamplingParams(本篇)
3 一个可运行的 Decoder Layer(一次性版本)
4 Naive KV Cache
5 单 GPU Engine v0 与 greedy 生成

本篇你将学到

  • 理解为什么一条请求需要一个独立的 Req 对象,而不是直接传 prompt;
  • 看懂 Req 的四个长度字段(cached_len / device_len / max_device_len / extend_len)如何同时表达 prefill 与 decode 两阶段;
  • 明白 complete_one / append_host 为何拆成两个方法------host 与 device 的更新职责不同;
  • 掌握 SamplingParams.is_greedy 这一个布尔表达式背后的工程取舍。

1. 为什么需要 Req

上一篇结尾的状态

第 1 篇结束时,单条请求内部的过程已经清楚:prompt 一次 prefill → 每步 decode 一个新 token。但一旦多请求并发涌入引擎,只有"prompt"是不够的。

一条请求当前的"全部状态"包含什么

试着列一下:

  • 原始 prompt 的 token id 序列(在 CPU);
  • KV cache 在 GPU 上已经写到第几个位置(在 device);
  • 还允许再生成多少个 token(取决于用户设定的 max_tokens);
  • 这条请求的采样参数(temperature / top_k / top_p / ...);
  • 是否还没结束、是否被中途取消......

如果把这些散落在引擎各处,Scheduler 每一步都要从七八个不同的数据结构里凑齐"这条请求的当前状态",代码会迅速失控。

因此引入 Req 这个对象

Req 是引擎里"一条推理请求"的容器。它把这条请求当前的全部状态装进一个 dataclass,Scheduler / Engine / Sampler 之间传递的就是它。

具体怎么解决?Req 用四个长度字段(cached_len / device_len / max_device_len / extend_len)统一表达 prefill 与 decode 两阶段的进度,用两个方法(complete_one / append_host)分别推进 device 侧与 host 侧的状态------任何一处代码只要拿到一个 Req,就能算出"这条请求当前在哪一步、下一步要算什么、还能再走多少步"。

2. 四个长度字段及其演化

四个长度字段彼此通过简单算术绑定:

  • cached_len:KV cache 里已经"算好且固化"的前缀长度。它单调不减
  • device_len:device 侧已经"占了位"的 token 数。初值 = prompt 长度,每完成一次 forward(无论 prefill 还是 decode)+1。
  • max_device_len:这条请求允许占到的最长 device 位置。一经创建不变,= prompt_len + output_len。
  • extend_len(property):下一次 forward 要算 KV 的那一段长度,= device_len - cached_len

外加一个派生量:remain_len = max_device_len - device_len,告诉调度器"这条请求还能再走多少步"。

一图看懂演化

下面以 prompt=6, output_len=4 为例,看 5 个时间点上这些字段的值:

读这张图的关键:绿色段(extend)的物理含义在 prefill 和 decode 上不同。prefill 那一步,绿色是整段 prompt 的 6 个 token,模型一次算完所有 KV;而 decode 每一步,绿色都只有 1 个槽位,是"上一步刚生成的新 token,本步要喂回去算 KV"。

extend_len 字段同时表达了这两种情况------这正是用 4 个长度字段就能驱动 prefill / decode 两阶段的关键设计。

python 复制代码
"""§2 Req:长度字段在 prefill / decode 过程中的演化"""
from dataclasses import dataclass

@dataclass
class Req:
    input_ids: list[int]                # CPU 侧的 token 序列
    output_len: int                     # 期望生成的 token 数
    cached_len: int = 0                 # KV cache 已固化的前缀长度

    def __post_init__(self) -> None:
        # device_len:device 侧当前占位数(prompt 起步,每完成 1 步 +1)
        # max_device_len:device 侧能占到的最远位置
        self.device_len = len(self.input_ids)
        self.max_device_len = len(self.input_ids) + self.output_len

    @property
    def extend_len(self) -> int:
        # 本次 forward 要算 KV 的段长:prefill 时 = 整段 prompt,decode 时 = 1
        return self.device_len - self.cached_len

    @property
    def remain_len(self) -> int:
        # 还允许再 decode 多少步
        return self.max_device_len - self.device_len

    @property
    def can_decode(self) -> bool:
        # decode 循环条件:remain_len 用完即停
        return self.remain_len > 0

    def complete_one(self) -> None:
        # 一次 forward 完成后推进 device 侧进度:
        # 把上次的"待算段"标为已 cached,并为下一个 token 预留 device 槽
        self.cached_len = self.device_len
        self.device_len += 1

    def append_host(self, next_token: int) -> None:
        # 在 CPU 侧的 input_ids 末尾追加新生成的 token
        self.input_ids.append(next_token)


def snapshot(req: Req, stage: str) -> None:
    print(
        f"{stage:<14} | cached={req.cached_len:>2}  device={req.device_len:>2}  "
        f"max_device={req.max_device_len:>2}  extend={req.extend_len:>2}  "
        f"remain={req.remain_len:>2}  can_decode={req.can_decode}"
    )


req = Req(
    input_ids=[101, 102, 103, 104, 105, 106],   # prompt 长 6
    output_len=4,                                # 与上方 SVG 的 5 列对齐
)
snapshot(req, "init")

# prefill:把 prompt 一次性算 KV。完成后第 1 个生成 token 已经在 device 上。
req.complete_one()
snapshot(req, "after prefill")

# decode 循环:每步 append_host(新 token) + complete_one()
for i in range(req.output_len):
    if not req.can_decode:
        break
    req.append_host(200 + i)
    req.complete_one()
    snapshot(req, f"after decode {i + 1}")
ini 复制代码
init           | cached= 0  device= 6  max_device=10  extend= 6  remain= 4  can_decode=True
after prefill  | cached= 6  device= 7  max_device=10  extend= 1  remain= 3  can_decode=True
after decode 1 | cached= 7  device= 8  max_device=10  extend= 1  remain= 2  can_decode=True
after decode 2 | cached= 8  device= 9  max_device=10  extend= 1  remain= 1  can_decode=True
after decode 3 | cached= 9  device=10  max_device=10  extend= 1  remain= 0  can_decode=False

3. host 与 device:为什么要 append_host

同一个"长度"在两个地方

上一节的字段表里,device_len 这个名字提了一个细节:它的语义指的是 GPU 侧 的占位数。但 prompt 的 token id 是在 CPU 侧input_ids 里。

这听起来重复,其实不是------两边的角色不同:

  • CPU 侧的 input_ids:是 token id 序列本身,是"输入"
  • GPU 侧的 KV cache:是每个 token 投影出来的 K / V 张量,是"计算结果"

两者长度天然一致(一个 token 对应一个 KV 槽),但维护逻辑不同 ,所以 Req 用两个不同的方法分别推进它们。

append_hostcomplete_one 各自做什么

complete_one() 只动两个整数:

  • cached_len = device_len:把上一次 forward 算过的那段标记为"已 cached"
  • device_len += 1:为下一个待生成的 token 预留位置

不写任何 KV ------KV 是由 attention 的 forward 自己写进 cache 张量的。complete_one 只是"事后告诉 Req:这一步已经走完了"。

append_host(next_token) 负责的是 CPU 侧:把刚刚 sampling 出来的 token id 放进 input_ids,这样下一步 forward 时模型能拿到它作为输入。

为什么拆成两个方法

如果合成一个 step(next_token),看起来确实简洁。但实际工程里,新 token 从 GPU 上的 sampler 出来后,先要从 GPU 传回 CPU(.cpu() 操作) 才能 append 到 input_ids------这一步在并发 batch 的批量回传里会被集中处理,与单个 Req 的 complete_one 时机错开。拆成两个方法,让两件事可以独立调用、独立测试,也方便 Scheduler 做 batch 级别的优化。

4. 一条 Req 的生命周期

把上面所有字段和方法串成一个状态机,就是这条请求从进入引擎到完成所经历的全部:

四个状态由两类事件驱动:

  • 外部事件:Scheduler 决定是否把这条 Req 选入下一个 batch(从 Waiting 进入 Prefilling,或留在 Decoding 自环里)
  • Req 内部事件 :每完成一步 forward,调用 complete_one()(decode 阶段还要 append_host(next_token));每步结束查 can_decode 看是否要转移到 Done

can_decode 是这套状态机的循环条件 ------只要 remain_len > 0(还没用完用户给的 max_tokens 配额)且未被外部判停(EOS / abort),就继续在 Decoding 里转圈。

这套状态机谁来驱?

答:Scheduler。它每一步都会:

  1. 从 Waiting 队列里挑能进 prefill 的 Req
  2. 让正在 Decoding 的 Req 走一步
  3. 根据 can_decode 决定要不要把某些 Req 移到 Done

但本篇不展开 Scheduler------它是阶段 4 的整篇主题。这里只需要建立一个直觉:Req 是被驱动的状态机,而不是自己跑的协程。

5. SamplingParams 与 greedy 快路径

SamplingParams 是另一个核心 dataclass。它把"用户怎么从 logits 里选 token"的所有控制参数打包,独立成一个对象------这样同一批请求共享同一份采样配置时,不需要在每条 Req 里重复存这些字段。下面是它的定义,并对几种典型参数组合验证 is_greedy 的输出。

5 个字段的完整语义将在阶段 2 的 sampler 模块展开,本节聚焦其中一个 property:is_greedy

拆解 is_greedy 表达式

(temperature <= 0.0 or top_k == 1) and top_p == 1.0

三个子句各自的含义:

子句 含义
temperature <= 0.0 不做随机化,直接拿 logits 的 argmax
top_k == 1 只保留最大的 1 个候选,等价于 argmax
top_p == 1.0 不做 nucleus 截断,保留所有 token

下图把这三个参数各自做了什么画在一张图里------is_greedy 关心的恰好是它们都"无操作"的情形。

前两个子句是 OR 关系 ,因为各自单独都能让结果退化为 argmax;最后一个子句必须是 top_p == 1.0------如果 top_p < 1,即使是确定性采样,引擎仍然要算累积概率、做 sort 和 mask,根本走不了快路径。

temperature 的具体作用

temperature 把 logits 除以 T 后再做 softmax: <math xmlns="http://www.w3.org/1998/Math/MathML"> s o f t m a x ( x i / T ) \mathrm{softmax}(x_i / T) </math>softmax(xi/T)。它单一的作用是调节概率分布的尖锐程度:

  • T < 1(如 0.7):分布变尖,高概率 token 更高 → 输出更确定、更收敛
  • T = 1:模型原生分布
  • T > 1(如 1.5):分布变平,各候选概率拉近 → 输出更随机、更多样
  • T → 0 :分布退化到一个点(最大值为 1,其余为 0)→ 等价于 argmax,即 greedy

实际场景的常用取值:

T 适用场景
0(greedy) 代码生成、工具调用、JSON 输出------要稳定可复现
0.3 ~ 0.7 知识问答、摘要------要正确,允许少量措辞变化
0.8 ~ 1.2 创意写作、对话------要多样性

数学直觉很简单:把所有 logits 除以一个大数(T 大),logits 之间的反差被压缩,softmax 后概率被拉平;除以小数则反之。所以 T 控制的是"按多大反差去比较 logits"。

回到 is_greedy:T = 0 时只关心 argmax,根本不需要算 softmax------这就是它被列为快路径条件的原因。

为什么要专门判 "greedy"

完整采样路径需要:softmax → 排序 → top-k 截断 → top-p 截断 → 重新归一化 → 按概率采样 。在 GPU 上这一整套即使融合(阶段 2 用 flashinfer.sampling),也比直接 argmax 慢得多。

判一个布尔条件就能跳过整套流程,对 batch 中那些用 temperature=0 的请求(代码生成、工具调用、结构化输出)是非常划算的。这个 is_greedy property 就是阶段 2 引入 fused sampler 时,用来分流走快路径还是慢路径的入口。

具体怎么走?is_greedy 为真时,采样器直接执行 next_token = logits.argmax(-1),跳过 softmax → top-k → top-p → 重采样这一整套流程;为假时才走完整路径。

python 复制代码
"""§5 SamplingParams:定义 + 几种参数组合下的 is_greedy"""
from dataclasses import dataclass

@dataclass
class SamplingParams:
    temperature: float = 0.0    # 0 表示 deterministic
    top_k: int = -1             # -1 表示不裁
    top_p: float = 1.0          # 1.0 表示不裁
    ignore_eos: bool = False
    max_tokens: int = 1024

    @property
    def is_greedy(self) -> bool:
        # 三个子句联合判定 argmax 快路径:
        # T <= 0 或 top_k == 1 都退化为 argmax;但 top_p < 1 仍需走完整路径
        return (self.temperature <= 0.0 or self.top_k == 1) and self.top_p == 1.0


cases = [
    ("默认 (T=0)",          SamplingParams()),
    ("T=0.7",               SamplingParams(temperature=0.7)),
    ("T=0, top_p=0.9",      SamplingParams(top_p=0.9)),
    ("top_k=1",             SamplingParams(top_k=1)),
]
for name, sp in cases:
    print(f"{name:<22} → is_greedy = {sp.is_greedy}")
ini 复制代码
默认 (T=0)               → is_greedy = True
T=0.7                  → is_greedy = False
T=0, top_p=0.9         → is_greedy = False
top_k=1                → is_greedy = True

6. Req 如何拼进 Batch

到这里 Req 单条已经够用。但生产场景里 Scheduler 一步要把几十、几百条 Req 同时塞给模型------这就需要一个把多条 Req 装在一起的容器:Batch

一条 Req 进 Batch 时,需要补什么字段

模型 forward 只接受张量,不接受 Req 列表。所以组 Batch 这一步,本质上是把每条 Req 的状态翻译成跨请求共享的张量:

  • input_ids:把所有 Req 的 token id 在第 0 维首尾拼接,得到一个 1D 张量 (packed 布局,见下图)
  • cu_seqlens :累积偏移 [0, len₁, len₁+len₂, ...],告诉 attention 每条请求从哪开始、到哪结束
  • page_table:每条 Req 在 KV cache 里占哪几个 page (KV 分页本身留到阶段 2 引入)
  • out_loc:每条 Req 这一步算出来的新 token,要写回到 KV cache 的哪个位置

可以看到职责分工:每条 Req 自己 只描述"我有这些 token、生成进度到哪了";Batch 才描述"这一批请求在张量层面长什么样"。这种分工正是阶段 2 第 6 篇要完成的工作。

为什么不直接 padding 成二维矩阵

朴素方案是按最长请求把短请求补 <pad> 凑成 [batch_size, max_len] 的 2D 张量。三条 Req(长 4 / 6 / 3)按这种方案排出来就是 [3, 6] = 18 个格子,其中只有 13 个是真实 token,剩下 5 个(28%)是 padding:

不用 padding 的原因有三:

  1. 算力浪费:模型要给每个 padding 位置算 Q/K/V 投影和 MLP------即便最后被 mask 掉,前面这些计算实打实跑过了
  2. attention 是平方代价 :padding 位置的 query 也要对整段 max_len 算注意力分数,而不是它"应该"的较短长度,放大效应明显
  3. 真实方差更大 :生产环境里同一批的 prompt 长度可能从 50 到 50000,按最长那条 padding,90%+ 的位置都浪费了

而 packed 方案的 input_ids 长度严格等于真实 token 数之和(上图就是 4 + 6 + 3 = 13);cu_seqlens 替代"行号"告诉 attention 每条请求的边界。

FlashAttention 的 flash_attn_varlen_func 就是为这种 1D 变长输入设计的:kernel 用 cu_seqlens 把 Q × Kᵀ 隐式视作一个"块对角"结构,只算每个块内部的 causal 下三角,跨请求的位置完全跳过、不算、不读、不写。

具体 kernel 实现 (Q/K 怎么切分到 SM、softmax 怎么 online 计算) 在阶段 2 第 9 篇接 FlashAttention 时展开,这里只需要建立直觉:变长 packed 输入 + cu_seqlens 边界 = 一个有效计算只在块对角的注意力,这是它和 padded 2D 方案在 kernel 层面的根本区别。

本阶段的退化形式:batch_size = 1

阶段 1 不展开 Batch------Engine 拿到一条 Req 直接包装成单元素 batch 送入模型:cu_seqlens 退化成 [0, prompt_len],page_table 退化成单行,packed 布局退化成单段。这种"先单条跑通,再扩 batch"的节奏,让我们能先把模型 forward 走通,再回头处理多请求的拼装。

小结

一条推理请求的状态,可以浓缩到一个 dataclass 里:

  • 四个长度字段 同时支撑 prefill 与 decode 两阶段,关键是 extend_len 在两阶段含义不同但形式统一
  • complete_one / append_host 把 device 与 host 的更新职责拆开,为后续的批量优化留口子
  • SamplingParams.is_greedy 是一个布尔表达式,但决定了采样要走 argmax 一行,还是要走完整 softmax+top-k+top-p 一整套

下一篇预告

3 一个可运行的 Decoder Layer :用最朴素的 nn.Module 拼出一个 Qwen3 layer(QKV → RoPE → SDPA → O → MLP),attention 用 torch.nn.functional.scaled_dot_product_attention ------不做 TP、不做 fused kernel。这版代码会在阶段 3 整个推翻重写,但能让我们立刻看到模型跑起来的样子。

相关推荐
AskHarries8 小时前
OpenClaw 是什么?为什么它不是普通 AI Agent
人工智能·后端·程序员
sali-tec8 小时前
C# 基于OpenCv的视觉工作流-章75-线-线角度
图像处理·人工智能·opencv·算法·计算机视觉
AskHarries8 小时前
如何判断一个需求是真需求
人工智能·程序员·产品
ccice018 小时前
硬核教程:用Gemini编排多任务办公智能体,实现周报、数据表与行程单全自动生成(国内免费镜像方案)
人工智能·自动化
liux35288 小时前
第2章:核心功能篇 —— 记忆系统,让 AI 越用越懂你
人工智能·hermes
随身数智备忘录8 小时前
安全生产法详解:安全生产法如何规范企业安全管理行为?
大数据·人工智能
weixin_495248408 小时前
2026年AI视频翻译配音工具深度评测:多角色识别能力对比
人工智能·音视频
数智工坊8 小时前
PyCharm 运行 Python 脚本总自动进 Test 模式?附 RT-DETRv2 依赖缺失终极排坑
开发语言·ide·人工智能·python·pycharm
大熊背8 小时前
Binning模式下和Normal模式下加权平均亮度差异分析以及优化
人工智能·算法·自动曝光