【LLM系列】PagedAttention 完整详解+最小可运行 Demo

一、前置背景:传统 KV 缓存的致命缺陷

大模型推理分为两个阶段:

  • Prefill(预填充):一次性输入整段 Prompt,并行计算所有输入 Token 的 K、V 向量,存入显存 KV 缓存;
  • Decode(逐 Token 生成):每轮输出 1 个新 Token,计算该 Token 的 K/V 并追加到缓存,循环直到结束。

传统连续 KV 缓存方案问题

  • 框架会给每个请求预先分配一块固定长度、连续的显存张量,长度等于模型最大上下文窗口(如 2048/8192)。
  • 举个直观痛点:
    • 最大上下文 8192 Token,用户 Prompt 仅 300 Token、生成 100 Token,实际只用 400 位置,但必须占用完整 8192 长度连续显存;
    • 多请求长短不一,释放后内存碎成大量无法合并的小段,显存利用率常低于 15%;
    • Beam Search 多生成分支时,需要完整复制多份 KV 张量,显存占用成倍暴涨。

PagedAttention 核心目标:用操作系统虚拟内存分页思路管理 KV 缓存,消除碎片、按需分配、支持 KV 块共享,显存利用率提升至 85%~95%。

重要前提:PagedAttention不改变注意力数学公式(QK^T V),只是一套 KV 缓存的显存管理机制,常与FlashAttention 算子搭配使用。

二、PagedAttention 四大核心基础概念

1. 物理块 Block(物理 Page)

  • 全局显存划分为大量固定大小的独立显存块,是最小分配单元。
  • 超参 block_size:单块可存储的 Token 数量,工业界常用 16、32;
  • 每个 Block 存储单层单批的 K/V 数据,维度:num_heads, block_size, head_dim
  • 所有 Block 统一存放在全局空闲块池,统一分配、回收、复用。

2. 逻辑 Token & 逻辑块

  • 对任意一条对话序列,生成的 Token 按顺序拥有连续逻辑编号(0,1,2,3...);
  • 按 block_size 切割逻辑 Token,得到连续逻辑块:
    逻辑块编号 = token_pos // block_size,块内偏移 = token_pos % block_size。
  • 逻辑上 Token 连续,但底层物理内存完全可以不连续。

3. Block Table(页表,每条请求独有)

  • 每个推理序列维护一张数组 block_table\[\],映射关系:
    block_table逻辑块ID = 对应物理Block的显存地址/指针
  • 作用:屏蔽物理内存离散性,让上层算子看起来逻辑空间是完整连续的。
  • 多个生成分支(Beam)可以共享只读的历史 KV 物理块;
  • 只有当某个分支需要修改 / 追加新 Token、覆盖块内数据时,才复制一份独立物理块供该分支私有,避免完整复制全部 KV 缓存。

三、单条对话序列完整工作流程

固定配置

  • Block_size = 16 Token(一块存 16 个 Token 的 K/V)
  • 全局空闲物理块池:B0、B1、B2、B3、B4......
  • 用户输入 Prompt:22 个 Token(逻辑编号 0~21),后续生成 3 个新 Token(22、23、32)

阶段 1:Prefill 预填充,处理输入 22 个 Token

  1. 拆分逻辑块
    逻辑 Token 0~21,按 16 切割:
  • 逻辑块 0:Token 0~15(填满 16 个)
  • 逻辑块 1:Token 16~21(仅 6 个,块内剩余 10 个空位)
  1. 向全局块池申请 2 个空闲物理块,分配 B0、B1
  2. 创建当前序列专属 Block Table:
bash 复制代码
block_table[0] = B0
block_table[1] = B1
  1. 逐层执行 Prefill 计算,将 0-15 的 K/V 写入物理块 B0;16-21 的 K/V 写入物理块 B1;
  2. 状态总结:仅占用 2 个物理块;传统方案需预分配 8192 长度大张量,显存占用差距几十倍。

此时内存布局示意:

阶段 2:Decode 第一轮,生成第 22 号 Token

  1. 新 Token 逻辑位置 pos=22
  2. 计算归属:逻辑块 ID = 22 // 16 = 1;块内偏移 = 22 % 16 = 6
  3. 查询 Block Table,逻辑块 1 绑定物理块 B1,无需申请新块;
  4. 直接向 B1 的偏移 6 位置写入新 Token 的 K/V;
    序列总 Token 更新为 0~22。

阶段 3:Decode 第二轮,生成第 23 号 Token

pos=23 → 逻辑块 1、偏移 7,继续复用 B1 写入数据,总 Token 0~23。

阶段 4:持续生成直到填满逻辑块 1

不断生成 Token 到 pos=31:逻辑块 1 全部 16 个位置(16~31)写满。

阶段 5:下一个新 Token pos=32

  1. pos=32 // 16 = 2 → 逻辑块 2;
  2. 查询 Block Table,block_table2 为空,无对应物理块;
  3. 向全局块池申请新物理块 B2,更新页表:block_table2 = B2;
  4. 将 pos=32 的 K/V 写入 B2 的偏移 0 位置;

后续 32~47 所有新 Token 都会复用 B2,直到填满再申请 B3。

阶段 6:对话结束,释放显存

  1. 推理终止,遍历当前序列 Block Table 中所有物理块 B0、B1、B2,全部归还至全局空闲块池;
  2. 其他新请求可直接复用这些回收的 Block,无内存碎片。

场景:beam_width=2,同一段 22Token Prompt 生成两条候选分支 Beam0、Beam1

  1. Prefill 完成后,基础序列页表:B0, B1,存储 0~21 历史 KV;
  2. 初始化两个 Beam 分支时,不复制任何物理块,两个分支共用同一份 Block Table,共享只读 B0、B1;
  3. 第一轮 Decode,两个分支各自生成独立新 Token(pos=22):
  • 逻辑块 1 绑定 B1,而 B1 是多分支共享的只读块,不能直接覆盖;
  • 触发 CoW 机制:
    ① 分配新空闲块 B1_copy;
    ② 将原 B1 内全部 KV 数据完整拷贝至 B1_copy;
    ③ Beam0 更新自身页表:block_table1 = B1_copy,写入自己的 pos22 数据;
    ④ Beam1 仍保留原共享块 B1,写入自身 pos22 数据;
  1. 收益:长达上千 Token 的历史上下文块(B0)全程共享,仅尾部未填满的小块才需要复制;传统 Beam 方案需要完整复制两份 8192 长度 KV 张量,显存开销大幅降低 70% 以上。

五、底层关键配套优化:离散 KV 原生注意力算子

离散 KV 原生注意力算子是 vLLM 等高性能推理框架实现 PagedAttention 的定制 CUDA 内核,核心能力是直接对非连续物理内存中的 KV 块执行注意力计算,无需先拼接为连续张量。它不是新的注意力数学公式,而是对传统注意力算子的软件架构重构,解决了长期困扰大模型推理的 "KV 必须连续存储" 的软件设计枷锁。

为什么以前绝对不支持离散 KV?(核心根源)

  1. 硬件层面:GPU 本身没有 "必须连续" 的限制(先破除误区)
    GPU SM 核心可以访问显存任意地址,硬件指令集支持随机内存访问。限制完全来自上层软件,而非硬件能力不足。
  2. 软件层面:三大核心枷锁锁死了离散计算路径
  3. 历史背景:早期算子设计的场景局限
  • FlashAttention v1/v2 研发目标:解决连续 KV下的片上分块计算,减少 HBM 显存占用,完全没考虑分页、离散存储、多请求混合调度场景
  • 业务场景限制:早期大模型推理以单用户长文本为主,预分配连续显存即可满足需求,没有多用户并发导致的碎片化问题
  • 开发优先级:开发者优先优化计算效率,而非内存管理灵活性,"连续内存 + 高效计算" 是当时最优解

完整工作流程

以序列长度 10、block_size=4、num_heads=2、head_dim=16 为例:

Step 1:输入准备

  • Query:当前解码的 1 个 token(1, 2, 1, 16
  • 块表:0, 3, 5(逻辑块 0→物理块 0,逻辑块 1→物理块 3,逻辑块 2→物理块 5)
  • 物理 KV 池:全局共享的 8 个物理块,每个存 4 个 token 的 KV

Step 2:内核启动与块表遍历

  • CUDA 内核启动,每个线程处理一个注意力头
  • 内核按逻辑顺序遍历块表,不做任何拼接,直接获取每个物理块的显存地址
  • 块表常驻 GPU 共享内存,查表延迟可忽略不计

Step 3:块级 QKV 计算(并行执行,无中间张量)

对每个物理块执行:

  • 分散读取:直接从离散物理块读取 KV(无拷贝)
  • 块内计算:复用 FlashAttention 分块计算逻辑(softmax / 加权求和)
  • 聚合结果:累加所有块输出,无拼接操作

Step 4:最终输出

所有块计算完成后,直接得到最终注意力结果,用于后续层计算。

与传统方案的根本差异

六、常见误区澄清

  • PagedAttention 不是新式注意力算法,只是 KV 缓存内存管理方案,QKV 计算逻辑和标准 Self-Attention 完全一致;
  • 和 FlashAttention 是互补关系:Flash 优化单块张量注意力计算速度,PagedAttention 优化 KV 显存存储;
  • Block Size 并非越大越好:
    • block_size=32:页表更短、查表开销小,但单块空闲空位变多,内存利用率轻微下降;
    • block_size=16:工业界平衡内存利用率与查表延迟的通用选择。

七、主流落地框架

vLLM 原创并原生实现 PagedAttention;

SGLang、TensorRT-LLM、DeepSpeed-MII、LightLLM 均借鉴分页 KV 缓存架构。

八、PyTorch 2.x FlexAttention 模拟 PagedAttention 最小可运行 Demo

本示例是教学级纯 PyTorch 模拟实现:用 FlexAttention + 张量索引模拟分页 KV 块表逻辑,不等于 vLLM 底层手写 CUDA 内核版本;工业高性能 PagedAttention 仍依赖定制 CUDA Kernel,Flex 方案胜在零 CUDA 代码、可编译跨算子。

python 复制代码
import torch
import torch.nn as nn
from torch.nn.attention import flex_attention, FlexAttentionMask
from torch.utils._triton import has_triton

# ===================== 1. 环境校验 =====================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
assert torch.cuda.is_available(), "必须使用CUDA GPU运行FlexAttention"
assert torch.__version__ >= "2.4.0", "PyTorch版本需≥2.4"
print(f"Torch版本: {torch.__version__}, Device: {device}")

# ===================== 2. 超参配置(分页核心参数) =====================
BLOCK_SIZE = 4        # 每个物理块存储4个token(PagedAttention最小分配单元)
NUM_HEADS = 2         # 注意力头数
HEAD_DIM = 16         # 单头维度
NUM_PHYSICAL_BLOCKS = 8  # 全局总物理块池数量
SEQ_LOGICAL_LEN = 10  # 当前序列总逻辑token长度(0~9)

# ===================== 3. 全局物理KV块池(模拟显存离散块) =====================
# 物理池形状: [总物理块数, num_heads, block_size, head_dim]
# 所有请求共享这块显存池,按需分配/回收物理块
physical_k_pool = torch.randn(
    NUM_PHYSICAL_BLOCKS, NUM_HEADS, BLOCK_SIZE, HEAD_DIM,
    device=device, dtype=torch.float16
)
physical_v_pool = torch.randn_like(physical_k_pool)

# ===================== 4. 构造PageTable(单条推理序列的页表) =====================
# 逻辑token 0~9,按BLOCK_SIZE=4切分逻辑块:
# 逻辑块0: token0-3; 逻辑块1: token4-7; 逻辑块2: token8-9
# page_table[逻辑块ID] = 对应物理块ID
page_table = torch.tensor(
    [0, 3, 5],  # 逻辑块0→物理块0,逻辑块1→物理块3,逻辑块2→物理块5
    device=device, dtype=torch.long
)
num_logical_blocks = page_table.shape[0]

# ===================== 5. 从离散物理块拼接逻辑连续K/V =====================
def gather_paged_kv(page_table: torch.Tensor, physical_pool: torch.Tensor, block_size: int, seq_len: int):
    """
    根据页表,从离散物理块池中取出逻辑连续的KV
    Args:
        page_table: [num_logical_blocks] 逻辑块→物理块映射
        physical_pool: [num_phys_block, num_heads, block_size, head_dim]
    Return:
        logical_kv: [1, num_heads, seq_len, head_dim] 逻辑连续KV (batch=1)
    """
    num_phys, n_heads, blk_sz, h_dim = physical_pool.shape
    all_blocks = physical_pool[page_table]  # [num_logical_blocks, n_heads, blk_sz, h_dim]
    concat_all = all_blocks.permute(1, 0, 2, 3).reshape(n_heads, -1, h_dim)  # [n_heads, total_logical_tokens, h_dim]
    # 截断到真实序列长度(最后一个块可能没填满)
    logical_kv = concat_all[:, :seq_len].unsqueeze(0)  # [1, n_heads, seq_len, h_dim]
    return logical_kv

# 拼接分页K、V
k = gather_paged_kv(page_table, physical_k_pool, BLOCK_SIZE, SEQ_LOGICAL_LEN)
v = gather_paged_kv(page_table, physical_v_pool, BLOCK_SIZE, SEQ_LOGICAL_LEN)
print(f"拼接后逻辑K shape: {k.shape}")  # [1, 2, 10, 16]

# ===================== 6. 生成Query(Decode阶段单Token Query演示) =====================
# 模拟Decode:每次仅1个新token作为Query,batch=1
q = torch.randn(1, NUM_HEADS, 1, HEAD_DIM, device=device, dtype=torch.float16)

# ===================== 7. FlexAttention 计算分页注意力 =====================
# 构造下三角因果mask(自回归LLM必备,看不到未来token)
seq_q_len = q.shape[-2]
seq_kv_len = k.shape[-2]
mask = FlexAttentionMask.make_causal(seq_q_len, seq_kv_len, device=device)

# 执行FlexAttention(纯PyTorch编译算子,无自定义CUDA源码)
attn_out = flex_attention(
    query=q,
    key=k,
    value=v,
    mask=mask,
    scale=HEAD_DIM ** -0.5
)

# ===================== 8. 输出验证 =====================
print(f"注意力输出shape: {attn_out.shape}")  # [1, num_heads, q_len, head_dim]
print("="*50)
print("运行成功!FlexAttention 分页KV模拟完成")
print(f"页表映射关系: 逻辑块0→物理块{page_table[0]}, 逻辑块1→物理块{page_table[1]}, 逻辑块2→物理块{page_table[2]}")
print(f"物理块池总容量: {NUM_PHYSICAL_BLOCKS} 块,单块容纳 {BLOCK_SIZE} Token")

运行输出示例

python 复制代码
Torch版本: 2.4.0, Device: cuda
拼接后逻辑K shape: torch.Size([1, 2, 10, 16])
注意力输出shape: torch.Size([1, 2, 1, 16])
==================================================
运行成功!FlexAttention 分页KV模拟完成
页表映射关系: 逻辑块0→物理块0, 逻辑块1→物理块3, 逻辑块2→物理块5
物理块池总容量: 8 块,单块容纳 4 Token