一、前置背景:传统 KV 缓存的致命缺陷
大模型推理分为两个阶段:
- Prefill(预填充):一次性输入整段 Prompt,并行计算所有输入 Token 的 K、V 向量,存入显存 KV 缓存;
- Decode(逐 Token 生成):每轮输出 1 个新 Token,计算该 Token 的 K/V 并追加到缓存,循环直到结束。
传统连续 KV 缓存方案问题
- 框架会给每个请求预先分配一块固定长度、连续的显存张量,长度等于模型最大上下文窗口(如 2048/8192)。
- 举个直观痛点:
- 最大上下文 8192 Token,用户 Prompt 仅 300 Token、生成 100 Token,实际只用 400 位置,但必须占用完整 8192 长度连续显存;
- 多请求长短不一,释放后内存碎成大量无法合并的小段,显存利用率常低于 15%;
- Beam Search 多生成分支时,需要完整复制多份 KV 张量,显存占用成倍暴涨。
PagedAttention 核心目标:用操作系统虚拟内存分页思路管理 KV 缓存,消除碎片、按需分配、支持 KV 块共享,显存利用率提升至 85%~95%。
重要前提:PagedAttention不改变注意力数学公式(QK^T V),只是一套 KV 缓存的显存管理机制,常与FlashAttention 算子搭配使用。
二、PagedAttention 四大核心基础概念
1. 物理块 Block(物理 Page)
- 全局显存划分为大量固定大小的独立显存块,是最小分配单元。
- 超参 block_size:单块可存储的 Token 数量,工业界常用 16、32;
- 每个 Block 存储单层单批的 K/V 数据,维度:num_heads, block_size, head_dim;
- 所有 Block 统一存放在全局空闲块池,统一分配、回收、复用。
2. 逻辑 Token & 逻辑块
- 对任意一条对话序列,生成的 Token 按顺序拥有连续逻辑编号(0,1,2,3...);
- 按 block_size 切割逻辑 Token,得到连续逻辑块:
逻辑块编号 = token_pos // block_size,块内偏移 = token_pos % block_size。 - 逻辑上 Token 连续,但底层物理内存完全可以不连续。
3. Block Table(页表,每条请求独有)
- 每个推理序列维护一张数组 block_table\[\],映射关系:
block_table逻辑块ID = 对应物理Block的显存地址/指针 - 作用:屏蔽物理内存离散性,让上层算子看起来逻辑空间是完整连续的。
4. Copy-on-Write(CoW 写时复制,Beam Search 核心优化)
- 多个生成分支(Beam)可以共享只读的历史 KV 物理块;
- 只有当某个分支需要修改 / 追加新 Token、覆盖块内数据时,才复制一份独立物理块供该分支私有,避免完整复制全部 KV 缓存。
三、单条对话序列完整工作流程
固定配置
- Block_size = 16 Token(一块存 16 个 Token 的 K/V)
- 全局空闲物理块池:B0、B1、B2、B3、B4......
- 用户输入 Prompt:22 个 Token(逻辑编号 0~21),后续生成 3 个新 Token(22、23、32)
阶段 1:Prefill 预填充,处理输入 22 个 Token
- 拆分逻辑块
逻辑 Token 0~21,按 16 切割:
- 逻辑块 0:Token 0~15(填满 16 个)
- 逻辑块 1:Token 16~21(仅 6 个,块内剩余 10 个空位)
- 向全局块池申请 2 个空闲物理块,分配 B0、B1
- 创建当前序列专属 Block Table:
bash
block_table[0] = B0
block_table[1] = B1
- 逐层执行 Prefill 计算,将 0-15 的 K/V 写入物理块 B0;16-21 的 K/V 写入物理块 B1;
- 状态总结:仅占用 2 个物理块;传统方案需预分配 8192 长度大张量,显存占用差距几十倍。
此时内存布局示意:

阶段 2:Decode 第一轮,生成第 22 号 Token
- 新 Token 逻辑位置 pos=22
- 计算归属:逻辑块 ID = 22 // 16 = 1;块内偏移 = 22 % 16 = 6
- 查询 Block Table,逻辑块 1 绑定物理块 B1,无需申请新块;
- 直接向 B1 的偏移 6 位置写入新 Token 的 K/V;
序列总 Token 更新为 0~22。
阶段 3:Decode 第二轮,生成第 23 号 Token
pos=23 → 逻辑块 1、偏移 7,继续复用 B1 写入数据,总 Token 0~23。
阶段 4:持续生成直到填满逻辑块 1
不断生成 Token 到 pos=31:逻辑块 1 全部 16 个位置(16~31)写满。
阶段 5:下一个新 Token pos=32
- pos=32 // 16 = 2 → 逻辑块 2;
- 查询 Block Table,block_table2 为空,无对应物理块;
- 向全局块池申请新物理块 B2,更新页表:block_table2 = B2;
- 将 pos=32 的 K/V 写入 B2 的偏移 0 位置;
后续 32~47 所有新 Token 都会复用 B2,直到填满再申请 B3。
阶段 6:对话结束,释放显存
- 推理终止,遍历当前序列 Block Table 中所有物理块 B0、B1、B2,全部归还至全局空闲块池;
- 其他新请求可直接复用这些回收的 Block,无内存碎片。
四、进阶案例:Beam Search + Copy-on-Write(写时复制)
场景:beam_width=2,同一段 22Token Prompt 生成两条候选分支 Beam0、Beam1
- Prefill 完成后,基础序列页表:B0, B1,存储 0~21 历史 KV;
- 初始化两个 Beam 分支时,不复制任何物理块,两个分支共用同一份 Block Table,共享只读 B0、B1;
- 第一轮 Decode,两个分支各自生成独立新 Token(pos=22):
- 逻辑块 1 绑定 B1,而 B1 是多分支共享的只读块,不能直接覆盖;
- 触发 CoW 机制:
① 分配新空闲块 B1_copy;
② 将原 B1 内全部 KV 数据完整拷贝至 B1_copy;
③ Beam0 更新自身页表:block_table1 = B1_copy,写入自己的 pos22 数据;
④ Beam1 仍保留原共享块 B1,写入自身 pos22 数据;
- 收益:长达上千 Token 的历史上下文块(B0)全程共享,仅尾部未填满的小块才需要复制;传统 Beam 方案需要完整复制两份 8192 长度 KV 张量,显存开销大幅降低 70% 以上。
五、底层关键配套优化:离散 KV 原生注意力算子
离散 KV 原生注意力算子是 vLLM 等高性能推理框架实现 PagedAttention 的定制 CUDA 内核,核心能力是直接对非连续物理内存中的 KV 块执行注意力计算,无需先拼接为连续张量。它不是新的注意力数学公式,而是对传统注意力算子的软件架构重构,解决了长期困扰大模型推理的 "KV 必须连续存储" 的软件设计枷锁。
为什么以前绝对不支持离散 KV?(核心根源)
- 硬件层面:GPU 本身没有 "必须连续" 的限制(先破除误区)
GPU SM 核心可以访问显存任意地址,硬件指令集支持随机内存访问。限制完全来自上层软件,而非硬件能力不足。 - 软件层面:三大核心枷锁锁死了离散计算路径

- 历史背景:早期算子设计的场景局限
- FlashAttention v1/v2 研发目标:解决连续 KV下的片上分块计算,减少 HBM 显存占用,完全没考虑分页、离散存储、多请求混合调度场景
- 业务场景限制:早期大模型推理以单用户长文本为主,预分配连续显存即可满足需求,没有多用户并发导致的碎片化问题
- 开发优先级:开发者优先优化计算效率,而非内存管理灵活性,"连续内存 + 高效计算" 是当时最优解
完整工作流程
以序列长度 10、block_size=4、num_heads=2、head_dim=16 为例:
Step 1:输入准备
- Query:当前解码的 1 个 token(1, 2, 1, 16)
- 块表:0, 3, 5(逻辑块 0→物理块 0,逻辑块 1→物理块 3,逻辑块 2→物理块 5)
- 物理 KV 池:全局共享的 8 个物理块,每个存 4 个 token 的 KV
Step 2:内核启动与块表遍历
- CUDA 内核启动,每个线程处理一个注意力头
- 内核按逻辑顺序遍历块表,不做任何拼接,直接获取每个物理块的显存地址
- 块表常驻 GPU 共享内存,查表延迟可忽略不计
Step 3:块级 QKV 计算(并行执行,无中间张量)
对每个物理块执行:
- 分散读取:直接从离散物理块读取 KV(无拷贝)
- 块内计算:复用 FlashAttention 分块计算逻辑(softmax / 加权求和)
- 聚合结果:累加所有块输出,无拼接操作
Step 4:最终输出
所有块计算完成后,直接得到最终注意力结果,用于后续层计算。
与传统方案的根本差异

六、常见误区澄清
- PagedAttention 不是新式注意力算法,只是 KV 缓存内存管理方案,QKV 计算逻辑和标准 Self-Attention 完全一致;
- 和 FlashAttention 是互补关系:Flash 优化单块张量注意力计算速度,PagedAttention 优化 KV 显存存储;
- Block Size 并非越大越好:
- block_size=32:页表更短、查表开销小,但单块空闲空位变多,内存利用率轻微下降;
- block_size=16:工业界平衡内存利用率与查表延迟的通用选择。
七、主流落地框架
vLLM 原创并原生实现 PagedAttention;
SGLang、TensorRT-LLM、DeepSpeed-MII、LightLLM 均借鉴分页 KV 缓存架构。
八、PyTorch 2.x FlexAttention 模拟 PagedAttention 最小可运行 Demo
本示例是教学级纯 PyTorch 模拟实现:用 FlexAttention + 张量索引模拟分页 KV 块表逻辑,不等于 vLLM 底层手写 CUDA 内核版本;工业高性能 PagedAttention 仍依赖定制 CUDA Kernel,Flex 方案胜在零 CUDA 代码、可编译跨算子。
python
import torch
import torch.nn as nn
from torch.nn.attention import flex_attention, FlexAttentionMask
from torch.utils._triton import has_triton
# ===================== 1. 环境校验 =====================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
assert torch.cuda.is_available(), "必须使用CUDA GPU运行FlexAttention"
assert torch.__version__ >= "2.4.0", "PyTorch版本需≥2.4"
print(f"Torch版本: {torch.__version__}, Device: {device}")
# ===================== 2. 超参配置(分页核心参数) =====================
BLOCK_SIZE = 4 # 每个物理块存储4个token(PagedAttention最小分配单元)
NUM_HEADS = 2 # 注意力头数
HEAD_DIM = 16 # 单头维度
NUM_PHYSICAL_BLOCKS = 8 # 全局总物理块池数量
SEQ_LOGICAL_LEN = 10 # 当前序列总逻辑token长度(0~9)
# ===================== 3. 全局物理KV块池(模拟显存离散块) =====================
# 物理池形状: [总物理块数, num_heads, block_size, head_dim]
# 所有请求共享这块显存池,按需分配/回收物理块
physical_k_pool = torch.randn(
NUM_PHYSICAL_BLOCKS, NUM_HEADS, BLOCK_SIZE, HEAD_DIM,
device=device, dtype=torch.float16
)
physical_v_pool = torch.randn_like(physical_k_pool)
# ===================== 4. 构造PageTable(单条推理序列的页表) =====================
# 逻辑token 0~9,按BLOCK_SIZE=4切分逻辑块:
# 逻辑块0: token0-3; 逻辑块1: token4-7; 逻辑块2: token8-9
# page_table[逻辑块ID] = 对应物理块ID
page_table = torch.tensor(
[0, 3, 5], # 逻辑块0→物理块0,逻辑块1→物理块3,逻辑块2→物理块5
device=device, dtype=torch.long
)
num_logical_blocks = page_table.shape[0]
# ===================== 5. 从离散物理块拼接逻辑连续K/V =====================
def gather_paged_kv(page_table: torch.Tensor, physical_pool: torch.Tensor, block_size: int, seq_len: int):
"""
根据页表,从离散物理块池中取出逻辑连续的KV
Args:
page_table: [num_logical_blocks] 逻辑块→物理块映射
physical_pool: [num_phys_block, num_heads, block_size, head_dim]
Return:
logical_kv: [1, num_heads, seq_len, head_dim] 逻辑连续KV (batch=1)
"""
num_phys, n_heads, blk_sz, h_dim = physical_pool.shape
all_blocks = physical_pool[page_table] # [num_logical_blocks, n_heads, blk_sz, h_dim]
concat_all = all_blocks.permute(1, 0, 2, 3).reshape(n_heads, -1, h_dim) # [n_heads, total_logical_tokens, h_dim]
# 截断到真实序列长度(最后一个块可能没填满)
logical_kv = concat_all[:, :seq_len].unsqueeze(0) # [1, n_heads, seq_len, h_dim]
return logical_kv
# 拼接分页K、V
k = gather_paged_kv(page_table, physical_k_pool, BLOCK_SIZE, SEQ_LOGICAL_LEN)
v = gather_paged_kv(page_table, physical_v_pool, BLOCK_SIZE, SEQ_LOGICAL_LEN)
print(f"拼接后逻辑K shape: {k.shape}") # [1, 2, 10, 16]
# ===================== 6. 生成Query(Decode阶段单Token Query演示) =====================
# 模拟Decode:每次仅1个新token作为Query,batch=1
q = torch.randn(1, NUM_HEADS, 1, HEAD_DIM, device=device, dtype=torch.float16)
# ===================== 7. FlexAttention 计算分页注意力 =====================
# 构造下三角因果mask(自回归LLM必备,看不到未来token)
seq_q_len = q.shape[-2]
seq_kv_len = k.shape[-2]
mask = FlexAttentionMask.make_causal(seq_q_len, seq_kv_len, device=device)
# 执行FlexAttention(纯PyTorch编译算子,无自定义CUDA源码)
attn_out = flex_attention(
query=q,
key=k,
value=v,
mask=mask,
scale=HEAD_DIM ** -0.5
)
# ===================== 8. 输出验证 =====================
print(f"注意力输出shape: {attn_out.shape}") # [1, num_heads, q_len, head_dim]
print("="*50)
print("运行成功!FlexAttention 分页KV模拟完成")
print(f"页表映射关系: 逻辑块0→物理块{page_table[0]}, 逻辑块1→物理块{page_table[1]}, 逻辑块2→物理块{page_table[2]}")
print(f"物理块池总容量: {NUM_PHYSICAL_BLOCKS} 块,单块容纳 {BLOCK_SIZE} Token")
运行输出示例
python
Torch版本: 2.4.0, Device: cuda
拼接后逻辑K shape: torch.Size([1, 2, 10, 16])
注意力输出shape: torch.Size([1, 2, 1, 16])
==================================================
运行成功!FlexAttention 分页KV模拟完成
页表映射关系: 逻辑块0→物理块0, 逻辑块1→物理块3, 逻辑块2→物理块5
物理块池总容量: 8 块,单块容纳 4 Token