深入解析:Chunked Prefill 与 FlashAttention/FlashInfer 如何协同工作
在现代大语言模型(LLM)推理服务中,为了同时满足低延迟和高吞吐的需求,系统必须高效处理两种截然不同的计算负载:用于处理输入提示(Prompt)的 预填充(Prefill) 阶段和自回归生成新词元(Token)的 解码(Decode) 阶段。预填充是计算密集型(Compute-Bound)的,而解码则是内存带宽密集型(Memory-Bound)的。 "分块预填充"(Chunked Prefill)是一种先进的调度策略,它通过将大型预填充任务分解,并与解码任务混合在同一个计算批次中,从而显著提升GPU利用率和系统整体吞吐量。
Chunked Prefill:化整为零,提升效率
传统的连续批处理(Continuous Batching)在处理新请求时,会执行一个完整的预填充操作,这个操作的耗时由最长的输入序列决定。在预填充期间,已经在进行解码的旧请求要么暂停,要么只能生成一个词元,这导致了解码请求的有效吞吐量下降。
Chunked Prefill 的核心思想是将一个长序列的预填充任务分解成多个较小的"块"(Chunks)。 例如,一个包含 4096 个词元的长提示,可以被分解为 8 个 512 词元的块。推理引擎每次只处理一个预填充块,而不是整个序列。
这种机制带来了两大优势:
- 减少流水线气泡,提升GPU利用率:预填充块的计算量相对较小且固定,更容易与许多小的解码任务打包在一起,形成一个计算量饱和的批次。这使得计算密集型的预填充块可以充分利用GPU的计算单元,而内存密集型的解码任务则"搭便车",利用剩余的计算周期完成,从而掩盖了各自的瓶颈。
- 改善解码请求的延迟:对于正在进行解码的用户来说,他们感受到的不再是长时间的完全停滞,而是在每个混合批次中均匀地、轻微地"减速"。这大大降低了输出每个词元的平均时间,改善了流式生成的体验。
混合批次的"运输"与执行
当一个批次中既有预填充块,又有解码任务时(即 ForwardMode.MIXED
状态),其数据流和计算过程相比单一模式更为复杂。
-
批次构建:调度器会选择一个或多个预填充块,并尽可能多地填充解码请求,以达到最优的批次大小。
-
数据准备:所有请求的数据会被整理并加载到GPU。这包括:
- 对于预填充块:需要加载当前块的词元ID。
- 对于解码任务:需要加载单个新词元的ID。
- 对于所有请求:都需要访问各自已经计算并存储在GPU显存中的键值缓存(KV Cache)。
-
前向计算:模型执行一次前向传播。在注意力层,这是最关键的部分,也是 FlashAttention 和 FlashInfer 发挥作用的地方。
FlashAttention 与 FlashInfer:混合批次注意力的实现核心
FlashAttention 及其后续版本(如 FlashAttention-2/3)和 FlashInfer 都是为解决标准注意力机制中存在的I/O瓶颈而设计的融合算子(Fused Kernel)。 它们通过分块计算、避免将巨大的注意力矩阵物化到高带宽内存(HBM)中等技术,极大地提升了注意力计算的速度和效率。 在处理 Chunked Prefill 产生的混合批次时,这些库的能力至关重要。
处理非连续、变长序列
混合批次中的序列具有两个显著特点:
- 长度不一(Variable Length):预填充块的长度(如512)远大于解码任务的长度(始终为1)。
- KV Cache 历史不一:每个请求已经缓存的KV序列长度都不同。
为了高效处理这种"锯齿状"的批次数据,FlashAttention 和 FlashInfer 采用了专门针对变长序列的算子。
FlashAttention 中的 flash_attn_varlen_qkvpacked_func
FlashAttention 提供了 flash_attn_varlen_qkvpacked_func
这类函数来处理打包好的(packed)变长序列。 其工作方式如下:
- 数据打包:所有序列的 Query (Q), Key (K), Value (V) 向量被拼接(concatenate)成一个长的一维张量。
- 位置索引 :一个名为
cu_seqlens
(cumulative sequence lengths) 的辅助张量被传入内核。这个张量记录了每个序列在拼接后的一维张量中的起始和结束位置。 - 融合计算 :FlashAttention 内核利用
cu_seqlens
来正确地为每个序列计算注意力。对于预填充块中的每个词元,它需要关注(attend to)该块内所有在它之前的词元以及该请求之前所有块已经生成的 KV Cache。对于解码词元,它需要关注其全部历史 KV Cache。内核通过精心设计的掩码(masking)机制来确保这种因果关系。
FlashInfer 的角色与 PageAttention
FlashInfer 是一个专注于LLM推理场景的算子库,它提供了针对预填充、解码和追加(Append,用于投机解码等场景)三种阶段优化的内核。 它与 vLLM 等框架中广泛使用的 PagedAttention 内存管理机制紧密结合。
PagedAttention 将 KV Cache 分成固定大小的"页面"(Page),并通过一个"块表"(Block Table)来管理每个序列对应的物理内存页面。 这种方式极大地减少了内存碎片,提高了内存利用率。
在处理混合批次时,FlashInfer 的流程如下:
- 更新 KV Cache :对于批次中的预填充块,新计算出的 K 和 V 向量需要被追加到对应序列的 Paged KV Cache 中。FlashInfer 提供了如
append_paged_kv_cache
这样的高效内核来完成此操作。 这个内核接收打包好的 K/V 数据和描述其在批次中位置的索引,然后将其写入由块表指定的非连续物理内存页面中。 - 执行注意力计算 :FlashInfer 的预填充和解码算子被调用来执行注意力计算。
- Prefill 算子 (
single_prefill_with_kv_cache
): 用于处理批次中的预填充块。该算子被设计为计算密集型,以最大化利用Tensor Core。 它会读取该请求之前所有块的 KV Cache,并与当前块的 Q、K、V 进行计算。 - Decode 算子: 用于处理批次中的解码请求。该算子被优化为内存带宽密集型,专注于快速地从 Paged KV Cache 中加载大量的历史键值对。
- Prefill 算子 (
通过将预填充和解码任务打包,并调用 FlashInfer 或 FlashAttention 的变长序列处理内核,整个混合批次可以在一次GPU前向计算中完成。预填充块的高计算强度和解码任务的高内存带宽需求得以互补,实现了GPU资源的高效利用。