capture_cudagraph
py
if not self.enforce_eager:
self.capture_cudagraph()
@torch.inference_mode()
def capture_cudagraph(self):
config = self.config
hf_config = config.hf_config
max_bs = min(self.config.max_num_seqs, 512)
max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size
input_ids = torch.zeros(max_bs, dtype=torch.int64)
positions = torch.zeros(max_bs, dtype=torch.int64)
slot_mapping = torch.zeros(max_bs, dtype=torch.int32)
context_lens = torch.zeros(max_bs, dtype=torch.int32)
block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32)
outputs = torch.zeros(max_bs, hf_config.hidden_size)
self.graph_bs = [1, 2, 4, 8] + list(range(16, max_bs + 1, 16))
self.graphs = {}
self.graph_pool = None
for bs in reversed(self.graph_bs):
graph = torch.cuda.CUDAGraph()
set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs])
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup
with torch.cuda.graph(graph, self.graph_pool):
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture
if self.graph_pool is None:
self.graph_pool = graph.pool()
self.graphs[bs] = graph
torch.cuda.synchronize()
reset_context()
self.graph_vars = dict(
input_ids=input_ids,
positions=positions,
slot_mapping=slot_mapping,
context_lens=context_lens,
block_tables=block_tables,
outputs=outputs,
)
为什么需要 CUDA Graph?
在 LLM 推理等小算子、高频次 的场景中,CPU 逐个调度任务的开销往往比 GPU 实际计算的时间还要长,导致 GPU 大量空闲等待;CUDA Graph 通过将一系列 GPU 操作"录制"为静态图,在执行时只需一次 CPU 指令即可驱动整个计算流程,从而彻底消除 CPU 调度瓶颈,填满 GPU 流水线,显著降低推理延迟。
下面先初始化一些变量
py
@torch.inference_mode() # 禁用梯度计算,节省显存并加速
def capture_cudagraph(self):
config = self.config
hf_config = config.hf_config
# 设定最大 Batch Size,为了安全起见限制在 512 以内
max_bs = min(self.config.max_num_seqs, 512)
# 计算单个序列 KV Cache 所需的最大 block 数量
max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size
接下来是关键的一步 :CUDA Graph 要求输入和输出的内存地址是固定的。因此,代码预先分配了一组全零的张量(Tensors)作为"静态缓冲区"。
py
# 分配静态输入/输出 Tensor,它们将驻留在 GPU 上
input_ids = torch.zeros(max_bs, dtype=torch.int64)
positions = torch.zeros(max_bs, dtype=torch.int64)
slot_mapping = torch.zeros(max_bs, dtype=torch.int32)
context_lens = torch.zeros(max_bs, dtype=torch.int32)
block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32)
outputs = torch.zeros(max_bs, hf_config.hidden_size)
torch.zeros(),PyTorch 会在内存中申请一块空间,并明确地将所有位置都填充为 0。
以后在推理时,我们不能直接传新 Tensor 给模型,而是必须把数据 copy_ 到这些 input_ids 等静态 Tensor 中,然后重放 Graph。
假设 max_bs 是最大批次大小,max_num_blocks 是为每个序列分配的最大内存块数量。
为什么
input_ids和positions是int64?而slot_mapping,context_lens,block_tables是int32?
因为PyTorch 的核心层 nn.Embedding 以及很多官方算子,强制要求输入的索引张量必须是 LongTensor (即 int64)。而另外三个不是给 PyTorch 标准层用的,而是传给 vLLM 自定义的 CUDA Kernel(例如 PagedAttention 算子)用的,这些元数据在推理过程中会被高频读取,使用int32相比于int64可以节省一半的显存读取,节省带宽。而且int32大概是2e9,显存再大,物理块的数量也不可能超过 21 亿,context_lens和block_tables也是一样。
| 变量名 | 形状 (Shape) | 维度含义解析 | 作用 |
|---|---|---|---|
input_ids |
(max_bs,) |
输入 Token ID。每一行代表当前 Batch 中每个序列正在处理的那个 Token。 | 输入token |
positions |
(max_bs,) |
位置索引。对应每个 Token 在其原始序列中的绝对位置(用于 Position Embedding)。 | 它决定了 RoPE(旋转位置编码) 的旋转角度。只有知道了位置,模型才能分清"我喜欢你"和"你喜欢我"的区别。 |
slot_mapping |
(max_bs,) |
槽位映射。指示当前 Token 应该存储在物理 KV Cache 内存池中的哪个具体位置。 | 表示每个 token 在 KV Cache 池中的物理存储位置(slot 索引),用于 CUDA kernel 将新计算的 K/V 写入正确的缓存位置。 |
context_lens |
(max_bs,) |
有效长度。记录每个序列到目前为止总共拥有多少个有效的 Token | 它是推理时的"边界守护者",告诉模型在计算注意力时,应该回溯查看多少长度的 KV Cache。 |
block_tables |
(max_bs, max_num_blocks) |
物理块表。每一行映射了一个序列所占用的所有不连续内存块的编号。 | 显存页表,由于模型在计算 Attention 时,它需要回头看过去所有的历史记录(Key/Value),所以需要一个页表记录每个序列的kv cache存在哪些块里。 |
outputs |
(max_bs, hidden_size) |
隐层输出。保存模型最后一层计算出的向量,准备进行 Logits 映射。 | 输出token |
接下来定义 Batch Size 策略
py
# 定义需要捕获的 Batch Size 列表
# 小 BS:1, 2, 4, 8 (不仅是为了速度,也是为了精确匹配)
# 大 BS:16, 32, 48... 直到 max_bs (以 16 为步长)
self.graph_bs = [1, 2, 4, 8] + list(range(16, max_bs + 1, 16))
self.graphs = {}
self.graph_pool = None # 用于共享内存池
为什么要分这么多size?
我们不可能为 1 到 512 的每个 batch size 都录制一个图(那样显存会爆炸)。通常的做法是"向上取整":如果你来了 12 个请求,就用 batch size = 16 的图来跑,多余的槽位填 padding。
接下来为不同 batch size 预先录制 CUDA Graph,从而在后续推理时直接 replay 图,避免 Python 调度和 CUDA kernel launch 的开销
py
for bs in reversed(self.graph_bs):
graph = torch.cuda.CUDAGraph() # 这是一个录制器,用来记录接下来所有的 CUDA kernel 调用
# 1. 设置上下文 (Set Context)
# 这通常用于 PagedAttention 等自定义算子,告诉它们当前只处理前 bs 个数据
set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs])
# 2. Warmup (预热)
# 在录制前必须先运行一次。这是为了让 PyTorch/CUDA 内部完成 lazy initialization(懒加载),
# 确保所有显存分配和 buffer 初始化都已完成,避免录制到内存分配操作。
outputs[:bs] = self.model(input_ids[:bs], positions[:bs])
# 3. 开启录制 (Capture)
# self.graph_pool 用于在不同的 Graph 之间共享显存,极大节省内存占用。
with torch.cuda.graph(graph, self.graph_pool):
outputs[:bs] = self.model(input_ids[:bs], positions[:bs])
# 4. 内存池管理
# 如果是第一次循环(最大的 Batch Size),获取其内存池,供后续较小的 Graph 复用。
if self.graph_pool is None:
self.graph_pool = graph.pool()
# 5. 保存图并清理
self.graphs[bs] = graph
torch.cuda.synchronize() # 等待 GPU 完成
reset_context() # 清理上下文
为什么要倒序遍历batch size,即从大到小?
因为 CUDA Graph capture 过程中会分配 GPU 内存,如果从大到小录制,可以:先为最大 bs 分配最大需要的内存,后面小 bs 可以复用 graph memory pool,可以减少碎片和额外分配
为什么需要内存池?为什么需要共享的内存池?
在大模型的一趟前向传播中,每一层都会产生大量的中间临时变量。在普通模式下,PyTorch 会在算到这一层时,临时向系统申请一块显存(cudaMalloc)来装这些中间变量,算完再释放。 但是,动态申请显存非常耗时 。CUDA Graph 为了追求极致的速度,严禁 在运行期间进行任何动态内存分配。 因此,在录制 Graph 之前,系统必须提前一次性申请好一块足够大的连续显存作为**"专属草稿纸"。这块提前划定好的草稿纸,就是内存池**。有了它,Graph 才能把所有中间变量的地址当成常量"死死地刻在"指令里。
因为 CUDA Graph 会将所需地址永久锁定,如果不共享内存池,录制多个不同 Batch Size 的图会导致每一张图都霸占一份专属的临时显存,最终令显存占用随图的数量成倍爆炸;而引入共享内存池(graph_pool)后,系统只需按照最大尺寸的图申请唯一一块 显存,巧妙利用大模型推理任务严格串行执行且中间激活值**"朝生夕死"**的特性,让所有不同尺寸的图轮流在这张相同的"临时草稿纸"上安全地覆盖擦写,从而在维持硬件级极速调度的同时,将底层的显存开销压缩到了绝对的最低极限。
内存池里究竟放什么东西, 不放什么东西?
在一次推理过程中,我们会产生一些临时的tensor,比如
py
Q = linear(x)
K = linear(x)
V = linear(x)
attn_out = softmax(QK^T)V
mlp_out = ...
residual = ...
这些中间 tensor 的生命周期只在一次 forward 内,下次 forward 可以覆盖,且不需要跨 step 保存,我们会把他放在memory pool中。
而对于KV cache、input_ids、slot_mapping、position_ids、block_tables、模型参数等,是不会存在graph memory pool的
我们的GPU显存可以抽象成这样:
txt
┌────────────────────────────┐
│ 模型权重 (长期存在) │
├────────────────────────────┤
│ KV Cache (长期存在) │
├────────────────────────────┤
│ slot_mapping 等输入张量 │
├────────────────────────────┤
│ Graph Pool (长期存在) │ ← 只放中间 tensor
│ Q tensor │
│ K tensor │
│ V tensor │
│ attn_out │
│ mlp_out │
│ ... │
└────────────────────────────┘
graph 为什么要固定地址?
因为"固定地址"是让 GPU 能够"脱离 CPU 独立全速运行"的唯一代价。
CUDA Graph 必须固定显存地址,是为了彻底消除动态内存分配(cudaMalloc)的高昂开销,并实现 GPU 的纯硬件级自主调度。在 Graph 的录制阶段,底层驱动会将所有算子的输入、输出及中间变量的**显存指针直接作为常量"硬编码"**到预编译的执行蓝图中。只有将这些"数据交互的仓库位置"完全固定死,GPU 在实际重放时才能彻底脱离 CPU 的干预,像自动化流水线一样:前一个算子算完直接塞进固定地址,后一个算子毫无延迟地从该地址读取,从而实现微秒级的无缝衔接与极致推理性能。
graph 什么时候被释放?
Graph 被保存在了一个字典里:self.graphs[bs] = graph。 只要包含这个字典的实例对象ModelRunner还存活,Graph 就一直活着。在像 vLLM 这样的实际工业应用中,这些 Graph 几乎永远不会被主动释放。
所以说,只要服务还在运行,这块显存就一直被死死霸占着吗,即使目前没有任何请求也不会被释放,原因主要是
- 为了数据的绝对安全,CUDA Graph 在录制的时候,已经把显存地址(比如
0x1A2B)刻在了底层的机器指令里。它在运行时是直接往0x1A2B这个地址读写数据的。如果把他释放掉了,且被别的东西申请了,那调用graph的时候会该地址直接写东西影响别的数据,造成数据污染 - 为了机制的速度,一直把地圈着,就省去了申请和释放空间的时间,用空间换时间
为什么可以共享graph pool
-
降序录制带来的"空间包容"
有较小 Batch Size 的图,在物理内存上完全内含于最大的那个图。因为大图能装下,小图就绝对不会发生内存越界。
-
执行机制的"时间互斥"
如果 BS=128 和 BS=64 的图在 GPU 上同时运行,共享内存池瞬间就会发生灾难性的数据踩踏。但大模型推理引擎的设计避免了这一点。
- 对于同一个模型实例(同一个 Worker),推理引擎的调度器(Scheduler)是严格串行处理请求的。
- 在任何一个微秒级的物理时间点上,GPU 的这个 CUDA Stream 中绝对只有一个 Graph 在运行。
-
池内数据的"绝对无状态"
在大模型推理中,需要跨时间步保留的"有状态数据"只有两样:
- 模型权重 (Weights):只读,不在 Pool 里。
- KV Cache :记录用户历史对话,必须保留,所以它们存在由
block_tables动态管理的专门显存区里,绝对不在 Pool 里。
共享 Pool 全是纯粹的中间激活值。比如:
- 第 1 层 Transformer 算完后准备传给第 2 层的隐藏状态矩阵。
- Softmax 计算时的临时分母。
这些数据有一个共同特点:朝生夕死。一旦这一次前向传播(即这个 Graph 的执行)结束,它们就彻底变成了电子垃圾。当下一个 Graph(哪怕是不同 BS 的图)被启动时,直接用新数据覆盖这些垃圾,不仅毫无影响,反而省去了清理内存的时间。
py
self.graph_vars = dict(
input_ids=input_ids,
positions=positions,
slot_mapping=slot_mapping,
context_lens=context_lens,
block_tables=block_tables,
outputs=outputs,
)
这是一个字典,为了少传一些参数,我们用一个字典来方便传递参数。我们可以把计算图抽象成一个黑盒测试,他知道去哪里读输入数据,把输出数据写到哪里,因为这些地址都在录制的过程中写好了,所以每次推理的时候,你就需要把本次输入的数据复制到对应地址就可以
比如,你拿着用户新输入的 Token(比如 new_input_ids),把它填进那个早就固定好地址的池子里:
py
# 把新数据复制到静态张量对应的切片中
self.graph_vars["input_ids"][:bs].copy_(new_input_ids)
self.graph_vars["positions"][:bs].copy_(new_positions)
# ... 其他变量同理
关键点:这里用的是 .copy_(),也就是"原地替换"(In-place)。进料口的物理显存地址纹丝不动,只是里面的数字被刷新了。
py
# 找到对应 Batch Size 的那盘录像带,直接执行!
self.graphs[bs].replay()
此时,CPU 瞬间"下班"(开销降到极低),GPU 的硬件调度器接管一切,闭着眼睛根据写死的地址一顿狂算,把临时草稿全部扔进 graph_pool 里复写。
py
# 结果已经乖乖躺在静态的 outputs 槽位里了,直接按需要的长度取走
final_logits = self.graph_vars["outputs"][:bs]