【大模型推理】ScheduleBatch 学习

很好问题!让我详细解释EXTEND和预填充(prefill)的概念,以及它们与解码(decode)的区别。

1. EXTEND 的含义

定义

EXTEND模式指的是扩展序列的操作,即处理新的输入token并将其KV缓存写入缓存池。

具体场景

python 复制代码
# 场景1: 新请求的首次处理
输入: "Hello, how are you?"
操作: 将整个输入序列的KV缓存写入KV缓存池

# 场景2: 长文本的分块处理  
输入: "This is a very long document that needs to be processed in chunks..."
操作: 分块处理,每次处理一部分token

# 场景3: 流式输入追加
已处理: "The weather is"
新输入: " nice today"
操作: 只处理新追加的token

在代码中的体现

python 复制代码
def prepare_for_extend(self):
    self.forward_mode = ForwardMode.EXTEND
    
    # 计算需要扩展的token
    input_ids = [r.fill_ids[len(r.prefix_indices):] for r in self.reqs]
    extend_num_tokens = sum(len(ids) for ids in input_ids)
    
    # 这些token都是新的,需要计算并缓存它们的KV

2. Prefill (预填充) vs EXTEND

传统意义上的Prefill

Prefill 通常指首次填充,即处理整个输入提示(prompt):

python 复制代码
# 传统Prefill
输入: "Translate the following English to French: 'Hello world'"
输出: 无(只填充KV缓存)

SGLang中的EXTEND概念

在SGLang中,EXTEND是更广义的概念:

python 复制代码
# 包含传统Prefill,但更广泛
class ForwardMode(Enum):
    EXTEND    # 扩展序列(包含预填充和续写)
    DECODE    # 自回归解码生成
    MIXED     # 混合模式
    IDLE      # 空闲

关键区别

方面 传统Prefill SGLang EXTEND
范围 只处理初始prompt 处理任何新token
输出 通常不生成输出 可能生成输出token
缓存 填充初始KV缓存 可能利用已有前缀缓存

3. EXTEND 的具体工作流程

3.1 前缀缓存利用

python 复制代码
def init_next_round_input(self, tree_cache=None):
    self.fill_ids = self.origin_input_ids + self.output_ids
    
    # 查询前缀缓存,找到可重用的部分
    if tree_cache is not None:
        self.prefix_indices, self.last_node = tree_cache.match_prefix(
            rid=self.rid, key=self.adjust_max_prefix_ids()
        )
    
    # 只处理新的部分
    self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)

示例

复制代码
已有缓存: [1, 2, 3, 4]  # 前缀索引
完整序列: [1, 2, 3, 4, 5, 6, 7]  # fill_ids
EXTEND部分: [5, 6, 7]  # 只需要处理这些新token

3.2 内存分配策略

python 复制代码
def prepare_for_extend(self):
    # 分配请求槽位
    req_pool_indices = self.alloc_req_slots(len(self.reqs))
    
    # 分配KV缓存位置
    if self.token_to_kv_pool_allocator.page_size == 1:
        out_cache_loc = self.alloc_token_slots(extend_num_tokens)
    else:
        # 分页分配
        out_cache_loc = self.alloc_paged_token_slots_extend(...)
    
    # 建立映射关系
    write_req_to_token_pool_triton(...)

4. EXTEND vs DECODE 的对比

4.1 处理模式对比

python 复制代码
# EXTEND 模式 - 并行处理多个token
输入: [token1, token2, token3, ...]  # 可变长度序列
处理: 并行计算所有token的注意力
输出: 可能生成下一个token的概率分布

# DECODE 模式 - 单步自回归
输入: [last_token]  # 单个token(上一步的输出)
处理: 基于整个历史生成下一个token
输出: 下一个token的ID

4.2 计算特征对比

特征 EXTEND模式 DECODE模式
计算复杂度 O(n²) - 全注意力 O(n) - 增量注意力
并行度 高 - 多个token并行 低 - 单个token
内存访问 不规则 - 可变长度 规则 - 固定批次
KV缓存 写入新的缓存位置 读取现有缓存

4.3 实际代码对比

EXTEND 准备
python 复制代码
def prepare_for_extend(self):
    # 处理可变长度序列
    input_ids = [r.fill_ids[len(r.prefix_indices):] for r in self.reqs]
    extend_num_tokens = sum(len(ids) for ids in input_ids)  # 可变总数
    
    # 扁平化处理
    input_ids_tensor = torch.tensor(sum(input_ids, []), dtype=torch.int64)
    # 形状: [extend_num_tokens]
DECODE 准备
python 复制代码
def prepare_for_decode(self):
    # 处理固定批次(每个请求1个token)
    bs = len(self.reqs)
    
    # 输入是上一步的输出
    self.input_ids = self.output_ids  # 形状: [bs]
    
    # 每个请求分配1个新KV位置
    self.out_cache_loc = self.alloc_token_slots(bs)  # 形状: [bs]

5. EXTEND 的使用场景

5.1 新请求初始化

python 复制代码
# 用户发送新请求
request = Req(
    rid="req1",
    origin_input_text="What is the capital of France?",
    origin_input_ids=[123, 456, 789, ...],
    sampling_params=sampling_params
)

# 创建批次并准备EXTEND
batch = ScheduleBatch.init_new([request], ...)
batch.prepare_for_extend()  # 首次处理整个prompt

5.2 长文本分块处理

python 复制代码
# 长文档处理
def process_long_document(document_tokens, chunk_size=512):
    for i in range(0, len(document_tokens), chunk_size):
        chunk = document_tokens[i:i+chunk_size]
        
        # 准备EXTEND处理这个分块
        req.fill_ids = existing_tokens + chunk
        req.init_next_round_input(tree_cache)
        
        batch.prepare_for_extend()  # 处理这个分块
        # 执行模型前向传播...

5.3 流式输入追加

python 复制代码
# 流式对话场景
def handle_streaming_input(req, new_input_tokens):
    # 追加新输入到现有序列
    req.output_ids.extend(previous_output)
    req.fill_ids = req.origin_input_ids + req.output_ids + new_input_tokens
    
    # 只EXTEND新输入的部分
    req.init_next_round_input(tree_cache)
    batch.prepare_for_extend()

6. 性能优化考虑

6.1 EXTEND的挑战

python 复制代码
# 问题: 可变长度导致计算效率低
extend_num_tokens = sum(len(ids) for ids in input_ids)  # 可能很大且不规则

# 解决方案: 分块和填充
if enable_chunked_prefill:
    # 将长序列分成固定大小的块
    process_in_chunks(req, chunk_size=256)

6.2 前缀缓存的价值

python 复制代码
def init_next_round_input(self, tree_cache):
    # 利用前缀缓存避免重复计算
    self.prefix_indices, self.last_node = tree_cache.match_prefix(...)
    self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
    
    # 只有 extend_input_len 个token需要实际计算

性能收益

  • 减少计算量:只计算新token
  • 减少内存分配:重用已有KV缓存位置
  • 提高吞吐量:更快的处理速度

7. 实际执行流程示例

完整请求生命周期

python 复制代码
# 1. 新请求 - EXTEND处理整个prompt
batch.prepare_for_extend()
# 输入: [What, is, the, capital, of, France, ?]
# 输出: 第一个生成token的概率分布

# 2. 解码循环 - DECODE模式
for step in range(max_steps):
    batch.prepare_for_decode()
    # 输入: [last_generated_token]
    # 输出: 下一个token的概率分布
    
    # 采样得到新token
    new_token = sample_from_logits(logits)
    req.output_ids.append(new_token)
    
    # 检查是否结束
    req.check_finished()
    if req.finished():
        break

# 3. 如果有新输入 - 再次EXTEND
if has_new_user_input:
    req.init_next_round_input(tree_cache)
    batch.prepare_for_extend()  # 处理新输入

8. 与Profile的关系

Profile模式的作用

PROFILE模式通常用于性能分析,而不是实际推理:

python 复制代码
# Profile模式 - 测量性能,不产生实际输出
if profile_mode:
    # 测量EXTEND操作的性能
    start_time = time.time()
    batch.prepare_for_extend()
    extend_time = time.time() - start_time
    
    # 测量DECODE操作的性能  
    start_time = time.time()
    batch.prepare_for_decode()
    decode_time = time.time() - start_time
    
    logger.info(f"EXTEND: {extend_time:.3f}s, DECODE: {decode_time:.3f}s")

实际使用中的Profile

在真实系统中,Profile可能用于:

  1. 容量规划:测量系统能处理的最大批次大小
  2. 优化调试:识别性能瓶颈
  3. 资源分配:根据EXTEND/DECODE比例调整资源

总结

EXTEND在SGLang中是一个核心概念,它:

  1. 广义的预填充:不仅处理初始prompt,还处理任何新token
  2. 前缀缓存感知:智能利用共享前缀避免重复计算
  3. 性能关键路径:通常比DECODE更计算密集,需要精心优化
  4. 流式处理基础:支持长文本分块和流式输入

与传统的Prefill相比,EXTEND更灵活,支持更复杂的应用场景,是现代LLM推理服务的重要组成部分。

我们继续使用EXTEND模式来处理新的输入,这类似于预填充,但有一些关键区别。

在传统的预填充(prefill)中,我们通常是指处理整个提示(prompt)以生成第一个输出token。而在SGLang中,EXTEND模式用于处理任何新的输入token,这些新输入token可能是:

整个提示(第一次处理)

长提示中的后续分块(当使用分块处理时)

在流式对话中,用户新输入的消息(追加到现有对话之后)

所以,EXTEND模式比传统的预填充更通用。它允许我们利用前缀缓存(如果存在)来避免重复计算,只计算新token的KV缓存。

具体来说,当我们有新的用户输入时,我们会将新输入追加到现有序列中,然后通过EXTEND模式处理这些新token。这与第一次预填充的不同之处在于:

第一次预填充(整个提示)时,前缀缓存是空的,所以我们需要计算整个提示的KV缓存。

当处理后续分块或新输入时,我们可能已经缓存了部分序列(前缀),所以只需要计算新追加的token。

因此,EXTEND模式是一种增量式的预填充,它利用已有的缓存来减少计算量。

在代码中,我们通过以下步骤准备EXTEND:

将新输入追加到请求的fill_ids(即完整序列:原始输入+已生成输出+新输入)

调用init_next_round_input来更新前缀缓存和计算需要扩展的token数量(extend_input_len)

调用prepare_for_extend来分配KV缓存位置并构建模型输入。

这样,模型只会处理新添加的token,而不会重新处理整个序列。

举个例子:

假设我们有一个请求,已经处理了初始提示:"Hello, how are you?",并生成了一些输出:"I'm fine."。现在用户又输入了新的消息:"What about you?"。

那么,完整的序列是:"Hello, how are you? I'm fine. What about you?"。

但是,由于我们已经缓存了"Hello, how are you? I'm fine."的KV,所以我们只需要处理新追加的"What about you?"部分。

这就是EXTEND模式与传统预填充的区别:它允许我们利用缓存来增量处理,而不是每次都要处理整个序列。

因此,在代码中,当我们有新的用户输入时,我们会再次进入EXTEND模式,但只处理新输入的部分,从而大大提高效率。

非常好问题!让我详细解释"再次EXTEND"的含义,以及它与传统Prefill的根本区别。

非常好问题!让我详细解释"再次EXTEND"的含义,以及它与传统Prefill的根本区别。

1. 场景示例:多轮对话

用户与AI的对话流程

python 复制代码
# 第一轮对话
用户: "What is the capital of France?"  # 初始输入
AI: "The capital of France is Paris."

# 第二轮对话  
用户: "And what about Germany?"  # 新输入
AI: "The capital of Germany is Berlin."

在系统中的处理流程

python 复制代码
# 第一轮:处理初始问题
req1 = Req(rid="conv1", input_text="What is the capital of France?")
batch.prepare_for_extend()  # 传统Prefill:处理整个问题
# → 生成回答: "The capital of France is Paris."

# 第二轮:处理追加问题
req1.add_input("And what about Germany?")  # 追加新输入
req1.init_next_round_input(tree_cache)
batch.prepare_for_extend()  # 再次EXTEND:只处理新问题
# → 生成回答: "The capital of Germany is Berlin."

2. 再次EXTEND vs 传统Prefill的关键区别

2.1 计算范围不同

传统Prefill(首次EXTEND)
python 复制代码
# 处理整个序列
输入序列: [What, is, the, capital, of, France, ?]
需要计算的token: 7个
KV缓存写入: 7个新位置
前缀缓存: [] (空的,没有可重用的)
再次EXTEND(增量EXTEND)
python 复制代码
# 只处理新增部分
完整序列: [What, is, the, capital, of, France, ?, The, capital, of, France, is, Paris, ., And, what, about, Germany, ?]
                            ↑ 已缓存部分 ↑                ↑ 已生成部分 ↑          ↑ 新增输入部分 ↑
前缀缓存: [What, is, the, capital, of, France, ?, The, capital, of, France, is, Paris, .]
需要计算的token: [And, what, about, Germany, ?]  # 只有5个
KV缓存写入: 5个新位置

2.2 前缀缓存利用

代码层面的差异
python 复制代码
def init_next_round_input(self, tree_cache):
    # 构建完整序列
    self.fill_ids = self.origin_input_ids + self.output_ids + new_input_ids
    
    # 关键:查询前缀缓存,找到可重用的部分
    if tree_cache is not None:
        self.prefix_indices, self.last_node = tree_cache.match_prefix(
            key=self.adjust_max_prefix_ids()  # 查询已有缓存
        )
    
    # 只计算新增部分
    self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
    # 首次Prefill: extend_input_len = 整个序列长度
    # 再次EXTEND: extend_input_len = 新增部分长度

2.3 内存分配差异

首次Prefill的内存分配
python 复制代码
def prepare_for_extend(self):
    # 整个序列都是新的
    input_ids = [r.fill_ids[len(r.prefix_indices):] for r in self.reqs]  # 整个序列
    extend_num_tokens = sum(len(ids) for ids in input_ids)  # 可能很大
    
    # 需要分配大量KV缓存位置
    out_cache_loc = self.alloc_token_slots(extend_num_tokens)
再次EXTEND的内存分配
python 复制代码
def prepare_for_extend(self):
    # 只有新增部分是新的
    input_ids = [r.fill_ids[len(r.prefix_indices):] for r in self.reqs]  # 只有新增部分
    extend_num_tokens = sum(len(ids) for ids in input_ids)  # 通常较小
    
    # 只需要分配少量KV缓存位置
    out_cache_loc = self.alloc_token_slots(extend_num_tokens)  # 更少的内存需求

3. 具体技术实现细节

3.1 前缀缓存的工作原理

python 复制代码
# 假设的对话历史
history = {
    "session1": [
        [1, 2, 3, 4],      # "What is the capital"
        [1, 2, 3, 4, 5],   # "What is the capital of"  
        [1, 2, 3, 4, 5, 6] # "What is the capital of France"
    ]
}

def match_prefix(self, key):
    """在缓存树中查找最长的匹配前缀"""
    # key = [1, 2, 3, 4, 5, 6, 10, 11]  # 完整序列
    # 返回: [1, 2, 3, 4, 5, 6] 的缓存位置
    # 剩余: [10, 11] 需要处理
    return prefix_indices, last_node

3.2 序列状态的变化

首次Prefill后
python 复制代码
req = Req(...)
# 初始状态
req.origin_input_ids = [1, 2, 3, 4, 5, 6]  # "What is the capital of France?"
req.output_ids = [7, 8, 9, 10, 11, 12]     # "The capital of France is Paris."
req.fill_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]

# 前缀缓存状态
req.prefix_indices = [100, 101, 102, 103, 104, 105]  # 对应KV缓存位置
收到新输入后
python 复制代码
# 用户新输入
new_input = [13, 14, 15, 16, 17]  # "And what about Germany?"

# 更新序列
req.fill_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]
#                                ↑ 已缓存 ↑              ↑ 新输入 ↑

# 重新查询前缀缓存
req.init_next_round_input(tree_cache)
# prefix_indices = [100, 101, 102, 103, 104, 105, ...]  # 可能扩展到输出部分
# extend_input_len = 5  # 只有新输入需要处理

4. 性能优势分析

4.1 计算复杂度对比

传统方法(无缓存)
python 复制代码
# 每次都要处理整个历史
总计算量 = O(L²)  # L是累计序列长度
第二轮: O(17²) = 289
第三轮: O(22²) = 484  # 越来越慢
再次EXTEND(有缓存)
python 复制代码
# 只处理新增部分
总计算量 = O(L_existing²) + O(L_new²)  # 但L_existing的KV已缓存
第二轮: O(5²) = 25  # 只处理5个新token
第三轮: O(5²) = 25  # 保持稳定

4.2 内存使用对比

KV缓存重用
python 复制代码
# 首次Prefill
KV缓存分配: 位置[100-111]  # 12个位置

# 再次EXTEND  
KV缓存分配: 位置[112-116]  # 只新增5个位置
# 位置[100-111] 被重用,不需要重新计算

5. 实际应用场景

5.1 多轮对话系统

python 复制代码
class ChatSession:
    def __init__(self, session_id):
        self.session_id = session_id
        self.req = None
        
    def add_message(self, user_input):
        if self.req is None:
            # 首次Prefill
            self.req = Req(rid=self.session_id, input_text=user_input)
            batch.prepare_for_extend()
        else:
            # 再次EXTEND
            self.req.add_input(user_input)
            self.req.init_next_round_input(tree_cache)
            batch.prepare_for_extend()
        
        # 执行推理并返回结果
        return self.execute_inference()

5.2 长文档处理

python 复制代码
def process_long_document(document, chunk_size=512):
    req = Req(rid="doc1", input_text=document[:chunk_size])
    
    # 首次Prefill第一个分块
    batch.prepare_for_extend()
    
    # 处理剩余分块
    for i in range(chunk_size, len(document), chunk_size):
        chunk = document[i:i+chunk_size]
        req.add_input(chunk)
        req.init_next_round_input(tree_cache)
        batch.prepare_for_extend()  # 再次EXTEND处理新分块

5.3 流式输入应用

python 复制代码
# 实时翻译或语音识别
def stream_processor():
    req = Req(rid="stream1")
    
    while True:
        new_tokens = get_new_tokens_from_stream()
        if new_tokens:
            req.add_input(new_tokens)
            req.init_next_round_input(tree_cache)
            
            if req.extend_input_len > 0:  # 有新内容需要处理
                batch.prepare_for_extend()
                yield generate_output()

6. 与传统方法的根本区别

6.1 状态持续性

传统Prefill(无状态)
python 复制代码
# 每次请求都是独立的
def handle_request(request_text):
    req = Req(input_text=request_text)  # 新建请求
    batch.prepare_for_extend()  # 完整处理
    # 完成后丢弃所有状态
再次EXTEND(有状态)
python 复制代码
# 保持会话状态
def handle_conversation(session_id, new_input):
    req = get_session(session_id)  # 获取已有请求
    req.add_input(new_input)       # 追加新输入
    req.init_next_round_input(tree_cache)  # 智能确定需要处理的部分
    batch.prepare_for_extend()     # 增量处理
    # 保持所有状态供下次使用

6.2 缓存策略

传统:每次重新计算
复制代码
用户: "A B C" → 计算: A, B, C
AI: "X Y Z"
用户: "D E"   → 计算: A, B, C, X, Y, Z, D, E  # 重复计算!
再次EXTEND:智能重用
复制代码
用户: "A B C" → 计算: A, B, C → 缓存: [A, B, C]
AI: "X Y Z"   → 计算: X, Y, Z → 缓存: [A, B, C, X, Y, Z]  
用户: "D E"   → 重用: [A, B, C, X, Y, Z] → 只计算: D, E

7. 代码层面的具体差异

7.1 init_next_round_input 的关键作用

python 复制代码
def init_next_round_input(self, tree_cache, enable_hierarchical_cache=False):
    # 更新完整序列
    self.fill_ids = self.origin_input_ids + self.output_ids
    
    # 关键区别点:
    if tree_cache is not None:
        # 再次EXTEND:这里能找到大量匹配的前缀
        # 首次Prefill:这里可能只能找到很少或没有匹配
        self.prefix_indices, self.last_node = tree_cache.match_prefix(
            key=self.adjust_max_prefix_ids()
        )
    
    # 计算需要实际处理的部分
    self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
    # 首次Prefill: extend_input_len ≈ 整个序列长度
    # 再次EXTEND: extend_input_len ≪ 整个序列长度

7.2 prepare_for_extend 的实际影响

python 复制代码
def prepare_for_extend(self):
    # 输入数据收集 - 关键差异点
    input_ids = [r.fill_ids[len(r.prefix_indices):] for r in self.reqs]
    
    # 首次Prefill: input_ids 包含整个prompt
    # 再次EXTEND: input_ids 只包含新增部分
    
    extend_num_tokens = sum(len(ids) for ids in input_ids)
    # 这个数字在再次EXTEND时显著变小
    
    # 内存分配相应减少
    out_cache_loc = self.alloc_token_slots(extend_num_tokens)

总结

再次EXTEND 与传统Prefill的根本区别:

方面 传统Prefill 再次EXTEND
计算范围 整个输入序列 只新增部分
缓存利用 无或很少 大量重用已有缓存
内存分配 分配整个序列 只分配新增部分
性能特征 计算密集型 相对轻量
应用场景 新请求初始化 会话延续、流式输入

核心价值 :再次EXTEND使得LLM能够以增量方式处理输入,而不是每次都重新处理整个历史,这对于多轮对话、长文档处理和流式应用至关重要,能显著提升性能和用户体验。

在 SGLang(以及大多数现代 LLM 推理引擎,如 vLLM)中,extend 模式 (也常被称为 prefill 模式 )是指 对一个新请求(或新 chunk)的输入序列进行首次 KV Cache 填充和 logits 计算的过程


一、核心定义

  • ForwardMode.EXTEND :表示当前 batch 中的所有请求都在执行 prefill(首次处理)
  • 它对应的是 从输入 token 到生成第一个输出 token 之前的完整前向计算

✅ 简单说:extend = prefill,只是 SGLang 用词为 "extend"。


二、为什么叫 "extend"?

因为:

  • 每个请求可能已有 共享前缀 (来自 RadixCache / ChunkCache
  • 当前操作是 "扩展"这个前缀 ,把 尚未计算的 token 部分 (即 extend_input_len)填入 KV Cache

例如:

  • 请求已有前缀 100 个 token(已缓存)
  • 输入总长 120 个 token
  • extend_input_len = 20,需要 extend 这 20 个 token

三、与 decode 模式的对比

特性 extend(prefill) decode(自回归生成)
输入长度 可变(1 ~ 几千) 固定为 1(每个请求 1 个 token)
KV Cache 操作 写入多个新 token 的 K/V 写入 1 个新 token 的 K/V
计算量 大(O(n²) Attention) 小(O(n) Attention)
batch shape 不规则(ragged) 规则(可启用 CUDA Graph)
是否支持 CUDA Graph ❌ 否 ✅ 是
典型场景 请求首次进入、chunked prefill 生成第 2、3、4... 个 token

四、在 SGLang 中的关键字段(ForwardBatch

forward_mode == ForwardMode.EXTEND 时,以下字段有效:

  • extend_num_tokens: 本次 batch 总共要 extend 的 token 数
  • extend_seq_lens: 每个请求要 extend 的 token 数(如 [20, 30, 10]
  • extend_prefix_lens: 每个请求已缓存的前缀长度(如 [100, 50, 0]
  • positions: 每个 extend token 的绝对位置(如 [100,101,...,119, 50,51,...,79, 0,1,...,9]
  • extend_start_loc: 每个请求在 flatten input_ids 中的起始偏移

五、工作流程示例

python 复制代码
# 请求:input_ids = [1,2,3,4,5],无前缀缓存
req.prefix_indices = []          # 已缓存 0 个
req.fill_ids = [1,2,3,4,5]       # 完整序列
req.extend_input_len = 5         # 需要 extend 5 个 token

# 调度器分配 KV slots
out_cache_loc = [100, 101, 102, 103, 104]

# 写入 ReqToTokenPool
req_to_token_pool[req.req_pool_idx, 0:5] = [100,101,102,103,104]

# 构建 ForwardBatch
forward_batch = ForwardBatch(
    forward_mode=ForwardMode.EXTEND,
    input_ids=[1,2,3,4,5],
    positions=[0,1,2,3,4],
    extend_prefix_lens=[0],
    extend_seq_lens=[5],
    out_cache_loc=[100,101,102,103,104],
    ...
)

# ModelRunner 执行前向
logits = model_runner.forward(forward_batch)  # 输出 5 个位置的 logits
next_token = sample(logits[-1])  # 通常只采样最后一个

六、相关模式

SGLang 还定义了其他与 extend 相关的模式:

  • ForwardMode.MIXED:batch 中 同时包含 extend 和 decode 请求(用于 overlap schedule)
  • ForwardMode.DRAFT_EXTEND:在 EAGLE 草稿模型 中执行 extend
  • ForwardMode.TARGET_VERIFY:在 主模型 中验证草稿 token(本质是 extend 整个序列)

总结

术语 含义
extend 对请求的 未缓存部分 进行首次前向计算(= prefill)
目的 填充 KV Cache,计算 logits,为 decode 阶段做准备
特点 输入长度可变、计算量大、不支持 CUDA Graph
触发时机 请求首次进入、chunked prefill 的每个 chunk

理解 extend 模式,是掌握 LLM 推理中 prefill/decode 分离优化 的关键。

init_next_round_input 是 SGLang 中 Req(请求)类的一个关键方法,它的核心作用是:

为下一轮调度/前向计算准备输入状态,包括更新完整 token 序列(fill_ids)、匹配前缀缓存(prefix_indices),并计算需要新处理的 token 数量(extend_input_len


一、函数签名与上下文

python 复制代码
def init_next_round_input(
    self,
    tree_cache: Optional[BasePrefixCache] = None,
    enable_hierarchical_cache=False,
):
  • 调用时机 :每次请求即将进入 prefill 或 decode 阶段前 ,由调度器(Scheduler)调用。
  • 目的:让请求对象"知道"自己当前的状态,以便调度器决定如何分配内存、构建 batch。

二、逐行详解

1. 重置 placeholder embeddings(多模态支持)

python 复制代码
self.input_placeholder_embs = self.original_input_placeholder_embs
  • 如果请求包含多模态 placeholder(如 <image>),恢复原始嵌入,避免上一轮修改污染。

2. 构建完整 token 序列 fill_ids

python 复制代码
self.fill_ids = self.origin_input_ids + self.output_ids
  • origin_input_ids:用户原始输入 token(如 prompt)
  • output_ids:模型已生成的 token
  • fill_ids = 完整上下文 = [prompt tokens] + [generated tokens]

✅ 这是当前请求的 完整逻辑序列,用于后续 KV Cache 匹配和 extend 计算。


3. 前缀缓存匹配(Radix/Chunk/HiRadix Cache)

python 复制代码
if tree_cache is not None:
    if enable_hierarchical_cache:
        self.prefix_indices, self.last_node, self.last_node_global = (
            tree_cache.match_prefix(
                key=self.adjust_max_prefix_ids(), include_evicted=True
            )
        )
    else:
        self.prefix_indices, self.last_node = tree_cache.match_prefix(
            rid=self.rid, key=self.adjust_max_prefix_ids()
        )
关键点:
  • tree_cache.match_prefix(...):在前缀缓存中查找 fill_ids 的最长匹配前缀。
  • 返回:
    • prefix_indices:匹配到的 token 在 KV Cache 中的物理 slot 索引列表 (如 [100,101,102]
    • last_node:缓存树中对应的节点(用于后续 evict/lock)

✅ 这一步实现了 KV Cache 前缀共享,避免重复 prefill。


4. 处理 Hierarchical Cache 的特殊情况

python 复制代码
elif enable_hierarchical_cache:
    while self.last_node.evicted:
        # 如果 last_node 被驱逐,回退 prefix_indices
        self.prefix_indices = self.prefix_indices[:-len(self.last_node.host_value)]
        self.last_node = self.last_node.parent
  • HiRadixCache 中,部分节点可能被换出到 CPU。
  • 如果 last_node 已被 evict,则逐步回退,直到找到未 evict 的节点。

5. 计算需要新处理的 token 数(核心输出)

python 复制代码
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
  • extend_input_len = 需要新计算 KV Cache 的 token 数
  • 例如:
    • fill_ids = [1,2,3,4,5](总长 5)
    • prefix_indices = [100,101](前 2 个已缓存)
    • extend_input_len = 3(需处理 token 3,4,5)

✅ 调度器将根据此值决定分配多少 out_cache_loc(KV Cache slot)。


三、辅助方法:adjust_max_prefix_ids()

python 复制代码
def adjust_max_prefix_ids(self):
    self.fill_ids = self.origin_input_ids + self.output_ids
    input_len = len(self.fill_ids)
    max_prefix_len = input_len - 1  # 至少留 1 个 token 用于生成 logits
    if self.return_logprob:
        max_prefix_len = min(max_prefix_len, self.logprob_start_len)
    return self.fill_ids[:max_prefix_len]
  • 作用:限制前缀匹配的最大长度,避免影响 logprob 计算或生成。
  • 关键逻辑
    • 如果需要返回 logprob,前缀不能超过 logprob_start_len
    • 总是保留至少 1 个 token 用于当前 step 的 logits 计算。

四、典型工作流程示例

场景:用户输入 "Hello",模型已生成 " world"

  1. 初始状态

    • origin_input_ids = [100, 101]("Hello")
    • output_ids = [200, 201](" world")
    • fill_ids = [100,101,200,201]
  2. 调用 init_next_round_input()

    • tree_cache.match_prefix([100,101,200,201]) → 返回 prefix_indices = [50,51,52](前 3 个已缓存)
    • extend_input_len = 4 - 3 = 1
  3. 调度器行为

    • 分配 1 个新 KV slot(如 out_cache_loc = [53]
    • 构建 input_ids = [201](最后一个 token 作为 decode 输入)
    • 写入 req_to_token_pool[req_idx, 3] = 53

五、为什么这个函数重要?

  1. 实现前缀共享:大幅减少重复 prefill,提升吞吐。
  2. 支持 chunked prefill:大 prompt 分块处理时,每块都调用此函数。
  3. 支持 retract:OOM 时回退请求后,重新调用此函数恢复状态。
  4. logprob 兼容:确保 logprob 计算区域不被错误缓存。

总结

步骤 作用
构建 fill_ids 获取完整上下文
匹配 prefix_indices 查找已缓存的前缀
计算 extend_input_len 确定需要新处理的 token 数
处理 hierarchical cache 支持 CPU/GPU 混合缓存

init_next_round_inputSGLang 高效调度与缓存机制的核心枢纽,它让每个请求"自知状态",从而实现灵活、高性能的批处理。

这里明确说是为了绕过logprob计算中的bug,所以暂时确保每个请求至少有一个token(不被缓存?)。也就是说,我们最多只能缓存到倒数第二个token,最后一个token不被缓存。这样,在扩展阶段,我们就会处理最后一个token(以及之后可能的新token)。但是,为什么这样做能绕过logprob计算的bug呢?我们不得而知。

总结

adjust_max_prefix_ids函数的作用是确定可以用于前缀缓存的最大token序列。它返回一个子序列,这个子序列是当前完整序列(输入+输出)的前缀,并且这个前缀的长度最多为input_len-1,即至少保留一个token不被缓存。

这样做的原因可能是由于logprob计算中的bug,暂时采用这种策略。在未来,这个限制可能会被放宽。

举例

假设当前请求的fill_ids是[1,2,3,4,5](长度为5)。

那么,adjust_max_prefix_ids返回的前缀可能是[1,2,3,4](长度为4),即最多缓存前4个token,第5个token不被缓存。

在接下来的扩展阶段,我们将处理第5个token(以及可能的新token)。

很好问题!我来详细解释adjust_max_prefix_ids的作用,特别是"确保至少1个token用于生成"的含义。

1. 函数完整代码

python 复制代码
def adjust_max_prefix_ids(self):
    self.fill_ids = self.origin_input_ids + self.output_ids
    input_len = len(self.fill_ids)

    # FIXME: To work around some bugs in logprob computation, we need to ensure each
    # request has at least one token. Later, we can relax this requirement and use `input_len`.
    max_prefix_len = input_len - 1

    if self.sampling_params.max_new_tokens > 0:
        # Need at least one token to compute logits
        max_prefix_len = min(max_prefix_len, input_len - 1)

    if self.return_logprob:
        max_prefix_len = min(max_prefix_len, self.logprob_start_len)

    max_prefix_len = max(max_prefix_len, 0)
    return self.fill_ids[:max_prefix_len]

2. "确保至少1个token用于生成"的含义

2.1 LLM生成的基本原理

在LLM中,生成下一个token需要:

  • 输入: 前N个token
  • 输出: 第N+1个token的概率分布
python 复制代码
# 生成过程示例
输入: [1, 2, 3, 4]    # "The cat sat on"
输出: [5]             # "the" 的概率分布

# 如果所有token都被缓存为前缀,就没有token用于生成了

2.2 具体例子说明

例子1:正常生成情况
python 复制代码
# 当前序列
fill_ids = [1, 2, 3, 4, 5]  # "The cat sat on the"
input_len = 5

# 调整后的最大前缀长度
max_prefix_len = input_len - 1  # = 4

# 这意味着:
# 前缀缓存最多可以缓存: [1, 2, 3, 4]  # "The cat sat on"
# 必须保留: [5]  # "the" 用于生成下一个token
例子2:如果缓存所有token会怎样
python 复制代码
# 错误情况:缓存所有token
fill_ids = [1, 2, 3, 4, 5]  # 完整序列
max_prefix_len = 5  # 错误!缓存了所有token

# 问题:没有token用于计算下一个token的概率
# 输入序列: []  # 空,因为没有未缓存的token
# 无法生成下一个token!

3. 函数执行步骤详解

3.1 基础限制

python 复制代码
max_prefix_len = input_len - 1  # 核心限制:保留最后1个token

为什么是input_len - 1

  • 序列有input_len个token
  • 需要至少1个token作为生成下一个token的输入
  • 所以最多只能缓存input_len - 1个token

3.2 考虑生成需求

python 复制代码
if self.sampling_params.max_new_tokens > 0:
    # Need at least one token to compute logits
    max_prefix_len = min(max_prefix_len, input_len - 1)

作用:如果还需要生成新token,必须确保有token用于计算logits。

3.3 考虑logprob计算

python 复制代码
if self.return_logprob:
    max_prefix_len = min(max_prefix_len, self.logprob_start_len)

logprob计算的特殊要求

  • logprob计算需要知道每个token的前一个token
  • logprob_start_len指定从哪个位置开始计算logprob
  • 不能缓存超过这个位置的token

3.4 边界保护

python 复制代码
max_prefix_len = max(max_prefix_len, 0)  # 确保非负

4. 具体场景分析

场景1:新请求的预填充

python 复制代码
# 新请求,没有输出
req.origin_input_ids = [1, 2, 3, 4]  # "What is AI"
req.output_ids = []                   # 无输出
req.fill_ids = [1, 2, 3, 4]
input_len = 4

# adjust_max_prefix_ids 过程:
max_prefix_len = 4 - 1 = 3           # 保留最后1个token
# 返回: [1, 2, 3]  # 最多缓存前3个token

为什么这样设计?

  • 缓存前3个token的KV值:[1, 2, 3] → "What is AI"
  • 使用第4个token "AI" 生成第一个输出token

场景2:生成过程中的调整

python 复制代码
# 生成过程中
req.origin_input_ids = [1, 2, 3]     # "Hello world"
req.output_ids = [4, 5]              # "How are"
req.fill_ids = [1, 2, 3, 4, 5]
input_len = 5

# adjust_max_prefix_ids 过程:
max_prefix_len = 5 - 1 = 4           # 保留最后1个token
# 返回: [1, 2, 3, 4]  # 最多缓存前4个token

作用

  • 缓存:[1, 2, 3, 4] → "Hello world How are"
  • 使用:[5] → "are" 生成下一个token "you"

场景3:logprob计算的影响

python 复制代码
# 需要计算logprob的情况
req.origin_input_ids = [1, 2, 3, 4, 5, 6]  # 长序列
req.output_ids = [7, 8]
req.fill_ids = [1, 2, 3, 4, 5, 6, 7, 8]
input_len = 8
req.return_logprob = True
req.logprob_start_len = 5            # 从第5个token开始计算logprob

# adjust_max_prefix_ids 过程:
max_prefix_len = 8 - 1 = 7           # 基础限制
max_prefix_len = min(7, 5) = 5       # logprob限制更严格
# 返回: [1, 2, 3, 4, 5]  # 最多缓存到logprob_start_len

为什么logprob需要这个限制?

python 复制代码
# logprob计算示例
序列: [1, 2, 3, 4, 5, 6, 7, 8]
logprob_start_len = 5

# 需要计算token 6,7,8的logprob
# 但计算token6的logprob需要token5的隐藏状态
# 如果token5被缓存了,就无法计算token6的logprob

5. 技术原理深度解析

5.1 Transformer的生成机制

python 复制代码
# Transformer生成下一个token的过程
def generate_next_token(sequence):
    # 输入: 整个序列的token [t1, t2, ..., tn]
    # 输出: 下一个token tn+1的概率分布
    
    # 1. 计算序列的隐藏状态
    hidden_states = transformer(sequence)  # 形状: [n, hidden_size]
    
    # 2. 只使用最后一个token的隐藏状态预测下一个token
    last_hidden = hidden_states[-1]       # 形状: [hidden_size]
    next_token_logits = lm_head(last_hidden)  # 形状: [vocab_size]
    
    return next_token_logits

关键点 :生成只需要最后一个token的隐藏状态 ,但计算这个隐藏状态需要所有前序token的KV缓存

5.2 前缀缓存与生成的协调

python 复制代码
# 有前缀缓存时的生成过程
def generate_with_prefix_cache(sequence, prefix_indices):
    # sequence: 完整序列 [1, 2, 3, 4, 5]
    # prefix_indices: 已缓存的KV位置 [100, 101, 102] (对应token 1,2,3)
    
    # 1. 从缓存中读取前3个token的KV
    k_cache = get_k_cache(prefix_indices)  # 形状: [3, heads, dim]
    v_cache = get_v_cache(prefix_indices)
    
    # 2. 只计算后2个token的KV并生成
    new_tokens = sequence[3:]  # [4, 5]
    new_k, new_v = compute_kv(new_tokens)
    
    # 3. 合并KV缓存进行注意力计算
    full_k = concat(k_cache, new_k)
    full_v = concat(v_cache, new_v)
    
    # 4. 使用最后一个token生成
    last_hidden = attention_with_kv_cache(full_k, full_v, sequence)
    next_token_logits = lm_head(last_hidden[-1])
    
    return next_token_logits

6. 实际应用中的影响

6.1 对性能的影响

python 复制代码
# 序列长度: 1000个token
# 情况A: 缓存999个token,计算1个新token
计算量 ≈ O(1²)  # 主要计算1个新token

# 情况B: 缓存0个token,计算1000个token  
计算量 ≈ O(1000²)  # 计算整个序列

# 性能差异: 1000000倍!

6.2 对内存的影响

python 复制代码
# KV缓存内存占用
缓存999个token: 需要999个KV位置
缓存0个token: 需要0个KV位置(但计算代价高)

# 权衡: 用内存换计算速度

7. 特殊情况处理

7.1 极短序列

python 复制代码
# 只有1个token的情况
req.fill_ids = [1]    # 单个token
input_len = 1

# adjust_max_prefix_ids:
max_prefix_len = 1 - 1 = 0  # 不能缓存任何token
# 返回: []  # 空列表

处理:必须计算整个序列,因为没有token可用于生成。

7.2 边界条件保护

python 复制代码
max_prefix_len = max(max_prefix_len, 0)  # 确保非负

作用:防止在空序列或极短序列时出现负数。

8. 总结

adjust_max_prefix_ids的核心作用是:

  1. 生成保障:确保至少1个token可用于生成下一个token
  2. 性能优化:最大化前缀缓存的使用,减少计算量
  3. 功能协调:平衡生成需求与logprob计算需求
  4. 边界保护:处理各种极端序列长度情况

关键理解:在LLM推理中,我们不能把所有token都缓存为前缀,必须保留至少一个token作为"种子"来生成后续内容。这个函数就是智能地确定可以安全缓存的最大前缀长度,在性能和功能之间找到最佳平衡点。

在你提供的 SGLang 代码中,并没有一个名为 alloc_token_slot 的独立函数,但有一个非常关键且功能完全对应的方法:

ScheduleBatch.alloc_token_slots(num_tokens: int, backup_state: bool = False)

这是你实际想了解的函数。下面我将详细讲解它的 作用、实现逻辑、调用时机和使用方式


一、函数签名与目的

python 复制代码
def alloc_token_slots(self, num_tokens: int, backup_state: bool = False):

功能:

num_tokens 个新 token 分配 KV Cache 中的物理存储位置(slot indices)

  • 返回一个 torch.Tensor,shape [num_tokens],每个元素是全局 token slot ID(如 [100, 101, 102])。
  • 如果内存不足,会尝试从 前缀缓存(Radix/Chunk Cache)驱逐(evict) 可回收的请求以释放空间。
  • 如果仍不足,抛出 RuntimeError(OOM)。

二、内部实现逻辑(逐行解析)

python 复制代码
def alloc_token_slots(self, num_tokens: int, backup_state: bool = False):
    # 1. 检查内存是否足够;若不足,尝试从 tree_cache 驱逐
    if self.token_to_kv_pool_allocator.available_size() < num_tokens:
        if self.tree_cache is not None:
            self.tree_cache.evict(num_tokens)  # 释放至少 num_tokens 个 slots

    # 2. (可选)备份当前空闲列表状态,用于失败回滚
    if backup_state:
        state = self.token_to_kv_pool_allocator.backup_state()

    # 3. 执行实际分配
    out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)

    # 4. 处理分配失败
    if out_cache_loc is None:
        phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
        error_msg = f"{phase_str} out of memory. ..."
        logger.error(error_msg)
        raise RuntimeError(error_msg)

    # 5. 返回结果
    if backup_state:
        return out_cache_loc, state
    else:
        return out_cache_loc

三、核心组件协作

组件 作用
TokenToKVPoolAllocator 管理空闲 slot 列表(free_slots),提供 alloc() / free() 接口
tree_cache(Radix/Chunk/HiRadix Cache) 前缀缓存;当内存不足时,evict(N) 会释放至少 N 个 token 的 KV Cache
out_cache_loc 分配结果:token → KV Cache slot 的映射

设计亮点 :通过 evict 实现 弹性内存管理,避免直接拒绝新请求。


四、调用时机与使用场景

1. Prefill 阶段(prepare_for_extend

python 复制代码
# 计算需要 extend 的总 token 数
extend_num_tokens = sum(len(ids) for ids in input_ids)

# 分配 slots
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
  • 用于为 新请求的 promptchunked prefill 的 chunk 分配 KV Cache。

2. Decode 阶段(prepare_for_decode

python 复制代码
# 每个请求生成 1 个新 token,共 bs 个
self.out_cache_loc = self.alloc_token_slots(bs)
  • 用于为 自回归生成 分配新 token 的 KV Cache。

3. Retract 场景(内存不足时回退请求)

  • 虽然 retract_decode 主要调用 free(),但后续重新调度时会再次调用 alloc_token_slots

五、返回值 out_cache_loc 的用途

分配得到的 out_cache_loc 会被用于:

  1. 写入 ReqToTokenPool(建立请求 ↔ token 位置映射):

    python 复制代码
    req_to_token_pool.write((req_idx, pos_slice), out_cache_loc[...])
  2. 构建 ModelWorkerBatchForwardBatch

    python 复制代码
    ForwardBatch(out_cache_loc=out_cache_loc, ...)
  3. 模型前向时写入 KV Cache

    python 复制代码
    kv_cache.set_kv_buffer(loc=out_cache_loc, cache_k=..., cache_v=...)

✅ 它是 调度器与模型执行器之间的"内存契约"


六、与分页(Paged KV Cache)的关系

  • page_size == 1(默认):使用 alloc_token_slots
  • page_size > 1:使用 alloc_paged_token_slots_extend / alloc_paged_token_slots_decode
    • 因为分页模式下需处理 非连续 pages续页 等复杂逻辑

七、示例

假设:

  • 请求 A 需要 extend 3 个 token
  • 请求 B 需要 extend 2 个 token
  • num_tokens = 5

调用:

python 复制代码
out_cache_loc = batch.alloc_token_slots(5)
# 假设返回: tensor([100, 101, 102, 103, 104])

后续:

  • 请求 A 的 token → slots [100, 101, 102]
  • 请求 B 的 token → slots [103, 104]

八、注意事项

  1. 仅用于 page_size == 1
  2. OOM 时会自动 evict,但可能影响缓存命中率
  3. backup_state=True 用于安全分配(失败时回滚)
  4. 返回的是 GPU tensor.to(device, non_blocking=True) 在上层完成)

总结

特性 说明
作用 为新 token 分配 KV Cache 位置
调用方 ScheduleBatch.prepare_for_extend/decode
依赖 TokenToKVPoolAllocator + RadixCache(evict)
关键机制 内存不足时 evict、失败回滚
输出 out_cache_loc(GPU tensor of slot indices)
适用模式 page_size == 1(非分页)

alloc_token_slots 是 SGLang 实现 高效、安全、弹性 KV Cache 管理 的核心环节。

mix_with_running 这个函数干啥的,举例说明

def mix_with_running(self, running_batch: "ScheduleBatch"):

self.forward_mode = ForwardMode.MIXED

running_bs = running_batch.batch_size()

复制代码
for req in running_batch.reqs:
    req.fill_ids = req.origin_input_ids + req.output_ids
    req.extend_input_len = 1

input_ids = torch.cat([self.input_ids, running_batch.input_ids])
out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])

self.merge_batch(running_batch)
self.input_ids = input_ids
self.out_cache_loc = out_cache_loc

# For overlap scheduler, the output_ids has one step delay
delta = 0 if self.enable_overlap else -1

# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
self.prefix_lens.extend(
    [
        len(r.origin_input_ids) + len(r.output_ids) + delta
        for r in running_batch.reqs
    ]
)
self.extend_lens.extend([1] * running_bs)
self.extend_num_tokens += running_bs
# TODO (lianmin): Revisit this. It should be seq_len - 1
self.extend_logprob_start_lens.extend([0] * running_bs)

函数作用

mix_with_running函数的作用是将当前批次(通常是一个预填充批次)与一个正在运行的解码批次合并,形成一个混合批次。这种混合批次可以同时处理新请求的预填充和已运行请求的解码。

使用场景

在流式推理服务中,经常会有新的请求到来(需要预填充)和已有的请求在生成(解码步骤)。为了充分利用GPU,我们可以将预填充和解码一起在一个批次中处理,这就是混合批次。

详细步骤

设置前向模式为混合模式:self.forward_mode = ForwardMode.MIXED

更新运行中批次的每个请求:

将每个请求的fill_ids更新为原始输入ID加上输出ID(即完整的当前序列)

设置每个请求的extend_input_len为1(因为解码阶段每次只生成一个token)

合并输入ID和输出缓存位置:

将当前批次的input_ids和运行中批次的input_ids拼接

将当前批次的out_cache_loc和运行中批次的out_cache_loc拼接

合并两个批次:调用merge_batch方法将运行中批次的请求合并到当前批次

更新前缀长度、扩展长度等元数据:

对于运行中批次的每个请求,计算其前缀长度(即已经缓存了的token数量)。注意这里有一个delta,在重叠调度模式下为0,非重叠为-1。

扩展长度(extend_lens)每个运行中请求都是1(因为每个请求只生成一个token)

扩展token数量(extend_num_tokens)增加运行中批次的大小

扩展logprob起始长度(extend_logprob_start_lens)对于运行中批次的每个请求设置为0

举例说明

假设我们有两个批次:

当前批次(预填充批次):包含2个新请求

请求A: 输入长度为10,需要预填充10个token

请求B: 输入长度为15,需要预填充15个token

运行中批次(解码批次):包含3个正在生成的请求

请求C: 已经生成了5个token,原始输入长度为8,当前序列长度为8+5=13

请求D: 已经生成了10个token,原始输入长度为20,当前序列长度为30

请求E: 已经生成了3个token,原始输入长度为12,当前序列长度为15

合并过程

设置模式:当前批次变为混合模式。

更新运行中批次的请求:

对于请求C、D、E,设置fill_ids为完整的当前序列(原始输入+输出),并设置extend_input_len=1。

合并张量:

当前批次的input_ids(长度为10+15=25)与运行中批次的input_ids(长度为3,因为每个请求一个token,即上次生成的token)拼接,得到长度为28的input_ids。

类似地,out_cache_loc也会合并。

合并请求:将请求C、D、E加入到当前批次的请求列表中,现在当前批次有5个请求。

更新元数据:

前缀长度(prefix_lens):

对于预填充的请求A和B,已经在prepare_for_extend中设置好了(比如可能利用了前缀缓存,假设请求A的前缀长度为0,请求B的前缀长度为0,即全部需要计算)

现在加入运行中批次的前缀长度:对于每个运行中请求,前缀长度是len(origin_input_ids) + len(output_ids) + delta。注意,这里的前缀长度是指已经缓存的token数量。在非重叠模式下(delta=-1),因为最后一个token还没有缓存,所以前缀长度是当前序列长度减1;在重叠模式下,可能是当前序列长度(因为重叠调度可能已经缓存了最后一个token?)。代码中的注释说明"我们不缓存每个解码步骤",所以这里可能需要仔细设计。实际上,在解码阶段,每次生成一个token,我们会缓存这个新生成的token,但这里的前缀长度指的是在本次扩展之前已经缓存的数量。所以对于运行中请求,在本次扩展(生成下一个token)之前,已经缓存了len(origin_input_ids) + len(output_ids) - 1(非重叠)或len(origin_input_ids) + len(output_ids)(重叠)?这个根据调度策略有所不同。

扩展长度(extend_lens):预填充的请求A和B分别是10和15,运行中批次每个请求都是1,所以扩展长度列表变为[10, 15, 1, 1, 1]。

扩展token总数(extend_num_tokens)变为25(预填充)+3(解码)=28。

扩展logprob起始长度(extend_logprob_start_lens)对于运行中批次请求设置为0。

注意点

这个函数用于混合预填充和解码,以提高GPU利用率。

重叠调度和非重叠调度在计算前缀长度时有差异,通过delta调整。

合并后,模型前向传播会同时处理预填充token(多个)和解码token(每个请求1个)。

总结

mix_with_running实现了预填充和解码的混合批次处理,是提高推理效率的关键技术之一。它允许推理引擎同时处理新请求的初始化和老请求的继续生成,从而充分利用GPU计算资源。

我来详细解析mix_with_running这个函数,它实现了混合批次处理的重要功能。

1. 函数完整代码

python 复制代码
def mix_with_running(self, running_batch: "ScheduleBatch"):
    self.forward_mode = ForwardMode.MIXED
    running_bs = running_batch.batch_size()

    for req in running_batch.reqs:
        req.fill_ids = req.origin_input_ids + req.output_ids
        req.extend_input_len = 1

    input_ids = torch.cat([self.input_ids, running_batch.input_ids])
    out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])

    self.merge_batch(running_batch)
    self.input_ids = input_ids
    self.out_cache_loc = out_cache_loc

    # For overlap scheduler, the output_ids has one step delay
    delta = 0 if self.enable_overlap else -1

    # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
    self.prefix_lens.extend(
        [
            len(r.origin_input_ids) + len(r.output_ids) + delta
            for r in running_batch.reqs
        ]
    )
    self.extend_lens.extend([1] * running_bs)
    self.extend_num_tokens += running_bs
    # TODO (lianmin): Revisit this. It should be seq_len - 1
    self.extend_logprob_start_lens.extend([0] * running_bs)

2. 函数核心作用

主要功能 :将预填充批次正在运行的解码批次 合并,形成一个混合批次,在一次前向传播中同时处理新请求的预填充和已有请求的解码。

3. 使用场景和动机

3.1 为什么要混合批次?

在LLM推理服务中,通常有两种类型的请求:

  • 新请求:需要预填充(Prefill)处理整个输入序列
  • 运行中请求:需要解码(Decode)生成下一个token

传统方式:分别处理,效率低

复制代码
时间线:
[预填充批次] → [空闲] → [解码批次] → [空闲] → [预填充批次] ...

混合方式:同时处理,提高GPU利用率

复制代码
时间线:
[混合批次:预填充 + 解码] → [混合批次:预填充 + 解码] → ...

3.2 性能优势

python 复制代码
# 分开处理的GPU利用率
预填充: 100% GPU利用率 → 解码: 100% GPU利用率 → 空闲: 0% GPU利用率

# 混合处理的GPU利用率  
混合: 100% GPU利用率 → 混合: 100% GPU利用率 → 混合: 100% GPU利用率

4. 详细执行步骤

4.1 设置前向模式

python 复制代码
self.forward_mode = ForwardMode.MIXED

作用:标记当前批次为混合模式,模型需要同时处理预填充和解码。

4.2 准备运行中批次的请求

python 复制代码
running_bs = running_batch.batch_size()

for req in running_batch.reqs:
    req.fill_ids = req.origin_input_ids + req.output_ids
    req.extend_input_len = 1

关键操作

  • 重建完整序列fill_ids = origin_input_ids + output_ids
  • 设置扩展长度extend_input_len = 1(解码每次只生成1个token)

示例

python 复制代码
# 运行中请求的状态
req.origin_input_ids = [1, 2, 3]     # "Hello world"
req.output_ids = [4, 5, 6]           # "How are you"
# 处理后:
req.fill_ids = [1, 2, 3, 4, 5, 6]    # 完整序列
req.extend_input_len = 1              # 每次解码只处理1个新token

4.3 合并输入数据

python 复制代码
input_ids = torch.cat([self.input_ids, running_batch.input_ids])
out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])

数据合并

  • input_ids:预填充token + 解码token
  • out_cache_loc:预填充KV位置 + 解码KV位置

4.4 合并批次元数据

python 复制代码
self.merge_batch(running_batch)

调用之前分析过的merge_batch方法,合并所有请求和采样信息。

4.5 计算前缀长度(关键逻辑)

python 复制代码
# For overlap scheduler, the output_ids has one step delay
delta = 0 if self.enable_overlap else -1

# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
self.prefix_lens.extend(
    [
        len(r.origin_input_ids) + len(r.output_ids) + delta
        for r in running_batch.reqs
    ]
)

前缀长度计算逻辑

  • 重叠调度delta=0):前缀包含所有已生成token
  • 非重叠调度delta=-1):前缀不包含最后一个token(用于生成)

具体示例

python 复制代码
# 运行中请求
req.origin_input_ids = [1, 2, 3]     # 3个token
req.output_ids = [4, 5, 6]           # 3个已生成token
当前序列长度 = 6

# 非重叠调度 (delta = -1)
prefix_len = 3 + 3 - 1 = 5  # 缓存前5个token,第6个用于生成

# 重叠调度 (delta = 0)  
prefix_len = 3 + 3 + 0 = 6  # 缓存所有6个token

4.6 设置扩展参数

python 复制代码
self.extend_lens.extend([1] * running_bs)
self.extend_num_tokens += running_bs
self.extend_logprob_start_lens.extend([0] * running_bs)

参数含义

  • extend_lens:运行中请求每个扩展1个token
  • extend_num_tokens:总扩展token数增加
  • extend_logprob_start_lens:运行中请求不计算输入logprob

5. 具体示例分析

5.1 场景设置

假设我们有:

  • 预填充批次:2个新请求
  • 运行中批次:3个解码请求
预填充批次状态
python 复制代码
# 请求A(新)
reqA.origin_input_ids = [10, 11, 12, 13]  # "What is AI"
reqA.output_ids = []                       # 无输出
reqA.fill_ids = [10, 11, 12, 13]
reqA.extend_input_len = 4

# 请求B(新)  
reqB.origin_input_ids = [20, 21, 22]      # "Hello world"
reqB.output_ids = []                       # 无输出
reqB.fill_ids = [20, 21, 22]
reqB.extend_input_len = 3

# 批次数据
self.input_ids = [10, 11, 12, 13, 20, 21, 22]  # 7个token
self.out_cache_loc = [100, 101, 102, 103, 104, 105, 106]  # 7个位置
运行中批次状态
python 复制代码
# 请求C(运行中)
reqC.origin_input_ids = [30, 31]          # "The weather"
reqC.output_ids = [32, 33]                # "is nice"
reqC.fill_ids = [30, 31, 32, 33]          # 需要更新
reqC.extend_input_len = 1                  # 需要设置

# 请求D(运行中)
reqD.origin_input_ids = [40, 41, 42]      # "Machine learning"
reqD.output_ids = [43]                    # "is"
reqD.fill_ids = [40, 41, 42, 43]          # 需要更新  
reqD.extend_input_len = 1                  # 需要设置

# 请求E(运行中)
reqE.origin_input_ids = [50]              # "AI"
reqE.output_ids = [51, 52, 53]            # "will change"
reqE.fill_ids = [50, 51, 52, 53]          # 需要更新
reqE.extend_input_len = 1                  # 需要设置

# 批次数据
running_batch.input_ids = [33, 43, 53]    # 每个请求的上次输出token
running_batch.out_cache_loc = [200, 201, 202]  # 为这次解码分配的位置

5.2 混合过程执行

步骤1:更新运行中请求
python 复制代码
for req in running_batch.reqs:
    req.fill_ids = req.origin_input_ids + req.output_ids
    req.extend_input_len = 1

更新后

复制代码
请求C: fill_ids = [30, 31, 32, 33], extend_input_len = 1
请求D: fill_ids = [40, 41, 42, 43], extend_input_len = 1  
请求E: fill_ids = [50, 51, 52, 53], extend_input_len = 1
步骤2:合并张量数据
python 复制代码
input_ids = torch.cat([[10,11,12,13,20,21,22], [33,43,53]]) 
# 结果: [10,11,12,13,20,21,22,33,43,53]  # 10个token

out_cache_loc = torch.cat([[100,101,102,103,104,105,106], [200,201,202]])
# 结果: [100,101,102,103,104,105,106,200,201,202]  # 10个位置
步骤3:合并批次
python 复制代码
self.merge_batch(running_batch)
# 现在 self.reqs = [reqA, reqB, reqC, reqD, reqE]  # 5个请求
步骤4:计算前缀长度(假设非重叠调度)
python 复制代码
delta = -1  # 非重叠调度

# 预填充批次的前缀长度(在prepare_for_extend中已设置)
# 假设: self.prefix_lens = [0, 0]  # 请求A和B没有前缀缓存

# 添加运行中请求的前缀长度
self.prefix_lens.extend([
    len(reqC.origin_input_ids) + len(reqC.output_ids) + (-1),  # 2+2-1=3
    len(reqD.origin_input_ids) + len(reqD.output_ids) + (-1),  # 3+1-1=3  
    len(reqE.origin_input_ids) + len(reqE.output_ids) + (-1)   # 1+3-1=3
])
# 结果: prefix_lens = [0, 0, 3, 3, 3]
步骤5:设置扩展参数
python 复制代码
self.extend_lens.extend([1, 1, 1])        # 原来[4,3] + [1,1,1] = [4,3,1,1,1]
self.extend_num_tokens += 3               # 原来7 + 3 = 10
self.extend_logprob_start_lens.extend([0, 0, 0])  # 运行中请求不计算输入logprob

5.3 最终混合批次状态

python 复制代码
# 混合批次最终状态
forward_mode = ForwardMode.MIXED
reqs = [reqA, reqB, reqC, reqD, reqE]     # 5个请求

# 输入数据
input_ids = [10,11,12,13,20,21,22,33,43,53]  # 10个token
out_cache_loc = [100,101,102,103,104,105,106,200,201,202]  # 10个位置

# 扩展参数
prefix_lens = [0, 0, 3, 3, 3]      # 每个请求的前缀长度
extend_lens = [4, 3, 1, 1, 1]      # 每个请求需要扩展的token数
extend_num_tokens = 10              # 总扩展token数

6. 模型前向传播时的处理

6.1 混合模式下的注意力计算

在混合模式下,模型需要同时处理:

  • 预填充请求:计算多个token的注意力(全序列)
  • 解码请求:计算单个token的注意力(增量)
python 复制代码
def mixed_attention_forward(input_ids, prefix_lens, extend_lens):
    # 输入: [10,11,12,13,20,21,22,33,43,53]
    # prefix_lens: [0,0,3,3,3]
    # extend_lens: [4,3,1,1,1]
    
    # 请求A: 处理token [10,11,12,13] (4个token,无前缀)
    # 请求B: 处理token [20,21,22] (3个token,无前缀)  
    # 请求C: 处理token [33] (1个token,前缀3个)
    # 请求D: 处理token [43] (1个token,前缀3个)
    # 请求E: 处理token [53] (1个token,前缀3个)
    
    # 实际实现使用 RaggedTensor 或类似结构处理不规则序列

6.2 性能优势体现

python 复制代码
# 分开处理的GPU使用
预填充: 处理7个token → GPU利用率: 70%
解码: 处理3个token → GPU利用率: 30%

# 混合处理的GPU使用  
混合: 处理10个token → GPU利用率: 100%

7. 调度策略的影响

7.1 重叠调度 vs 非重叠调度

python 复制代码
delta = 0 if self.enable_overlap else -1

重叠调度delta=0):

  • 前缀包含最后一个token
  • 解码步骤可以提前准备
  • 更高的吞吐量,但可能增加延迟

非重叠调度delta=-1):

  • 前缀不包含最后一个token
  • 更简单的实现
  • 较低的吞吐量,但延迟更可预测

7.2 实际调度示例

python 复制代码
# 时间线示例
时间点 t0: [预填充批次] 开始
时间点 t1: [运行中批次] 准备好
时间点 t2: mix_with_running() 创建混合批次
时间点 t3: [混合批次] 执行前向传播

# 重叠调度可以在t1就开始准备运行中批次
# 非重叠调度需要等待t0完成

8. 总结

mix_with_running函数的核心价值:

  1. 提高GPU利用率:通过合并预填充和解码请求,减少GPU空闲时间
  2. 降低延迟:新请求不用等待当前解码完成
  3. 提高吞吐量:一次性处理更多样化的计算负载
  4. 灵活调度:支持重叠和非重叠两种调度策略

关键理解:混合批次处理是现代LLM推理服务中的关键技术,它通过将计算密集的预填充和内存带宽受限的解码组合在一起,实现了更好的硬件资源利用,显著提升了整体推理性能。

相关推荐
CRMEB系统商城3 小时前
【新版发布】标准版PHP v5.6.4正式版,优化部分用户体验
java·大数据·小程序·php·ux
青云交3 小时前
Java 大视界 -- Java 大数据在智慧养老服务需求分析与个性化服务匹配中的应用
java·需求分析·智慧养老·健康管理·java 大数据·个性化服务·生活照料
丈剑走天涯3 小时前
kubernetes 源码编译(ubuntu) kubernetes-1.34.1
java·容器·kubernetes·1024程序员节
今天没ID3 小时前
Java 变量类型转换🙌🙌
java
AL流云。3 小时前
学习Docker前提:多环境安装Docker
学习·docker·eureka·1024程序员节
m0_674031434 小时前
GitHub等平台形成的开源文化正在重也有人
java·windows·mysql
懒惰蜗牛4 小时前
Day44 | J.U.C中的LockSupport详解
java·开发语言·后端·java-ee
2301_803554524 小时前
Http学习
网络协议·学习·http
5pace4 小时前
Mac Nginx安装、启动、简单命令(苍穹外卖、黑马点评前端环境搭建)
java·前端·nginx·macos·tomcat