深入解析:Chunked Prefill 与 FlashAttention/FlashInfer 如何协同工作

深入解析:Chunked Prefill 与 FlashAttention/FlashInfer 如何协同工作

在现代大语言模型(LLM)推理服务中,为了同时满足低延迟和高吞吐的需求,系统必须高效处理两种截然不同的计算负载:用于处理输入提示(Prompt)的 预填充(Prefill) 阶段和自回归生成新词元(Token)的 解码(Decode) 阶段。预填充是计算密集型(Compute-Bound)的,而解码则是内存带宽密集型(Memory-Bound)的。 "分块预填充"(Chunked Prefill)是一种先进的调度策略,它通过将大型预填充任务分解,并与解码任务混合在同一个计算批次中,从而显著提升GPU利用率和系统整体吞吐量。

Chunked Prefill:化整为零,提升效率

传统的连续批处理(Continuous Batching)在处理新请求时,会执行一个完整的预填充操作,这个操作的耗时由最长的输入序列决定。在预填充期间,已经在进行解码的旧请求要么暂停,要么只能生成一个词元,这导致了解码请求的有效吞吐量下降。

Chunked Prefill 的核心思想是将一个长序列的预填充任务分解成多个较小的"块"(Chunks)。 例如,一个包含 4096 个词元的长提示,可以被分解为 8 个 512 词元的块。推理引擎每次只处理一个预填充块,而不是整个序列。

这种机制带来了两大优势:

  1. 减少流水线气泡,提升GPU利用率:预填充块的计算量相对较小且固定,更容易与许多小的解码任务打包在一起,形成一个计算量饱和的批次。这使得计算密集型的预填充块可以充分利用GPU的计算单元,而内存密集型的解码任务则"搭便车",利用剩余的计算周期完成,从而掩盖了各自的瓶颈。
  2. 改善解码请求的延迟:对于正在进行解码的用户来说,他们感受到的不再是长时间的完全停滞,而是在每个混合批次中均匀地、轻微地"减速"。这大大降低了输出每个词元的平均时间,改善了流式生成的体验。

混合批次的"运输"与执行

当一个批次中既有预填充块,又有解码任务时(即 ForwardMode.MIXED 状态),其数据流和计算过程相比单一模式更为复杂。

  1. 批次构建:调度器会选择一个或多个预填充块,并尽可能多地填充解码请求,以达到最优的批次大小。

  2. 数据准备:所有请求的数据会被整理并加载到GPU。这包括:

    • 对于预填充块:需要加载当前块的词元ID。
    • 对于解码任务:需要加载单个新词元的ID。
    • 对于所有请求:都需要访问各自已经计算并存储在GPU显存中的键值缓存(KV Cache)。
  3. 前向计算:模型执行一次前向传播。在注意力层,这是最关键的部分,也是 FlashAttention 和 FlashInfer 发挥作用的地方。

FlashAttention 与 FlashInfer:混合批次注意力的实现核心

FlashAttention 及其后续版本(如 FlashAttention-2/3)和 FlashInfer 都是为解决标准注意力机制中存在的I/O瓶颈而设计的融合算子(Fused Kernel)。 它们通过分块计算、避免将巨大的注意力矩阵物化到高带宽内存(HBM)中等技术,极大地提升了注意力计算的速度和效率。 在处理 Chunked Prefill 产生的混合批次时,这些库的能力至关重要。

处理非连续、变长序列

混合批次中的序列具有两个显著特点:

  • 长度不一(Variable Length):预填充块的长度(如512)远大于解码任务的长度(始终为1)。
  • KV Cache 历史不一:每个请求已经缓存的KV序列长度都不同。

为了高效处理这种"锯齿状"的批次数据,FlashAttention 和 FlashInfer 采用了专门针对变长序列的算子。

FlashAttention 中的 flash_attn_varlen_qkvpacked_func

FlashAttention 提供了 flash_attn_varlen_qkvpacked_func 这类函数来处理打包好的(packed)变长序列。 其工作方式如下:

  1. 数据打包:所有序列的 Query (Q), Key (K), Value (V) 向量被拼接(concatenate)成一个长的一维张量。
  2. 位置索引 :一个名为 cu_seqlens (cumulative sequence lengths) 的辅助张量被传入内核。这个张量记录了每个序列在拼接后的一维张量中的起始和结束位置。
  3. 融合计算 :FlashAttention 内核利用 cu_seqlens 来正确地为每个序列计算注意力。对于预填充块中的每个词元,它需要关注(attend to)该块内所有在它之前的词元以及该请求之前所有块已经生成的 KV Cache。对于解码词元,它需要关注其全部历史 KV Cache。内核通过精心设计的掩码(masking)机制来确保这种因果关系。

FlashInfer 的角色与 PageAttention

FlashInfer 是一个专注于LLM推理场景的算子库,它提供了针对预填充、解码和追加(Append,用于投机解码等场景)三种阶段优化的内核。 它与 vLLM 等框架中广泛使用的 PagedAttention 内存管理机制紧密结合。

PagedAttention 将 KV Cache 分成固定大小的"页面"(Page),并通过一个"块表"(Block Table)来管理每个序列对应的物理内存页面。 这种方式极大地减少了内存碎片,提高了内存利用率。

在处理混合批次时,FlashInfer 的流程如下:

  1. 更新 KV Cache :对于批次中的预填充块,新计算出的 K 和 V 向量需要被追加到对应序列的 Paged KV Cache 中。FlashInfer 提供了如 append_paged_kv_cache 这样的高效内核来完成此操作。 这个内核接收打包好的 K/V 数据和描述其在批次中位置的索引,然后将其写入由块表指定的非连续物理内存页面中。
  2. 执行注意力计算 :FlashInfer 的预填充和解码算子被调用来执行注意力计算。
    • Prefill 算子 (single_prefill_with_kv_cache): 用于处理批次中的预填充块。该算子被设计为计算密集型,以最大化利用Tensor Core。 它会读取该请求之前所有块的 KV Cache,并与当前块的 Q、K、V 进行计算。
    • Decode 算子: 用于处理批次中的解码请求。该算子被优化为内存带宽密集型,专注于快速地从 Paged KV Cache 中加载大量的历史键值对。

通过将预填充和解码任务打包,并调用 FlashInfer 或 FlashAttention 的变长序列处理内核,整个混合批次可以在一次GPU前向计算中完成。预填充块的高计算强度和解码任务的高内存带宽需求得以互补,实现了GPU资源的高效利用。

相关推荐
陈广亮4 小时前
构建具有长期记忆的 AI Agent:从设计模式到生产实践
人工智能
会写代码的柯基犬4 小时前
DeepSeek vs Kimi vs Qwen —— AI 生成俄罗斯方块代码效果横评
人工智能·llm
Mintopia5 小时前
OpenClaw 是什么?为什么节后热度如此之高?
人工智能
爱可生开源社区5 小时前
DBA 的未来?八位行业先锋的年度圆桌讨论
人工智能·dba
叁两8 小时前
用opencode打造全自动公众号写作流水线,AI 代笔太香了!
前端·人工智能·agent
前端付豪8 小时前
LangChain记忆:通过Memory记住上次的对话细节
人工智能·python·langchain
strayCat232558 小时前
Clawdbot 源码解读 7: 扩展机制
人工智能·开源
王鑫星8 小时前
SWE-bench 首次突破 80%:Claude Opus 4.5 发布,Anthropic 的野心不止于写代码
人工智能
lnix8 小时前
当“大龙虾”养在本地:我们离“反SaaS”的AI未来还有多远?
人工智能·aigc