深入解析:Chunked Prefill 与 FlashAttention/FlashInfer 如何协同工作

深入解析:Chunked Prefill 与 FlashAttention/FlashInfer 如何协同工作

在现代大语言模型(LLM)推理服务中,为了同时满足低延迟和高吞吐的需求,系统必须高效处理两种截然不同的计算负载:用于处理输入提示(Prompt)的 预填充(Prefill) 阶段和自回归生成新词元(Token)的 解码(Decode) 阶段。预填充是计算密集型(Compute-Bound)的,而解码则是内存带宽密集型(Memory-Bound)的。 "分块预填充"(Chunked Prefill)是一种先进的调度策略,它通过将大型预填充任务分解,并与解码任务混合在同一个计算批次中,从而显著提升GPU利用率和系统整体吞吐量。

Chunked Prefill:化整为零,提升效率

传统的连续批处理(Continuous Batching)在处理新请求时,会执行一个完整的预填充操作,这个操作的耗时由最长的输入序列决定。在预填充期间,已经在进行解码的旧请求要么暂停,要么只能生成一个词元,这导致了解码请求的有效吞吐量下降。

Chunked Prefill 的核心思想是将一个长序列的预填充任务分解成多个较小的"块"(Chunks)。 例如,一个包含 4096 个词元的长提示,可以被分解为 8 个 512 词元的块。推理引擎每次只处理一个预填充块,而不是整个序列。

这种机制带来了两大优势:

  1. 减少流水线气泡,提升GPU利用率:预填充块的计算量相对较小且固定,更容易与许多小的解码任务打包在一起,形成一个计算量饱和的批次。这使得计算密集型的预填充块可以充分利用GPU的计算单元,而内存密集型的解码任务则"搭便车",利用剩余的计算周期完成,从而掩盖了各自的瓶颈。
  2. 改善解码请求的延迟:对于正在进行解码的用户来说,他们感受到的不再是长时间的完全停滞,而是在每个混合批次中均匀地、轻微地"减速"。这大大降低了输出每个词元的平均时间,改善了流式生成的体验。

混合批次的"运输"与执行

当一个批次中既有预填充块,又有解码任务时(即 ForwardMode.MIXED 状态),其数据流和计算过程相比单一模式更为复杂。

  1. 批次构建:调度器会选择一个或多个预填充块,并尽可能多地填充解码请求,以达到最优的批次大小。

  2. 数据准备:所有请求的数据会被整理并加载到GPU。这包括:

    • 对于预填充块:需要加载当前块的词元ID。
    • 对于解码任务:需要加载单个新词元的ID。
    • 对于所有请求:都需要访问各自已经计算并存储在GPU显存中的键值缓存(KV Cache)。
  3. 前向计算:模型执行一次前向传播。在注意力层,这是最关键的部分,也是 FlashAttention 和 FlashInfer 发挥作用的地方。

FlashAttention 与 FlashInfer:混合批次注意力的实现核心

FlashAttention 及其后续版本(如 FlashAttention-2/3)和 FlashInfer 都是为解决标准注意力机制中存在的I/O瓶颈而设计的融合算子(Fused Kernel)。 它们通过分块计算、避免将巨大的注意力矩阵物化到高带宽内存(HBM)中等技术,极大地提升了注意力计算的速度和效率。 在处理 Chunked Prefill 产生的混合批次时,这些库的能力至关重要。

处理非连续、变长序列

混合批次中的序列具有两个显著特点:

  • 长度不一(Variable Length):预填充块的长度(如512)远大于解码任务的长度(始终为1)。
  • KV Cache 历史不一:每个请求已经缓存的KV序列长度都不同。

为了高效处理这种"锯齿状"的批次数据,FlashAttention 和 FlashInfer 采用了专门针对变长序列的算子。

FlashAttention 中的 flash_attn_varlen_qkvpacked_func

FlashAttention 提供了 flash_attn_varlen_qkvpacked_func 这类函数来处理打包好的(packed)变长序列。 其工作方式如下:

  1. 数据打包:所有序列的 Query (Q), Key (K), Value (V) 向量被拼接(concatenate)成一个长的一维张量。
  2. 位置索引 :一个名为 cu_seqlens (cumulative sequence lengths) 的辅助张量被传入内核。这个张量记录了每个序列在拼接后的一维张量中的起始和结束位置。
  3. 融合计算 :FlashAttention 内核利用 cu_seqlens 来正确地为每个序列计算注意力。对于预填充块中的每个词元,它需要关注(attend to)该块内所有在它之前的词元以及该请求之前所有块已经生成的 KV Cache。对于解码词元,它需要关注其全部历史 KV Cache。内核通过精心设计的掩码(masking)机制来确保这种因果关系。

FlashInfer 的角色与 PageAttention

FlashInfer 是一个专注于LLM推理场景的算子库,它提供了针对预填充、解码和追加(Append,用于投机解码等场景)三种阶段优化的内核。 它与 vLLM 等框架中广泛使用的 PagedAttention 内存管理机制紧密结合。

PagedAttention 将 KV Cache 分成固定大小的"页面"(Page),并通过一个"块表"(Block Table)来管理每个序列对应的物理内存页面。 这种方式极大地减少了内存碎片,提高了内存利用率。

在处理混合批次时,FlashInfer 的流程如下:

  1. 更新 KV Cache :对于批次中的预填充块,新计算出的 K 和 V 向量需要被追加到对应序列的 Paged KV Cache 中。FlashInfer 提供了如 append_paged_kv_cache 这样的高效内核来完成此操作。 这个内核接收打包好的 K/V 数据和描述其在批次中位置的索引,然后将其写入由块表指定的非连续物理内存页面中。
  2. 执行注意力计算 :FlashInfer 的预填充和解码算子被调用来执行注意力计算。
    • Prefill 算子 (single_prefill_with_kv_cache): 用于处理批次中的预填充块。该算子被设计为计算密集型,以最大化利用Tensor Core。 它会读取该请求之前所有块的 KV Cache,并与当前块的 Q、K、V 进行计算。
    • Decode 算子: 用于处理批次中的解码请求。该算子被优化为内存带宽密集型,专注于快速地从 Paged KV Cache 中加载大量的历史键值对。

通过将预填充和解码任务打包,并调用 FlashInfer 或 FlashAttention 的变长序列处理内核,整个混合批次可以在一次GPU前向计算中完成。预填充块的高计算强度和解码任务的高内存带宽需求得以互补,实现了GPU资源的高效利用。

相关推荐
富唯智能4 分钟前
移动+协作+视觉:开箱即用的下一代复合机器人如何重塑智能工厂
人工智能·工业机器人·复合机器人
Antonio9151 小时前
【图像处理】图像的基础几何变换
图像处理·人工智能·计算机视觉
新加坡内哥谈技术2 小时前
Perplexity AI 的 RAG 架构全解析:幕后技术详解
人工智能
武子康2 小时前
AI研究-119 DeepSeek-OCR PyTorch FlashAttn 2.7.3 推理与部署 模型规模与资源详细分析
人工智能·深度学习·机器学习·ai·ocr·deepseek·deepseek-ocr
Sirius Wu3 小时前
深入浅出:Tongyi DeepResearch技术解读
人工智能·语言模型·langchain·aigc
忙碌5444 小时前
AI大模型时代下的全栈技术架构:从深度学习到云原生部署实战
人工智能·深度学习·架构
LZ_Keep_Running4 小时前
智能变电巡检:AI检测新突破
人工智能
InfiSight智睿视界4 小时前
AI 技术助力汽车美容行业实现精细化运营管理
大数据·人工智能
没有钱的钱仔5 小时前
机器学习笔记
人工智能·笔记·机器学习
听风吹等浪起5 小时前
基于改进TransUNet的港口船只图像分割系统研究
人工智能·深度学习·cnn·transformer