从0开始复现nano-vllm「 utils/contex.py」

utils/context.py

这段代码定义了一个全局上下文管理器(Global Context Manager)

它的核心作用是**"传纸条"**。在深度学习框架(如 PyTorch)与底层高性能计算核心(CUDA Kernels,比如 FlashAttention 或 PagedAttention)之间,有很多复杂的元数据(比如这句话有多长、显存存在哪里)需要传递。

为了避免在每个函数的参数列表里写上十几个参数(那样代码会极其丑陋且难以维护),这段代码定义了一个全局的"公告板" _CONTEXT。在运行底层算子之前,Python 端先把这些信息写到公告板上,底层的 C++/CUDA 代码直接去读这个板子里的信息。

py 复制代码
from dataclasses import dataclass
import torch


@dataclass
class Context:
    is_prefill: bool = False
    cu_seqlens_q: torch.Tensor | None = None
    cu_seqlens_k: torch.Tensor | None = None
    max_seqlen_q: int = 0
    max_seqlen_k: int = 0
    slot_mapping: torch.Tensor | None = None
    context_lens: torch.Tensor | None = None
    block_tables: torch.Tensor | None = None

_CONTEXT = Context()

def get_context():
    return _CONTEXT

def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None):
    global _CONTEXT
    _CONTEXT = Context(is_prefill, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables)

def reset_context():
    global _CONTEXT
    _CONTEXT = Context()

代码通过全局变量 _CONTEXT 维护了一个单例状态,并提供了 get_context 获取当前状态、set_context 更新推理参数以及 reset_context 重置环境的辅助函数,这种设计模式能够让底层的自定义算子在执行时直接获取必要的调度信息,而无需在每一层网络结构中显式地层层传递复杂的张量参数,从而简化了模型调用的接口逻辑。

  • 其中 is_prefill 标记了当前处于首词生成的"预填充"阶段还是后续的"解码"阶段;

  • cu_seqlens_qcu_seqlens_k 通常代表Query和Key的累积序列长度,是一个有关序列长度的前缀和数组,在推理时,为了快,我们经常把一个 Batch 里长短不一的句子,这个变量就是一个切分点数组,方便FlashAttention知道哪一段数据属于哪一个用户

  • max_seqlen_q表示新输入数据的最大长度,Maximum Sequence Length of Query

    • Prefill 阶段,算的是所有用户输入的token的最大值
    • Decode 阶段,由于模型已经读完了,现在开始一个字一个字往外蹦。 无论历史有多长,每一轮推理,我们只生成 1 个 新的 Token(Query),所以永远等于1
  • max_seqlen_k表示为了算这个新 Token,我需要回头看多少个历史 Token,这个数字的最大值

    • Prefill 阶段,每个 Token 都要看它自己以及它前面的所有 Token,所以是所有用户输入的token的最大值,同max_seqlen_q
    • Decode 阶段,由于模型已经生成了一些,所以是在不断变化的,是每一个序列输入+生成的token总数的最大值
  • slot_mapping表示每个 token 在 KV Cache 池中的物理存储位置(slot 索引),用于 CUDA kernel 将新计算的 K/V 写入正确的缓存位置。

    • 这个东西比较有意思,我们对输入的所有token都需要存他们的kv cache,假设输入有batch_size个序列,每个序列有若干个token,形状大概是个[batch_size, sequence_size],这里面每个序列的长度可能不一样,而kv cache的形状是[num_blocks, block_size],我们这里想搞一个token到kv cahce存储位置的映射表,如果搞一个二维对二维的map,是不太友好的,所以我们可以直接把他们压成一维,即我们搞一个长度为 [total_num_tokens]的map, [total_num_tokens]是所有Batch的所有Token数,而映射值slot_mapping[i]是从0 batch_size * block_size的一个数字,代表第 i 个 Token 对应的物理显存地址索引就是 slot_mapping[i],由于我们申请内存的时候,本身也就是一条长长的一维东西,所以很适合这样把我们的kv cache直接压成一维

    • 这样做的好处是:底层 CUDA Kernel 在写入时,不需要维护复杂的二维寻址逻辑,省去了 Kernel 内部的乘法和加法运算,直接拿slot_mapping[i]这个整数当指针偏移量往显存里写数据就行了,速度最快。而且可以实现解耦,底层 Kernel 不需要知道 block_size 是 16 还是 32,也不需要知道显存是怎么分块的。 它只知道:"给我一个偏移量,我往里写数据。"

    • 主要是在存数据的时候用

  • block_tables 是显存页表,由于模型在计算 Attention 时,它需要回头看过去所有的历史记录(Key/Value),所以需要一个页表记录每个序列的kv cache存在哪些块里。

    • 形状是 [batch_size, max_num_blocks_per_seq],它是一张二维表格。每一行代表一个输入序列,每一列代表这个序列用到的几个物理块(Block ID)。
    • 主要是在读数据的时候用
  • _CONTEXT = Context() 初始化一个空的全局变量。程序启动时,板子上是空的。

py 复制代码
def get_context():
    return _CONTEXT

读板子:底层的算子通过调用这个函数,拿到当前的配置信息。
*

py 复制代码
def set_context(is_prefill, ...):
    global _CONTEXT
    _CONTEXT = Context(...)

写板子 :在每次 model.forward() 之前,Python 主程序会调用这个函数,把当前这一轮推理的所有元数据(是不是 Prefill?块表在哪里?长度是多少?)全部填好,更新到全局变量里。
*

py 复制代码
def reset_context():
    global _CONTEXT
    _CONTEXT = Context()

擦板子:推理结束后,清空上下文,防止下一次推理误用了旧数据。

相关推荐
程序员鱼皮19 小时前
吴恩达新的免费 AI 课来了,YYDS!我已经学上了
计算机·ai·程序员·编程·ai编程
adouwt19 小时前
给 AI 用的代码索引器-技术篇
ai
俊哥V19 小时前
AI一周事件 · 2026-04-29 至 2026-05-05
人工智能·ai
阿Y加油吧20 小时前
RAG 检索→召回→增强→生成完整流程
ai
动物园猫20 小时前
公共安全打架行为识别数据集分享(适用于YOLO系列深度学习检测任务)
人工智能·深度学习·yolo
AI棒棒牛20 小时前
YOLOv13最新创新改进系列:比闪电还快的医学影像分析!YOLOv13+EMCAD融合实战,改进代码已跑通!cvpr2025最新独家改进!
深度学习·yolo·目标检测·计算机视觉
bst@微胖子20 小时前
PyTorch深度学习框架之基于RNN实现AI歌词生成器
深度学习
赵优秀一一20 小时前
AI入门学习
人工智能·pytorch·深度学习
2zcode20 小时前
原创文档:基于MATLAB深度学习与传统机器学习的脑肿瘤MRI图像分类系统
深度学习·机器学习·分类
这张生成的图像能检测吗20 小时前
(论文速读)基于多模态融合学习的航空发动机叶片损伤检测与测量
人工智能·深度学习·神经网络·计算机视觉·三维测量