从0开始复现nano-vllm「ModelRunner.capture_cudagraph()」

capture_cudagraph

py 复制代码
if not self.enforce_eager:
    self.capture_cudagraph()


@torch.inference_mode()
def capture_cudagraph(self):
    config = self.config
    hf_config = config.hf_config
    max_bs = min(self.config.max_num_seqs, 512)
    max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size
    input_ids = torch.zeros(max_bs, dtype=torch.int64)
    positions = torch.zeros(max_bs, dtype=torch.int64)
    slot_mapping = torch.zeros(max_bs, dtype=torch.int32)
    context_lens = torch.zeros(max_bs, dtype=torch.int32)
    block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32)
    outputs = torch.zeros(max_bs, hf_config.hidden_size)
    self.graph_bs = [1, 2, 4, 8] + list(range(16, max_bs + 1, 16))
    self.graphs = {}
    self.graph_pool = None

    for bs in reversed(self.graph_bs):
        graph = torch.cuda.CUDAGraph()
        set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs])
        outputs[:bs] = self.model(input_ids[:bs], positions[:bs])    # warmup
        with torch.cuda.graph(graph, self.graph_pool):
            outputs[:bs] = self.model(input_ids[:bs], positions[:bs])    # capture
        if self.graph_pool is None:
            self.graph_pool = graph.pool()
        self.graphs[bs] = graph
        torch.cuda.synchronize()
        reset_context()

    self.graph_vars = dict(
        input_ids=input_ids,
        positions=positions,
        slot_mapping=slot_mapping,
        context_lens=context_lens,
        block_tables=block_tables,
        outputs=outputs,
    )

为什么需要 CUDA Graph?

在 LLM 推理等小算子、高频次 的场景中,CPU 逐个调度任务的开销往往比 GPU 实际计算的时间还要长,导致 GPU 大量空闲等待;CUDA Graph 通过将一系列 GPU 操作"录制"为静态图,在执行时只需一次 CPU 指令即可驱动整个计算流程,从而彻底消除 CPU 调度瓶颈,填满 GPU 流水线,显著降低推理延迟。

下面先初始化一些变量

py 复制代码
@torch.inference_mode()  # 禁用梯度计算,节省显存并加速
def capture_cudagraph(self):
    config = self.config
    hf_config = config.hf_config
    
    # 设定最大 Batch Size,为了安全起见限制在 512 以内
    max_bs = min(self.config.max_num_seqs, 512)
    
    # 计算单个序列 KV Cache 所需的最大 block 数量
    max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size

接下来是关键的一步 :CUDA Graph 要求输入和输出的内存地址是固定的。因此,代码预先分配了一组全零的张量(Tensors)作为"静态缓冲区"。

py 复制代码
# 分配静态输入/输出 Tensor,它们将驻留在 GPU 上
input_ids = torch.zeros(max_bs, dtype=torch.int64)
positions = torch.zeros(max_bs, dtype=torch.int64)
slot_mapping = torch.zeros(max_bs, dtype=torch.int32)
context_lens = torch.zeros(max_bs, dtype=torch.int32)
block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32)
outputs = torch.zeros(max_bs, hf_config.hidden_size)

torch.zeros(),PyTorch 会在内存中申请一块空间,并明确地将所有位置都填充为 0。

以后在推理时,我们不能直接传新 Tensor 给模型,而是必须把数据 copy_ 到这些 input_ids 等静态 Tensor 中,然后重放 Graph。

假设 max_bs 是最大批次大小,max_num_blocks 是为每个序列分配的最大内存块数量。

为什么 input_idspositionsint64?而slot_mapping, context_lens, block_tablesint32

因为PyTorch 的核心层 nn.Embedding 以及很多官方算子,强制要求输入的索引张量必须是 LongTensor (即 int64)。而另外三个不是给 PyTorch 标准层用的,而是传给 vLLM 自定义的 CUDA Kernel(例如 PagedAttention 算子)用的,这些元数据在推理过程中会被高频读取,使用int32相比于int64可以节省一半的显存读取,节省带宽。而且int32大概是2e9,显存再大,物理块的数量也不可能超过 21 亿,context_lens和block_tables也是一样。

变量名 形状 (Shape) 维度含义解析 作用
input_ids (max_bs,) 输入 Token ID。每一行代表当前 Batch 中每个序列正在处理的那个 Token。 输入token
positions (max_bs,) 位置索引。对应每个 Token 在其原始序列中的绝对位置(用于 Position Embedding)。 它决定了 RoPE(旋转位置编码) 的旋转角度。只有知道了位置,模型才能分清"我喜欢你"和"你喜欢我"的区别。
slot_mapping (max_bs,) 槽位映射。指示当前 Token 应该存储在物理 KV Cache 内存池中的哪个具体位置。 表示每个 token 在 KV Cache 池中的物理存储位置(slot 索引),用于 CUDA kernel 将新计算的 K/V 写入正确的缓存位置。
context_lens (max_bs,) 有效长度。记录每个序列到目前为止总共拥有多少个有效的 Token 它是推理时的"边界守护者",告诉模型在计算注意力时,应该回溯查看多少长度的 KV Cache。
block_tables (max_bs, max_num_blocks) 物理块表。每一行映射了一个序列所占用的所有不连续内存块的编号。 显存页表,由于模型在计算 Attention 时,它需要回头看过去所有的历史记录(Key/Value),所以需要一个页表记录每个序列的kv cache存在哪些块里。
outputs (max_bs, hidden_size) 隐层输出。保存模型最后一层计算出的向量,准备进行 Logits 映射。 输出token

接下来定义 Batch Size 策略

py 复制代码
# 定义需要捕获的 Batch Size 列表
# 小 BS:1, 2, 4, 8 (不仅是为了速度,也是为了精确匹配)
# 大 BS:16, 32, 48... 直到 max_bs (以 16 为步长)
self.graph_bs = [1, 2, 4, 8] + list(range(16, max_bs + 1, 16))

self.graphs = {}
self.graph_pool = None  # 用于共享内存池

为什么要分这么多size?

我们不可能为 1 到 512 的每个 batch size 都录制一个图(那样显存会爆炸)。通常的做法是"向上取整":如果你来了 12 个请求,就用 batch size = 16 的图来跑,多余的槽位填 padding。

接下来为不同 batch size 预先录制 CUDA Graph,从而在后续推理时直接 replay 图,避免 Python 调度和 CUDA kernel launch 的开销

py 复制代码
for bs in reversed(self.graph_bs):
    graph = torch.cuda.CUDAGraph() # 这是一个录制器,用来记录接下来所有的 CUDA kernel 调用

    # 1. 设置上下文 (Set Context)
    # 这通常用于 PagedAttention 等自定义算子,告诉它们当前只处理前 bs 个数据
    set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs])


    # 2. Warmup (预热)
    # 在录制前必须先运行一次。这是为了让 PyTorch/CUDA 内部完成 lazy initialization(懒加载),
    # 确保所有显存分配和 buffer 初始化都已完成,避免录制到内存分配操作。
    outputs[:bs] = self.model(input_ids[:bs], positions[:bs])

    # 3. 开启录制 (Capture)
    # self.graph_pool 用于在不同的 Graph 之间共享显存,极大节省内存占用。
    with torch.cuda.graph(graph, self.graph_pool):
        outputs[:bs] = self.model(input_ids[:bs], positions[:bs])

    # 4. 内存池管理
    # 如果是第一次循环(最大的 Batch Size),获取其内存池,供后续较小的 Graph 复用。
    if self.graph_pool is None:
        self.graph_pool = graph.pool()

    # 5. 保存图并清理
    self.graphs[bs] = graph
    torch.cuda.synchronize() # 等待 GPU 完成
    reset_context()          # 清理上下文

为什么要倒序遍历batch size,即从大到小?

因为 CUDA Graph capture 过程中会分配 GPU 内存,如果从大到小录制,可以:先为最大 bs 分配最大需要的内存,后面小 bs 可以复用 graph memory pool,可以减少碎片和额外分配

为什么需要内存池?为什么需要共享的内存池?

在大模型的一趟前向传播中,每一层都会产生大量的中间临时变量。在普通模式下,PyTorch 会在算到这一层时,临时向系统申请一块显存(cudaMalloc)来装这些中间变量,算完再释放。 但是,动态申请显存非常耗时 。CUDA Graph 为了追求极致的速度,严禁 在运行期间进行任何动态内存分配。 因此,在录制 Graph 之前,系统必须提前一次性申请好一块足够大的连续显存作为**"专属草稿纸"。这块提前划定好的草稿纸,就是内存池**。有了它,Graph 才能把所有中间变量的地址当成常量"死死地刻在"指令里。

因为 CUDA Graph 会将所需地址永久锁定,如果不共享内存池,录制多个不同 Batch Size 的图会导致每一张图都霸占一份专属的临时显存,最终令显存占用随图的数量成倍爆炸;而引入共享内存池(graph_pool)后,系统只需按照最大尺寸的图申请唯一一块 显存,巧妙利用大模型推理任务严格串行执行且中间激活值**"朝生夕死"**的特性,让所有不同尺寸的图轮流在这张相同的"临时草稿纸"上安全地覆盖擦写,从而在维持硬件级极速调度的同时,将底层的显存开销压缩到了绝对的最低极限。

内存池里究竟放什么东西, 不放什么东西?

在一次推理过程中,我们会产生一些临时的tensor,比如

py 复制代码
Q = linear(x)
K = linear(x)
V = linear(x)
attn_out = softmax(QK^T)V
mlp_out = ...
residual = ...

这些中间 tensor 的生命周期只在一次 forward 内,下次 forward 可以覆盖,且不需要跨 step 保存,我们会把他放在memory pool中。

而对于KV cache、input_ids、slot_mapping、position_ids、block_tables、模型参数等,是不会存在graph memory pool的

我们的GPU显存可以抽象成这样:

txt 复制代码
┌────────────────────────────┐
│ 模型权重 (长期存在)        │
├────────────────────────────┤
│ KV Cache (长期存在)        │
├────────────────────────────┤
│ slot_mapping 等输入张量    │
├────────────────────────────┤
│ Graph Pool (长期存在)        │  ← 只放中间 tensor
│  Q tensor                  │
│  K tensor                  │
│  V tensor                  │
│  attn_out                  │
│  mlp_out                   │
│  ...                       │
└────────────────────────────┘

graph 为什么要固定地址?

因为"固定地址"是让 GPU 能够"脱离 CPU 独立全速运行"的唯一代价。

CUDA Graph 必须固定显存地址,是为了彻底消除动态内存分配(cudaMalloc)的高昂开销,并实现 GPU 的纯硬件级自主调度。在 Graph 的录制阶段,底层驱动会将所有算子的输入、输出及中间变量的**显存指针直接作为常量"硬编码"**到预编译的执行蓝图中。只有将这些"数据交互的仓库位置"完全固定死,GPU 在实际重放时才能彻底脱离 CPU 的干预,像自动化流水线一样:前一个算子算完直接塞进固定地址,后一个算子毫无延迟地从该地址读取,从而实现微秒级的无缝衔接与极致推理性能。

graph 什么时候被释放?

Graph 被保存在了一个字典里:self.graphs[bs] = graph。 只要包含这个字典的实例对象ModelRunner还存活,Graph 就一直活着。在像 vLLM 这样的实际工业应用中,这些 Graph 几乎永远不会被主动释放。

所以说,只要服务还在运行,这块显存就一直被死死霸占着吗,即使目前没有任何请求也不会被释放,原因主要是

  • 为了数据的绝对安全,CUDA Graph 在录制的时候,已经把显存地址(比如 0x1A2B)刻在了底层的机器指令里。它在运行时是直接往 0x1A2B 这个地址读写数据的。如果把他释放掉了,且被别的东西申请了,那调用graph的时候会该地址直接写东西影响别的数据,造成数据污染
  • 为了机制的速度,一直把地圈着,就省去了申请和释放空间的时间,用空间换时间

为什么可以共享graph pool

  1. 降序录制带来的"空间包容"

    有较小 Batch Size 的图,在物理内存上完全内含于最大的那个图。因为大图能装下,小图就绝对不会发生内存越界。

  2. 执行机制的"时间互斥"

    如果 BS=128 和 BS=64 的图在 GPU 上同时运行,共享内存池瞬间就会发生灾难性的数据踩踏。但大模型推理引擎的设计避免了这一点。

    • 对于同一个模型实例(同一个 Worker),推理引擎的调度器(Scheduler)是严格串行处理请求的。
    • 在任何一个微秒级的物理时间点上,GPU 的这个 CUDA Stream 中绝对只有一个 Graph 在运行。
  3. 池内数据的"绝对无状态"

    在大模型推理中,需要跨时间步保留的"有状态数据"只有两样:

    1. 模型权重 (Weights):只读,不在 Pool 里。
    2. KV Cache :记录用户历史对话,必须保留,所以它们存在由 block_tables 动态管理的专门显存区里,绝对不在 Pool 里

    共享 Pool 全是纯粹的中间激活值。比如:

    • 第 1 层 Transformer 算完后准备传给第 2 层的隐藏状态矩阵。
    • Softmax 计算时的临时分母。

    这些数据有一个共同特点:朝生夕死。一旦这一次前向传播(即这个 Graph 的执行)结束,它们就彻底变成了电子垃圾。当下一个 Graph(哪怕是不同 BS 的图)被启动时,直接用新数据覆盖这些垃圾,不仅毫无影响,反而省去了清理内存的时间。

py 复制代码
self.graph_vars = dict(
    input_ids=input_ids,
    positions=positions,
    slot_mapping=slot_mapping,
    context_lens=context_lens,
    block_tables=block_tables,
    outputs=outputs,
)

这是一个字典,为了少传一些参数,我们用一个字典来方便传递参数。我们可以把计算图抽象成一个黑盒测试,他知道去哪里读输入数据,把输出数据写到哪里,因为这些地址都在录制的过程中写好了,所以每次推理的时候,你就需要把本次输入的数据复制到对应地址就可以

比如,你拿着用户新输入的 Token(比如 new_input_ids),把它填进那个早就固定好地址的池子里:

py 复制代码
# 把新数据复制到静态张量对应的切片中
self.graph_vars["input_ids"][:bs].copy_(new_input_ids)
self.graph_vars["positions"][:bs].copy_(new_positions)
# ... 其他变量同理

关键点:这里用的是 .copy_(),也就是"原地替换"(In-place)。进料口的物理显存地址纹丝不动,只是里面的数字被刷新了。

py 复制代码
# 找到对应 Batch Size 的那盘录像带,直接执行!
self.graphs[bs].replay()

此时,CPU 瞬间"下班"(开销降到极低),GPU 的硬件调度器接管一切,闭着眼睛根据写死的地址一顿狂算,把临时草稿全部扔进 graph_pool 里复写。

py 复制代码
# 结果已经乖乖躺在静态的 outputs 槽位里了,直接按需要的长度取走
final_logits = self.graph_vars["outputs"][:bs]
相关推荐
武汉唯众智创2 小时前
云边端协同落地:唯众AI实训平台技术架构实操解析
人工智能·人工智能实训·ai 实训平台·职教 ai 实训·职教院校实训方案·高校职校实训方案
大猫子的技术日记2 小时前
Playwright 自动化测试入门指南:Python 开发者的端到端实战
开发语言·人工智能·python
数琨创享TQMS质量数智化2 小时前
数琨创享:以数智化质量目标管理闭环赋能可量化、可追溯、可驱动的质量运营
大数据·人工智能·qms质量管理系统
laplace01232 小时前
Kv cache
人工智能·agent·claude·rag·skills
Maynor9962 小时前
OpenClaw 中转站配置完全指南
linux·运维·服务器·人工智能·飞书
马拉AI2 小时前
Transformer范式改变?稀疏线性混合SALA架构发布,单卡5090跑通百万长文!
深度学习·架构·transformer
Eric2232 小时前
CLI-Agent-Manager:面向 Vibe Coding 的多 Agent 统一管理面板
人工智能·后端·开源
如若1232 小时前
SoftGroup训练FORinstance森林点云数据集——从零到AP=0.506完整复现
人工智能·python·深度学习·神经网络·计算机视觉
InternLM2 小时前
LMDeploy重磅更新:从支撑模型到被模型反哺,推理引擎迈入协同进化时代!
人工智能·大模型·多模态大模型·大模型推理·书生大模型