utils/context.py
这段代码定义了一个全局上下文管理器(Global Context Manager)。
它的核心作用是**"传纸条"**。在深度学习框架(如 PyTorch)与底层高性能计算核心(CUDA Kernels,比如 FlashAttention 或 PagedAttention)之间,有很多复杂的元数据(比如这句话有多长、显存存在哪里)需要传递。
为了避免在每个函数的参数列表里写上十几个参数(那样代码会极其丑陋且难以维护),这段代码定义了一个全局的"公告板" _CONTEXT。在运行底层算子之前,Python 端先把这些信息写到公告板上,底层的 C++/CUDA 代码直接去读这个板子里的信息。
py
from dataclasses import dataclass
import torch
@dataclass
class Context:
is_prefill: bool = False
cu_seqlens_q: torch.Tensor | None = None
cu_seqlens_k: torch.Tensor | None = None
max_seqlen_q: int = 0
max_seqlen_k: int = 0
slot_mapping: torch.Tensor | None = None
context_lens: torch.Tensor | None = None
block_tables: torch.Tensor | None = None
_CONTEXT = Context()
def get_context():
return _CONTEXT
def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None):
global _CONTEXT
_CONTEXT = Context(is_prefill, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables)
def reset_context():
global _CONTEXT
_CONTEXT = Context()
代码通过全局变量 _CONTEXT 维护了一个单例状态,并提供了 get_context 获取当前状态、set_context 更新推理参数以及 reset_context 重置环境的辅助函数,这种设计模式能够让底层的自定义算子在执行时直接获取必要的调度信息,而无需在每一层网络结构中显式地层层传递复杂的张量参数,从而简化了模型调用的接口逻辑。
-
其中
is_prefill标记了当前处于首词生成的"预填充"阶段还是后续的"解码"阶段; -
cu_seqlens_q和cu_seqlens_k通常代表Query和Key的累积序列长度,是一个有关序列长度的前缀和数组,在推理时,为了快,我们经常把一个 Batch 里长短不一的句子,这个变量就是一个切分点数组,方便FlashAttention知道哪一段数据属于哪一个用户 -
max_seqlen_q表示新输入数据的最大长度,Maximum Sequence Length of Query- Prefill 阶段,算的是所有用户输入的token的最大值
- Decode 阶段,由于模型已经读完了,现在开始一个字一个字往外蹦。 无论历史有多长,每一轮推理,我们只生成 1 个 新的 Token(Query),所以永远等于1
-
max_seqlen_k表示为了算这个新 Token,我需要回头看多少个历史 Token,这个数字的最大值- Prefill 阶段,每个 Token 都要看它自己以及它前面的所有 Token,所以是所有用户输入的token的最大值,同
max_seqlen_q - Decode 阶段,由于模型已经生成了一些,所以是在不断变化的,是每一个序列输入+生成的token总数的最大值
- Prefill 阶段,每个 Token 都要看它自己以及它前面的所有 Token,所以是所有用户输入的token的最大值,同
-
slot_mapping表示每个 token 在 KV Cache 池中的物理存储位置(slot 索引),用于 CUDA kernel 将新计算的 K/V 写入正确的缓存位置。-
这个东西比较有意思,我们对输入的所有token都需要存他们的kv cache,假设输入有
batch_size个序列,每个序列有若干个token,形状大概是个[batch_size, sequence_size],这里面每个序列的长度可能不一样,而kv cache的形状是[num_blocks, block_size],我们这里想搞一个token到kv cahce存储位置的映射表,如果搞一个二维对二维的map,是不太友好的,所以我们可以直接把他们压成一维,即我们搞一个长度为[total_num_tokens]的map,[total_num_tokens]是所有Batch的所有Token数,而映射值slot_mapping[i]是从0到batch_size * block_size的一个数字,代表第i个 Token 对应的物理显存地址索引就是slot_mapping[i],由于我们申请内存的时候,本身也就是一条长长的一维东西,所以很适合这样把我们的kv cache直接压成一维 -
这样做的好处是:底层 CUDA Kernel 在写入时,不需要维护复杂的二维寻址逻辑,省去了 Kernel 内部的乘法和加法运算,直接拿
slot_mapping[i]这个整数当指针偏移量往显存里写数据就行了,速度最快。而且可以实现解耦,底层 Kernel 不需要知道block_size是 16 还是 32,也不需要知道显存是怎么分块的。 它只知道:"给我一个偏移量,我往里写数据。" -
主要是在存数据的时候用
-
-
block_tables是显存页表,由于模型在计算 Attention 时,它需要回头看过去所有的历史记录(Key/Value),所以需要一个页表记录每个序列的kv cache存在哪些块里。- 形状是
[batch_size, max_num_blocks_per_seq],它是一张二维表格。每一行代表一个输入序列,每一列代表这个序列用到的几个物理块(Block ID)。 - 主要是在读数据的时候用
- 形状是
-
_CONTEXT = Context()初始化一个空的全局变量。程序启动时,板子上是空的。
py
def get_context():
return _CONTEXT
读板子:底层的算子通过调用这个函数,拿到当前的配置信息。
*
py
def set_context(is_prefill, ...):
global _CONTEXT
_CONTEXT = Context(...)
写板子 :在每次 model.forward() 之前,Python 主程序会调用这个函数,把当前这一轮推理的所有元数据(是不是 Prefill?块表在哪里?长度是多少?)全部填好,更新到全局变量里。
*
py
def reset_context():
global _CONTEXT
_CONTEXT = Context()
擦板子:推理结束后,清空上下文,防止下一次推理误用了旧数据。