MPK(Mirage Persistent Kernel)源码笔记(2)--- 多层结构化图模型
目录
- [MPK(Mirage Persistent Kernel)源码笔记(2)--- 多层结构化图模型](#MPK(Mirage Persistent Kernel)源码笔记(2)--- 多层结构化图模型)
- [0x00 概要](#0x00 概要)
- [0x01 机制](#0x01 机制)
- [1.1 当前问题](#1.1 当前问题)
- [1.2 解决方案](#1.2 解决方案)
- [1.2.1 μGraphs:多层次计算图表示](#1.2.1 μGraphs:多层次计算图表示)
- [1.2.2 归纳式程序合成:优化范式](#1.2.2 归纳式程序合成:优化范式)
- [0x02 多层次计算图表示](#0x02 多层次计算图表示)
- [2.1 概念](#2.1 概念)
- [2.2 层级关系](#2.2 层级关系)
- [2.3 对比](#2.3 对比)
- [2.4 执行关系](#2.4 执行关系)
- [0x03 内核图](#0x03 内核图)
- [3.1 PersistentKernel调用](#3.1 PersistentKernel调用)
- [3.2 Python 代码](#3.2 Python 代码)
- [3.3 桥梁](#3.3 桥梁)
- [3.4 C++ 代码](#3.4 C++ 代码)
- [3.5 KNOperator](#3.5 KNOperator)
- [3.6 生成样例](#3.6 生成样例)
- [0x04 线程块图](#0x04 线程块图)
- [4.1 属性](#4.1 属性)
- [4.1.1 网格尺寸](#4.1.1 网格尺寸)
- [4.1.2 For-loop 尺寸](#4.1.2 For-loop 尺寸)
- [4.2 Python 代码](#4.2 Python 代码)
- [4.3 桥梁](#4.3 桥梁)
- [4.4 C++代码](#4.4 C++代码)
- [4.5 TBOperator](#4.5 TBOperator)
- [4.6 生成样例](#4.6 生成样例)
- [4.6.1 Python代码直接构建](#4.6.1 Python代码直接构建)
- [4.6.2 PersistentKernel 的 layer 方法间接构建](#4.6.2 PersistentKernel 的 layer 方法间接构建)
- [4.6.3 C++代码直接构建](#4.6.3 C++代码直接构建)
- [4.1 属性](#4.1 属性)
- [0x05 线程图](#0x05 线程图)
- [0xFF 参考](#0xFF 参考)
0x00 概要
Mirage 使用 uGraph 来指定在 GPU 上执行张量程序。uGraph 包含多个级别的层次化图,以表示在内核、块和线程级别的计算。下图是GQA对应的μGraphs,显示了一个用于计算GQA的 uGraph。我们用它作为运行示例来解释 uGraph 的关键组成部分。
0x01 机制
1.1 当前问题
LLM 的计算过程通常以计算图的形式表示,其中每个节点对应一个计算算子(如矩阵乘法、注意力机制)或集合通信原语(如 all-reduce),边表示算子间的数据依赖关系。现有系统通常为每个算子启动独立的 GPU 内核。然而,这种"单算子单内核"的执行模型难以实现 pipeline 优化,因为依赖关系是在整个内核的粗粒度层面强制执行的,而非实际数据单元层面。
例如,矩阵乘法(matmul)后接 all-reduce 操作:现有系统中,all-reduce 内核必须等待整个 matmul 内核完成。而实际上,all-reduce 的每个数据分块仅依赖 matmul 输出的局部结果。这种逻辑依赖与实际依赖的错配,严重限制了计算与通信的重叠潜力。下图的右侧展示次优方案 ------ 其引入不必要的数据依赖与全局屏障,导致跨层流水线优化机会受限。
1.2 解决方案
为了解决这一问题,Mirage实现了多层次计算图表示(μGraphs)与归纳式程序合成(Inductive Program Synthesis)。这两大机制协同作用,实现了从宏观调度到微观计算的全链路优化,高效生成GPU程序,显著提升了张量计算的性能。
Mirage 的编译流程清晰且目标明确:
- 输入:来自预定义算子集合的计算图子图(如 GQA 注意力计算子图),确保输入逻辑的规范性与可优化性;
- 核心优化步骤:包含图重写(Graph Rewrite,调整图结构以适配 GPU 架构)、算子融合(Operator Fusion,减少内存访问次数)等,所有优化均基于 μGraphs 的跨层级表示展开;
- 输出:优化后的 CUDA 程序,直接适配 GPU 硬件执行,可直接JIT嵌入pytorch。
1.2.1 μGraphs:多层次计算图表示
MPK 编译器将 LLM 计算图自动转化为细粒度任务图,最大化暴露并行性。该任务图在子内核级别显式捕获依赖关系,实现更激进的跨层流水线优化。具体而言,在 MPK 任务图中(参见上图):
- 任务(矩形表示):代表分配给单个 GPU 流式多处理器(SM)的计算或通信单元。
- 事件(圆形表示):表示任务间的同步点。
- 触发机制:每个任务发出指向触发事件的边,该事件在关联任务全部完成后激活。
- 依赖机制:每个任务接收来自依赖事件的边,表明事件激活后任务立即启动。
任务图使 MPK 能够发掘计算图中无法实现的 pipeline 优化机会。例如,MPK 可以构建优化任务图 ------ 其中每个 all-reduce 任务仅依赖于生成其输入的对应 matmul 任务,从而实现分块执行与计算通信重叠。
除生成优化任务图外,MPK 还通过 Mirage 内核超优化器自动为每个任务生成高性能 CUDA 实现,确保任务在 GPU 流式多处理器(SM)上高效执行。
1.2.2 归纳式程序合成:优化范式
归纳式程序合成是Mirage的另一大核心机制。与传统的演绎式程序合成(如基于规则的重写系统)不同,归纳式程序合成直接从语法出发构造程序,并借助SMT求解器验证构造程序与原程序的等价性。这种方法能够突破传统优化方法的局限,发现将代数变换、调度变换和新自定义内核生成相结合的创新优化路径。
通过归纳式程序合成,Mirage能够自动生成高性能的GPU内核代码,不仅简化了开发流程,还提升了程序的运行效率,使得开发者能够更专注于高层逻辑的设计,而无需深入底层硬件细节。
传统机器学习编译器(如 TVM、TensorRT)采用演绎式程序合成(Deductive Program Synthesis,又称 Term Rewrite) :从原始程序出发,通过等价重写规则(如图模式匹配、循环调度原语)逐步变换,始终在 "程序等价类" 内搜索更优实现 ------ 这种方式依赖手工设计规则,难以突破现有等价类的性能上限。
Mirage 则采用归纳式程序合成:不依赖原始程序的逐步变换,而是直接基于算子语法构造全新候选程序,再通过 "μGraphs 语义校验 + 概率等价验证"(如有限域随机测试)确认候选程序与原始程序的功能一致性。这种范式无需受限于等价重写规则,可探索更灵活的跨层级优化方案(如 Kernel-Graph 合成算子与 Block-Graph 共享内存复用的协同),同时通过概率验证保障正确性。
下图是Mirage找出的最佳μGraphs。
0x02 多层次计算图表示
Mirage 实现了多层次计算图表示(μGraphs),通过 kernel-graph(内核图)、block-graph(块图)和 thread-graph(线程图)这三层结构化图模型,精确映射 GPU 程序从内核到线程的执行逻辑与存储层级。这种三层结构与 CUDA 程序的执行层级及 GPU 的存储体系紧密对应,每层均清晰定义了 "算子类型 --- 张量存储 --- 核心功能" 的关联关系。
2.1 概念
三层的概念如下:
- kernel-graph(内核图):属于高层次抽象,用于表示整个计算图(即完整的计算任务),包含粗粒度的高层操作(如完整的矩阵乘法、规约运算等)与对应数据。该层负责全局调度,重点关注数据流与任务间的依赖关系,对应 GPU 的全局内存,主要处理宏观层面的任务分配与协同。其包含的算子(举例)类型有:
- 高层操作:KN_INPUT_OP(输入算子)、KN_OUTPUT_OP(输出算子)、KN_MATMUL_OP(矩阵乘法算子);
- 数学操作:KN_EXP_OP(指数运算算子)、KN_ADD_OP(加法算子)、KN_MUL_OP(乘法算子);
- 规约操作:KN_REDUCTION_0_OP(零阶规约算子)等;
- 自定义操作:KN_CUSTOMIZED_OP(自定义算子)等。
- block-graph(块图):属于中等层次抽象,嵌套在 KN_CUSTOMIZED_OP(自定义内核算子)中,定义 threadblock(线程块)级别的计算逻辑。该层包含细粒度操作,负责管理线程块级别的并行计算,重点关注内存访问模式、循环结构等中观细节,对应 GPU 的共享内存,核心目标是优化中观层面的资源利用与数据共享效率。其包含的算子类型(举例)有:
- 输入操作:TB_INPUT_OP(线程块输入算子);
- 内存操作:TB_MATMUL_OP(线程块矩阵乘法算子)、TB_EXP_OP(线程块指数运算算子);
- 特殊操作:TB_FORLOOP_ACCUM_NO_RED_OP(线程块循环累加无规约算子)、TB_RMS_NORM_OP(线程块 RMS 归一化算子)。
- thread-graph(线程图):在 block-graph 的具体操作中体现,定义线程级别的执行细节。该层专注于线程级别的微观计算逻辑,对应 GPU 的寄存器,核心作用是确保每个线程的高效执行,最大化单线程的计算吞吐量。
这种三层结构支持系统在不同抽象层级开展针对性优化:
- 在 kernel-graph 层,主要进行全局任务调度与数据流优化,明确整体计算流程与资源分配方向;
- 在 block-graph 层,侧重线程块级别的并行策略优化,提升中观层面的并行效率与数据共享能力;
- 在 thread-graph 层,聚焦具体的内存访问模式优化与计算指令调度,确保微观执行的高效性。
若用通俗语言概括三层结构的分工:kernel-graph 决定 "要做什么"(明确整体计算任务与目标),block-graph 决定 "该怎么做"(规划线程块级的执行方案),thread-graph 负责 "具体执行"(完成线程级的微观计算)。
这种从宏观到微观的层次化设计,使 μGraphs 能够实现从全局调度到局部执行的全链路优化,有效减少计算冗余与资源浪费,确保 GPU 计算资源的高效利用。
2.2 层级关系
三级图结构的关系如下图所示。
muGraph(Kernel Graph)
│
├────► KNOperator(各种标准操作)
│
│
└────► KNCustomizeOp(自定义操作)
│
└───► block-graph(Threadblock Graph)
│
├────► TBOperator(各种线程块操作)
│
└────► TBInputOp(连接到muGraph的张量)
│
└───► thread-level execution(线程级执行)
2.3 对比
三层的对比如下。
| 计算图层级 | 对应 CUDA 执行层级 | 张量存储位置 | 算子类型与功能 | 核心属性 / 逻辑 |
|---|---|---|---|---|
| Kernel-Graph | 整个 GPU 内核(多流处理器 SM 协同) | 设备全局内存(Device DRAM) | 1. 预定义算子 :直接调用厂商库内核(如 cuBLAS 的 GEMM 矩阵乘、cuDNN 的卷积); 2. 合成算子:需通过更低层级的 Block-Graph 描述,承载算子融合、自定义算法等复杂逻辑 | 无额外属性,核心是 "调度多 SM 协同",通过预定义算子复用成熟库性能,合成算子支持灵活优化 |
| Block-Graph | 单个流处理器 SM(线程块协作) | 共享内存(Shared Memory) | 1. 预定义算子 :调用 CUTLASS、ThunderKittens 等库的共享内存操作(如块内矩阵乘、累加); 2. 合成算子:由 Thread-Graph 描述,实现线程块内细粒度计算 | 1. 并行切分属性 :imap(输入分块,映射 Grid 维度到输入张量维度)、omap(输出拼接,映射 Grid 维度到输出张量维度)、fmap(循环迭代,映射 For-Loop 维度到数据迭代器 / 累加器维度); 2. 执行逻辑:支持线程块循环迭代,通过共享内存复用与 "计算 - 访存重叠",将全局内存读写延迟隐藏在计算过程中 |
| Thread-Graph | 单个线程(寄存器操作) | 线程私有寄存器(Register File) | 仅含预定义算子,描述单个线程内的寄存器级流水操作(如 load 数据→元素级计算→store 结果),支持循环迭代与寄存器累加;默认通过 "规则化融合" 快速生成,避免细粒度层级的冗余搜索 | 核心是 "单线程高效流水",通过寄存器操作最小化内存访问,提升计算密度 |
2.4 执行关系
persistent_kernel.py是 Persistent Kernel的Python接口,本质是Python到CUDA持久化内核系统的桥梁,允许用户用python定义复杂的计算图,然后在GPU上高效执行。
persistent_kernel.py与三层计算图的关系如下:
- Persistent Kernel 创建并管理 Kernel Graph
- Kernel Graph 通过 KN_CUSTOMIZED_OP 包含多个 Block Graph
- 每个 Block Graph 定义线程块内的操作序列
- Kernel Graph 转换为 Task Graph 用于执行
- Task Execution Engine 在 Persistent Kernel 中执行任务
- Event System 管理任务间的依赖和同步
- Thread Graph 在实际GPU线程中执行具体操作
0x03 内核图
每个张量程序对应一个内核图,其中每个节点代表在整個 GPU 上运行的内核,每条边是内核之间共享的张量。内核图中的所有张量都存储在 GPU 设备内存中,因为不同的内核不能在寄存器文件或共享内存中共享数据。内核图中的每个节点都可以是现有内核库(如 cuDNN 的卷积和 cuBLAS 的矩阵乘法)支持的预定义内核操作符。此外,为了启用细粒度的内核间优化(如内核融合),内核图中的节点也可以是图定义的内核操作符,其语义和行为由较低级别的(即块)图定义。下图中的两个内核操作符都是图定义的操作符,每个都由块图指定。
3.1 PersistentKernel调用
在PersistentKernel内部,kn_graph负责实际的计算图构建。
self.kn_graph = KNGraph(CyKNGraph(disable_fingerprint=True))
每个attach_input和new_tensor调用都会在kn_graph中创建张量节点。每个layer调用也会在kn_graph中添加相应的计算节点。最后compile()调用self.kn_graph.generate_task_graph生成任务图。
3.2 Python 代码
内核图在Python中的类是KNGraph。KNGraph用于构建和管理内核计算图。比如,new_input会创建新的输入变量。attach_torch_tensor管理PyTorch变量。attach_cuda_tensor关联CUDA变量。compile会生成最终的执行代码。
KNGraph的特点如下:
-
Kernel graph的节点是:
- 预定义算子(pre-defined operator),比如cuBLAS GEMM、cuDNN Conv
- 合成算子(graph-defined operator),用更低一层的block graph描述,可承载fusion/新算法。
-
Kernel graph的边是:位于全局内存(Device DRAM)的Tensor。
KNGraph 代码举例如下:
python
class KNGraph:
def __init__(self, graph):
self.cygraph = graph
self._is_compiled = False
self.run = None
self._valid_cuda_kernels = False
self._cached_results = None
self.visualizer = None
self.backend = "cuda"
def new_input(
self, dims: tuple, strides: tuple = None, dtype: dtype = float16
) -> DTensor:
# use the default strided layout if strides = None
if strides is None:
total_elements = 1
strides = []
for d in reversed(dims):
strides.append(total_elements)
total_elements *= d
strides = reversed(strides)
return self.cygraph.new_input(dims, tuple(strides), dtype)
def compile(self, async_=False, **kwargs):
if self._is_compiled:
return self._cached_results
input_tensors = kwargs.get("inputs", [])
input_strides = []
for i in range(len(dtensors)):
dims, strides = self.cygraph.get_input_dtensor_shape_and_stride(dtensors[i])
input_strides.append(strides)
target_cc = kwargs.get(
"target_cc",
torch.cuda.get_device_properties(0).major * 10
+ torch.cuda.get_device_properties(0).minor,
)
num_warp_groups = kwargs.get("num_warp_groups", 2)
pipeline_stages = kwargs.get("pipeline_stages", 2)
enable_online_softmax = kwargs.get("enable_online_softmax", False)
result = generate_cuda_program(
self.cygraph,
target_cc=target_cc,
input_strides=input_strides,
num_warp_groups=num_warp_groups,
pipeline_stages=pipeline_stages,
profiling=profiling,
enable_online_softmax=enable_online_softmax,
)
if result["max_smem_size"] > get_shared_memory_capacity(target_cc):
self._is_compiled = True
self._valid_cuda_kernels = False
self._error_message = "shared memory usage exceed limit"
if async_:
return Handle([], None)
else:
return None
MIRAGE_ROOT, INCLUDE_PATH, DEPS_PATH = get_key_paths()
tempdir_obj = tempfile.TemporaryDirectory()
tempdir = tempdir_obj.name
saved_addr = ""
file_id = kwargs.get("file_id", -1)
if file_id != -1:
print(f"file_id: {file_id}")
saved_addr = f"./generated_codes/{file_id}/"
FILE_NAME = os.path.join(tempdir, "test.cu")
so_path = os.path.join(tempdir, "test.cpython-38-x86_64-linux-gnu.so")
with open(FILE_NAME, "w") as f:
f.write(result["code"] + HARD_CODE)
if saved_addr != "":
print(f"saved_addr: {saved_addr}")
os.makedirs(saved_addr, exist_ok=True)
with open(saved_addr + "test" + str(file_id) + ".cu", "w") as f:
f.write(result["code"] + HARD_CODE)
cc = shutil.which("nvcc")
# This function was renamed and made public in Python 3.10
if hasattr(sysconfig, "get_default_scheme"):
scheme = sysconfig.get_default_scheme()
else:
scheme = sysconfig._get_default_scheme()
if scheme == "posix_local":
scheme = "posix_prefix"
py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
cc_cmd = get_cc_cmd(
target_cc,
cc,
FILE_NAME,
py_include_dir,
INCLUDE_PATH,
DEPS_PATH,
so_path,
profiling,
)
def remain_op():
import importlib.util
try:
spec = importlib.util.spec_from_file_location(
"__mirage_launcher", so_path
)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
self.run = getattr(mod, "launch")
self._is_compiled = True
self._valid_cuda_kernels = True
self._cached_results = result
self._error_message = "No error"
tempdir_obj.cleanup()
return self._cached_results
except ImportError:
self._is_compiled = True
self._valid_cuda_kernels = False
self._cached_results = None
self._error_message = "CUDA compilation error"
return None
if async_:
if global_config.bypass_compile_errors:
ret = subprocess.Popen(
cc_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT
)
else:
ret = subprocess.Popen(cc_cmd)
return Handle([ret], remain_op)
else:
ret = subprocess.check_call(cc_cmd)
return remain_op()
3.3 桥梁
PersistentKernel 中,通过如下方式进行设置 Kernel Graph。
python
self.kn_graph = KNGraph(CyKNGraph(disable_fingerprint=True))
在python\mirage_cython\core.pyx 文件中,CyKNGraph 中有定义 CppKNGraph。
python
cdef class CyKNGraph:
cdef CppKNGraph *p_kgraph #Hold a CppKNGraph instance
def __cinit__(self, graph = None, bool disable_fingerprint = False):
cdef unsigned long long ptr
cdef dim3 c_gpu_dim
if graph is None:
c_gpu_dim.x = 1
c_gpu_dim.y = 1
c_gpu_dim.z = 1
self.p_kgraph = new CppKNGraph(c_gpu_dim, disable_fingerprint)
else:
ptr = ctypes.cast(graph, ctypes.c_void_p).value
self.p_kgraph = <CppKNGraph*>(ptr)
在 python\mirage_cython\CCore.pxd 文件中,指明 CppKNGraph 对应了 "mirage::kernel::Graph",这便是C++代码中,Kernel Graph 的实现。
python
cdef cppclass CppKNGraph "mirage::kernel::Graph":
CppKNGraph(dim3 gpu_dim, bool disable_fingerprint)
CppDTensor* new_input_ptr(vector[int] dims,
vector[size_t] strides,
DataType data_type,
DmemLayout layout)
void mark_output(const CppDTensor* A, vector[size_t] strides)
CppDTensor* matmul(const CppDTensor* A, const CppDTensor* B)
CppDTensor* reduction(const CppDTensor* input, int dim, int size)
CppDTensor* rms_norm(const CppDTensor* input, vector[int])
CppDTensor* exp(const CppDTensor* input)
CppDTensor* silu(const CppDTensor* input)
CppDTensor* gelu(const CppDTensor* input)
CppDTensor* relu(const CppDTensor* input)
CppDTensor* clamp(const CppDTensor* input, float min_val, float max_val)
CppDTensor* sqrt(const CppDTensor* input)
CppDTensor* square(const CppDTensor* input)
CppDTensor* add(const CppDTensor* op1, const CppDTensor* op2)
CppDTensor* mul(const CppDTensor* op1, const CppDTensor* op2)
CppDTensor* div(const CppDTensor* op1, const CppDTensor* op2)
CppDTensor* pow(const CppDTensor* op1, const CppDTensor* op2)
int customized(vector[const CppDTensor*] inputs,
CppDTensor** outputs,
CppTBGraph* bgraph)
int get_num_input_dtensors()
int get_num_output_dtensors()
int get_input_dtensors(CppDTensor** cinputs)
int get_input_dtensor_shape_and_stride(const CppDTensor *input, int *strides, int *dims)
void generate_triton_program(const char *filepath)
void generate_cuda_program(const char *filepath)
size_t get_owner_independent_hash() const
# Persistent kernel functions
void attach_torch_tensor(const CppDTensor *input,
void *torch_data_ptr,
const char *name)
void attach_cuda_tensor(const CppDTensor *input,
const char *name)
void attach_nvshmem_tensor(const CppDTensor *input,
const char *name)
CppDTensor* fuse_tensors(vector[const CppDTensor*] inputs,
int fused_dim,
int num_groups,
const char *name)
void register_task(const char *task_type,
vector[int] params)
TaskGraphResult generate_task_graph(int num_gpus, int my_gpu_id)
vector[CppKNOperator*] operators
3.4 C++ 代码
muGraph在c++代码中体现为mirage::kernel::Graph类,这是最高层次的计算图。
c++
namespace mirage {
namespace kernel {
class Graph {
private:
struct pair_hash {
size_t operator()(std::pair<int, int> const &p) const;
};
public:
Graph(dim3 gpu_dim = {1, 1, 1}, bool disable_fingerprint = false);
~Graph();
Graph(Graph const &) = delete;
Graph &operator=(Graph const &) = delete;
// input operator
DTensor new_input(std::vector<int> const &dims,
std::vector<size_t> const &strides,
mirage::type::DataType data_type,
mirage::layout::DmemLayout layout);
DTensor elementunary(DTensor const &input,
mirage::type::KNOperatorType _type);
// 忽略其它函数
public:
std::vector<mirage::kernel::KNOperator *> operators; // 操作符列表
dim3 gpu_dim;
off_t dmem_data_offset, dmem_fp_offset;
std::vector<std::pair<off_t, size_t>> allocated_data_tensors,
allocated_fp_tensors;
// Fields for persistent kernels
std::map<mirage::type::GuidType, mirage::runtime::IODesc> io_config;
std::unordered_map<mirage::kernel::KNOperator const *,
std::tuple<int, int, runtime::TaskType, int>>
task_config;
using OpType = KNOperator;
using TensorType = DTensor;
};
mirage::kernel::Graph的主要特征是:
- 操作符类型:使用KNOperatorType 枚举定义操作类型。
- 张量表示:使用DTensor(Device Tensor)表示数据。
- 操作节点:包括输入(KN_INPUT_OP),输出(KN_OUTPUT_OP),矩阵乘法(KN_MATMUL_OP)等。
mirage::kernel::Graph的成员函数以 elementunar 为例,代码如下:
c++
DTensor Graph::elementunary(DTensor const &input,
mirage::type::KNOperatorType type) {
KNOperator *op = create_elementunary_op(input, type);
assert(op != nullptr);
operators.push_back(op);
assert(op->output_tensors.size() == 1);
DTensor output = op->output_tensors[0];
return output;
}
3.5 KNOperator
Graph包含多个KNOperator对象。
KNOperator是内核级别的操作符基类,用于表示计算图中的节点。作为计算图中每个操作的基本单元,可以维护输入和输出张量的信息,提供操作类型表示。而且,通过输入输出张量的连接关系,可以建立操作间的依赖关系,为后续的任务调度和事件管理提供基础。
在runtime.cc中,系统通过遍历Graph中的operators来生成任务图。
c++
class KNOperator {
public:
KNOperator(Graph *graph, mirage::type::KNOperatorType _type);
KNOperator(Graph *graph,
mirage::type::KNOperatorType _type,
DTensor const &input1);
KNOperator(Graph *graph,
mirage::type::KNOperatorType _type,
DTensor const &input1,
DTensor const &input2);
KNOperator(Graph *graph,
mirage::type::KNOperatorType _type,
std::vector<DTensor> const &inputs);
int get_input_dtensors(DTensor **inputs);
int get_output_dtensors(DTensor **inputs);
virtual ~KNOperator();
virtual bool fingerprint(void) = 0;
virtual operator json() const = 0; // 将操作序列转换为JSON格式
// hash related functions
virtual size_t get_owner_independent_hash() const;
public:
Graph *kgraph; // 通过该指针维护与所属计算图的关联
mirage::type::KNOperatorType op_type; // 标识操作类型
std::vector<DTensor> input_tensors; // 存储操作的输入张量
std::vector<DTensor> output_tensors; // 存储操作的输出张量
};
KNCustomizedOp,KNInputOp,KNOutputOp是KNOperator的派生类。KNOperator的派生类举例。
c++
class KNInputOp : public KNOperator {
public:
KNInputOp(Graph *_graph,
std::vector<int> const &dims,
std::vector<size_t> const &strides,
mirage::type::DataType data_type,
mirage::layout::DmemLayout layout,
int3 input_map = {-1, -1, -1});
~KNInputOp();
bool fingerprint(void);
operator json() const override;
public:
std::vector<size_t> input_strides;
int3 input_map;
};
class KNOutputOp : public KNOperator {
public:
KNOutputOp(Graph *_graph,
DTensor const &A,
std::vector<size_t> const &strides,
int3 output_map = {-1, -1, -1});
~KNOutputOp();
bool fingerprint(void);
operator json() const override;
public:
std::vector<size_t> output_strides;
int3 output_map;
};
class KNCustomizedOp : public mirage::kernel::KNOperator {
public:
KNCustomizedOp(Graph *_kgraph,
std::vector<DTensor> const &inputs,
mirage::threadblock::Graph const &_graph);
virtual ~KNCustomizedOp();
bool fingerprint(void);
size_t get_owner_independent_hash() const override;
operator json() const override;
public:
mirage::threadblock::Graph bgraph;
void get_bgraph(mirage::threadblock::Graph **bgraph);
};
KNOperatorType 的全量为:
c++
enum KNOperatorType {
KN_UNKOWN = 1000,
KN_INPUT_OP = 1001,
KN_OUTPUT_OP = 1002,
KN_MATMUL_OP = 1003,
// ElementUnary
KN_EXP_OP = 1100,
KN_SQUARE_OP = 1101,
KN_SQRT_OP = 1102,
KN_MUL_SCALAR_OP = 1103,
KN_SILU_OP = 1104,
KN_SIGMOID_OP = 1105,
KN_GELU_OP = 1106,
// non-lax elementunary ops
KN_RELU_OP = 1150,
KN_CLAMP_OP = 1151,
KN_LOG_OP = 1160,
// ElementBinary
KN_ADD_OP = 1200,
KN_MUL_OP = 1201,
KN_DIV_OP = 1202,
KN_POW_OP = 1203,
// Reduction & Normalization
KN_REDUCTION_0_OP = 1300,
KN_REDUCTION_1_OP = 1301,
KN_REDUCTION_2_OP = 1302,
KN_RMS_NORM_OP = 1350,
// Concat & Split
KN_CONCAT_FIRST_OP_ID = 1400,
KN_CONCAT_0_OP = 1400,
KN_CONCAT_1_OP = 1401,
KN_CONCAT_2_OP = 1402,
KN_CONCAT_LAST_OP_ID = 1409,
KN_SPLIT_FIRST_OP_ID = 1420,
KN_SPLIT_0_OP = 1420,
KN_SPLIT_1_OP = 1421,
KN_SPLIT_2_OP = 1422,
KN_CHUNK_0_OP = 1423,
KN_CHUNK_1_OP = 1424,
KN_CHUNK_2_OP = 1425,
KN_SPLIT_LAST_OP_ID = 1429,
// Communication
KN_ALLREDUCE_OP = 1900,
KN_CUSTOMIZED_OP = 1999,
};
3.6 生成样例
Kernel & block图的生成逻辑如下:
- 从输入节点出发,以x,y,z输入张量为起点,初始化一个空前缀。
- 迭代增长,枚举算子来构造新节点,每次枚举一个算子加入(枚举matmul、add、exp...,合成算子),当枚举到合成算子,马上进入block graph的synthesis,每次扩张会检查合法性:形状、显存/SMEM容量、路径约束。
- 抽象剪枝,计算当前前缀的抽象表达式E,当和canonical form E0不一致时剪枝,生成结束后会得到没有thread graph的kernel/block图候选集合。
下面代码中给出了kernel graph和block graph的生成样例。
python
import mirage as mi
def new_kernel_graph():
kgraph = core.CyKNGraph()
return KNGraph(kgraph)
def get_rms_linear():
graph = mi.new_kernel_graph() # kernel graph
X = graph.new_input(dims=(num_tokens, 4096), dtype=mi.float16)
W = graph.new_input(dims=(4096, n_local_heads * head_dim + 2 * n_local_kv_heads * head_dim), dtype=mi.float16)
# block graph
tb_graph = mi.new_threadblock_graph(grid_dim=(384,1,1), block_dim=(128,1,1), forloop_range=32, reduction_dimx=64)
tX = tb_graph.new_input(dtensor=X, input_map=(-1, -1, -1), forloop_dim=1)
tW = tb_graph.new_input(dtensor=W, input_map=(1, -1, -1), forloop_dim=0)
tM = tb_graph.matmul(tX, tW)
tAccX = tb_graph.forloop_accum(tX, "rms")
tAccM = tb_graph.forloop_accum(tM)
tO = tb_graph.div(tAccM, tAccX)
tb_graph.new_output(stensor=tO, output_map=(1, -1, -1))
O = graph.customized([X, W], tb_graph)
return graph, O
def mirage_llama(X, Wqkv, Wo, W13, W2, Kcache, Vcache, kernels):
func = kernels[0]
outputs = func(inputs=[X, Wqkv])
Xqkv = outputs[0]
Xq = Xqkv[:, : (n_local_heads * head_dim)]
output_shape = Xq.shape
Xkv = Xqkv[:, (n_local_heads * head_dim) :]
Xk, Xv = Xkv.chunk(2, 1)
Xq = Xq.view(Xq.shape[0], n_local_heads, head_dim)
Xk = Xk.view(Xk.shape[0], n_local_kv_heads, head_dim)
Xv = Xv.view(Xv.shape[0], n_local_kv_heads, head_dim)
output = flashinfer.single_prefill_with_kv_cache(Xq, Kcache, Vcache, causal=True)
output = torch.matmul(output.reshape(output_shape), Wo)
X = output
func = kernels[1]
outputs = func(inputs=[X, W13])
X13 = outputs[0]
X1, X3 = X13.chunk(2, -1)
output = torch.matmul(X1, W2)
return output
if __name__ == "__main__":
X = torch.randn(num_tokens, 4096, dtype=torch.float16, device='cuda:0')
Wqkv = torch.randn(4096, n_local_heads * head_dim + 2 * n_local_kv_heads * head_dim, dtype=torch.float16, device='cuda:0')
Wo = torch.randn(n_local_heads * head_dim, 4096, dtype=torch.float16, device='cuda:0')
W13 = torch.randn(4096, intermediate_size * 2, dtype=torch.float16, device='cuda:0')
W2 = torch.rand(14336, 4096, dtype=torch.float16, device='cuda:0')
Kcache = torch.rand(num_kv_tokens, n_local_kv_heads, head_dim, dtype=torch.float16, device='cuda:0')
Vcache = torch.rand(num_kv_tokens, n_local_kv_heads, head_dim, dtype=torch.float16, device='cuda:0')
k1 = get_rms_linear() # 此处生成计算图
k2 = get_rms_linear2() # 此处生成计算图
kernels = [k1, k2]
for _ in range(16):
mirage_llama(X, Wqkv, Wo, W13, W2, Kcache, Vcache, kernels)
torch.cuda.synchronize()
from_json()函数也会生成。以下是创建操作。g是内核图。
c++
void from_json(json const &j, Graph &g) {
switch (op_type) {
case type::KNOperatorType::KN_INPUT_OP: {
int num_dim, dim[mirage::config::MAX_TENSOR_DIMS];
type::DataType data_type;
layout::DmemLayout layout;
std::vector<size_t> input_strides;
size_t guidO;
jop.at("output_tensors")[0].at("num_dims").get_to(num_dim);
jop.at("output_tensors")[0].at("dim").get_to(dim);
jop.at("input_strides").get_to(input_strides);
jop.at("output_tensors")[0].at("data_type").get_to(data_type);
jop.at("output_tensors")[0].at("layout").get_to(layout);
jop.at("output_tensors")[0].at("guid").get_to(guidO);
std::vector<int> dims = to_vector(num_dim, dim);
// 调用KNGraph的函数
DTensor const &output =
g.new_input(dims, input_strides, data_type, layout);
guid_mapping[output.guid] = guidO;
break;
}
new_input是KNGraph的函数。
c++
class KNGraph:
def new_input(
self, dims: tuple, strides: tuple = None, dtype: dtype = float16
) -> DTensor:
# use the default strided layout if strides = None
if strides is None:
total_elements = 1
strides = []
for d in reversed(dims):
strides.append(total_elements)
total_elements *= d
strides = reversed(strides)
return self.cygraph.new_input(dims, tuple(strides), dtype)
最终到CyTBGraph
c++
cdef class CyTBGraph:
cdef CppTBGraph *p_bgraph #Hold a CppTBGraph instance
def __cinit__(self, tuple grid_dim = (), tuple block_dim = (), int forloop_range = -1, int dimx = -1, bgraph = None):
cdef unsigned long long ptr
cdef dim3 c_grid_dim
cdef dim3 c_block_dim
if bgraph is None:
c_grid_dim.x = grid_dim[0]
c_grid_dim.y = grid_dim[1]
c_grid_dim.z = grid_dim[2]
c_block_dim.x = block_dim[0]
c_block_dim.y = block_dim[1]
c_block_dim.z = block_dim[2]
self.p_bgraph = new CppTBGraph(c_grid_dim, c_block_dim, forloop_range, dimx)
else:
ptr = ctypes.cast(bgraph, ctypes.c_void_p).value
if isinstance(bgraph, int):
self.p_bgraph = <CppTBGraph*>(ptr)
elif isinstance(bgraph, ctypes.c_void_p):
self.p_bgraph = <CppTBGraph*>(ptr)
def new_input(self, DTensor dtensor, tuple input_map, int forloop_dim, bool store_in_dmem = False):
cdef int3 c_input_map
c_input_map.x = input_map[0]
c_input_map.y = input_map[1]
c_input_map.z = input_map[2]
cdef CppDTensor* dtensor_cptr = NULL
if dtensor is not None:
dtensor_cptr = dtensor.c_ptr
cdef CppSTensor* ptr = self.p_bgraph.new_input(dtensor_cptr, c_input_map, forloop_dim, SmemRowMajor, store_in_dmem)
t = ctypes.cast(<unsigned long long>ptr, ctypes.c_void_p)
return STensor(t)
def new_output(self, STensor stensor, tuple output_map, int forloop_dim, str epilogue = None):
cdef int3 c_output_map
c_output_map.x = output_map[0]
c_output_map.y = output_map[1]
c_output_map.z = output_map[2]
epilogue_type = string_to_tbepilogue(epilogue)
self.p_bgraph.new_output(stensor.c_ptr, c_output_map, forloop_dim, epilogue_type)
def matmul(self, STensor A, STensor B):
cdef CppSTensor* ptr = self.p_bgraph.matmul(A.c_ptr, B.c_ptr)
t = ctypes.cast(<unsigned long long>ptr, ctypes.c_void_p)
return STensor(t)
def exp(self, STensor A):
cdef CppSTensor* ptr = self.p_bgraph.exp(A.c_ptr)
t = ctypes.cast(<unsigned long long>ptr, ctypes.c_void_p)
return STensor(t)
def silu(self, STensor A):
cdef CppSTensor* ptr = self.p_bgraph.silu(A.c_ptr)
t = ctypes.cast(<unsigned long long>ptr, ctypes.c_void_p)
return STensor(t)
0x04 线程块图
kernel graph 管理整体计算流,block_graph 管理线程块级别的并行计算,从而实现高效的 GPU 执行。
块图指定与线程块相关的计算,其中每个节点表示一个块操作符,指定线程块内的计算,每条边是线程块操作符之间共享的张量。Mirage 将块图中的所有中间张量保存在 GPU 共享内存中,有两个考虑。首先,GPU 共享内存提供的带宽远高于设备内存,这种设计允许 Mirage 通过最大限度地将中间结果保存在共享内存中来减少设备内存访问。其次,对于大小超过共享内存容量且必须存储在设备内存中的张量,Mirage 使用这些张量将计算分割成多个块图,每个块图仅包含共享内存中的张量。这种分离不会引入对设备内存的额外访问。
4.1 属性
每个块图还与一些属性相关联,以指定其执行。
4.1.1 网格尺寸
内核中的所有线程块都由最多 3 维的网格组织,标识为 x、y 和 z。相应地,块图与最多三个网格尺寸相关联,指定沿 x、y 和 z 尺寸的块数。上图中的两个块图启动了 80(即 8 × 10)和 64(即 8 × 8)个块。
首先,对于图定义的内核操作符(例如内核图中的 Q、K 和 V)的每个输入张量,相关的块图包含一个 imap,它指定如何将输入张量划分为各个块的子张量。对于每个网格尺寸(即 x、y 或 z),imap 将其映射到(1)输入张量的数据维度或(2)特殊的副本维度 𝜙。对于(1),映射的数据维度在网格尺寸上的块之间均匀划分。对于(2),输入张量在这些线程块之间复制。
其次,对于块图的每个输出张量,块图包括一个 omap,它指定所有块的输出如何连接以构建内核操作符的最终输出。在 omap 中,每个网格尺寸必须映射到输出张量的数据维度,因为不同的块必须保存到设备内存中的不相交张量。对于上图中形状为 [h=1, s=8, d=64] 的 B,其 omap={x<->h, y<->d} 表示具有相同 x 索引的块沿 h 维度连接,具有相同 y 索引的块沿 d 维度连接,从而得到形状为 [h=8, s=8, d=640] 的张量 B。
4.1.2 For-loop 尺寸
为了适应大输入张量在共享内存中并允许缓存重用,与每个块图相关的第二个属性是 for-loop 尺寸,它们共同指定块图执行多少次以完成内核。相应地,每个输入张量首先被发送到输入迭代器,该迭代器从设备内存加载张量的一部分到共享内存。每个输入迭代器都与 fmap 关联,以指定每次迭代加载输入张量的哪一部分。形式上,fmap 将每个 for-loop 维度映射到(1)输入张量的数据维度或(2)副本维度 𝜙。与 imap 的语义类似,输入张量沿该维度均匀划分为(1)并在(2)中复制。
此外,块图包含输出累加器,以在共享内存中跨迭代累积其输出,并将最终结果保存回设备内存。与输入迭代器类似,输出累加器也与 fmap 关联,以指定不同迭代的输出张量如何组合以产生最终结果。具体来说,fmap 将每个 for-loop 维度映射到数据维度,这导致输出沿该维度连接,或副本维度 𝜙,这导致输出在共享内存中累积。
4.2 Python 代码
TBGraph 是块图的实现。每个自定义操作(embedding,attention,MLP)都会创建对应的thread block,用于定义该级别的具体执行方式,这些thread block 被编译为CUDA 内核,在GPU上以warp和线程方式并行执行。
TBGraph的特点如下:
-
节点分类如下:
- 预定义算子,对应CUTLASS或者ThunderKittens等CUDA组件库中封装好的共享内存上的一些操作(例如MatMul、Mul、Accum等block ops)
- 合成算子,包含一个thread graph
-
边的特点是:
- Tensor,SEME tensor,所有暂存tensor默认放在共享内存,减少DRAM访问
python
class TBGraph:
def __init__(self, graph):
self.cygraph = graph
def new_input(
self,
dtensor: DTensor,
input_map: tuple,
forloop_dim: int,
store_in_dmem: bool = False,
):
return self.cygraph.new_input(dtensor, input_map, forloop_dim, store_in_dmem)
def new_output(self, stensor: STensor, output_map: tuple, forloop_dim: int = -1):
return self.cygraph.new_output(stensor, output_map, forloop_dim)
def matmul(self, A: STensor, B: STensor):
return self.cygraph.matmul(A, B)
def exp(self, A: STensor):
return self.cygraph.exp(A)
def silu(self, A: STensor):
return self.cygraph.silu(A)
def gelu(self, A: STensor):
return self.cygraph.gelu(A)
def relu(self, A: STensor):
return self.cygraph.relu(A)
def clamp(self, A: STensor, min_val: float, max_val: float):
return self.cygraph.clamp(A, min_val, max_val)
def square(self, A: STensor):
return self.cygraph.square(A)
def sqrt(self, A: STensor):
return self.cygraph.sqrt(A)
def mul_scalar(self, A: STensor, scalar: float):
return self.cygraph.mul_scalar(A, scalar)
def add(self, A: STensor, B: STensor):
return self.cygraph.add(A, B)
def mul(self, A: STensor, B: STensor):
return self.cygraph.mul(A, B)
def div(self, A: STensor, B: STensor):
return self.cygraph.div(A, B)
def sub(self, A: STensor, B: STensor):
return self.cygraph.sub(A, B)
def reduction(self, A: STensor, dim: int):
return self.cygraph.reduction(A, dim)
def reduction_max(self, A: STensor, dim: int):
return self.cygraph.reduction_max(A, dim)
def rms_norm(self, A: STensor):
return self.cygraph.rms_norm(A)
def concat(self, A: STensor, B: STensor, dim: int):
return self.cygraph.concat(A, B, dim)
def forloop_accum(self, A: STensor, acc: str = None):
return self.cygraph.forloop_accum(A, acc)
def forloop_accum_rescale(self, A: STensor, B: STensor, acc: str = None):
return self.cygraph.forloop_accum_rescale(A, B, acc)
def forloop_accum_max(self, A: STensor):
return self.cygraph.forloop_accum_max(A)
TBGraph 构造函数传参 graph 是 CyTBGraph 类型。因此,TBGraph 的所有操作都转交给 CyTBGraph 进行处理。
python
TBGraph(CyTBGraph(grid_dim, block_dim, 1, 64))
生成时候TBGraph,传入
python
grid_dim=(X,Y,Z) // 线程块网格维度
block_dim=(128,1,1) // 线程块内线程维度
这表明每个thread block包含128个线程,按一维方式组织。
grid_dim和block_dim这两个参数被CyTBGraph使用。
4.3 桥梁
new_threadblock_graph函数中,会看到CyTBGraph。
python
def new_threadblock_graph(
grid_dim: tuple, block_dim: tuple, forloop_range: int, reduction_dimx: int
):
bgraph = core.CyTBGraph(grid_dim, block_dim, forloop_range, reduction_dimx)
return TBGraph(bgraph)
CyTBGraph会调用到CppTBGraph。
python
cdef class CyTBGraph:
cdef CppTBGraph *p_bgraph #Hold a CppTBGraph instance
def __cinit__(self, tuple grid_dim = (), tuple block_dim = (), int forloop_range = -1, int dimx = -1, bgraph = None):
cdef unsigned long long ptr
cdef dim3 c_grid_dim
cdef dim3 c_block_dim
if bgraph is None:
c_grid_dim.x = grid_dim[0]
c_grid_dim.y = grid_dim[1]
c_grid_dim.z = grid_dim[2]
c_block_dim.x = block_dim[0]
c_block_dim.y = block_dim[1]
c_block_dim.z = block_dim[2]
self.p_bgraph = new CppTBGraph(c_grid_dim, c_block_dim, forloop_range, dimx)
else:
ptr = ctypes.cast(bgraph, ctypes.c_void_p).value
if isinstance(bgraph, int):
self.p_bgraph = <CppTBGraph*>(ptr)
elif isinstance(bgraph, ctypes.c_void_p):
self.p_bgraph = <CppTBGraph*>(ptr)
else:
assert False, "bgraph must be an integer or ctypes.c_void_p, but got " + str(type(bgraph))
CppTBGraph 对应 "mirage::threadblock::Graph",这就是 C++的实现。
python
cdef cppclass CppTBGraph "mirage::threadblock::Graph"
4.4 C++代码
块图在代码中是mirage::threadblock::Graph类,这是中间层次的计算图。下面是精简版代码。
Block graph主要包含以下属性来表示程序并行切分的信息
- Grid Dims(x, y, z):kernel启动多少block
- imap:作用是输入分块,grid-dims到input tensor dims的映射
- omap:作用是输出拼接,grid-dims到output tensor dims的映射
- For-loop body:允许block多次迭代来复用SMEM,流水线形式来充分计算和访存重叠,把DRAM读写完全隐藏到计算时间里,同时也充分服用SMEM,形如InputIterator->...->Accum->...->OutputSaver
- fmap:决定每次迭代取哪一块数据,比如 fmap={i↔h} 沿 h 维滑窗。
c++
namespace mirage {
namespace threadblock {
class Graph {
private:
struct pair_hash {
size_t operator()(std::pair<int, int> const &p) const;
};
public:
Graph();
Graph(dim3 grid_dim, dim3 block_dim, int forloop_range, int reduction_dimx);
~Graph();
Graph(Graph const &) = delete;
Graph &operator=(Graph const &) = delete;
// input operator
STensor new_input(mirage::kernel::DTensor const &dtensor,
int3 input_map,
int forloop_dim,
mirage::layout::SmemLayout layout,
bool store_in_dmem = false);
STensor *new_input(mirage::kernel::DTensor const *dtensor,
int3 input_map,
int forloop_dim,
mirage::layout::SmemLayout layout,
bool store_in_dmem = false);
TBOperator *create_input_op(mirage::kernel::DTensor const &dtensor,
int3 input_map,
int forloop_dim,
mirage::layout::SmemLayout layout,
bool store_in_dmem = false);
// matmul operator
STensor matmul(STensor const &A, STensor const &B);
STensor *matmul(STensor const *A, STensor const *B);
TBOperator *create_matmul_op(STensor const &A, STensor const &B);
// element unary operator
STensor exp(STensor const &A);
STensor *exp(STensor const *A);
STensor square(STensor const &A);
STensor *square(STensor const *A);
STensor sqrt(STensor const &A);
STensor *sqrt(STensor const *A);
STensor silu(STensor const &A);
STensor *silu(STensor const *A);
STensor gelu(STensor const &A);
STensor *gelu(STensor const *A);
STensor relu(STensor const &A);
STensor *relu(STensor const *A);
// element binary operators
STensor add(STensor const &A, STensor const &B);
STensor *add(STensor const *A, STensor const *B);
STensor mul(STensor const &A, STensor const &B);
STensor *mul(STensor const *A, STensor const *B);
STensor div(STensor const &A, STensor const &B);
STensor *div(STensor const *A, STensor const *B);
STensor sub(STensor const &A, STensor const &B);
STensor *sub(STensor const *A, STensor const *B);
STensor pow(STensor const &A, STensor const &B);
STensor *pow(STensor const *A, STensor const *B);
// reduction operator
STensor reduction(STensor const &A, int dim);
STensor *reduction(STensor const *A, int dim);
TBOperator *create_reduction_op(STensor const &A, int dim);
// reduction_to_dimx operator
STensor reduction_to_dimx(STensor const &A, int dim);
TBOperator *create_reduction_to_dimx_op(STensor const &A, int dim);
// reduction_max operator
std::vector<STensor> reduction_max(STensor const &A, int dim);
std::vector<STensor *> reduction_max(STensor const *A, int dim);
TBOperator *create_reduction_max_op(STensor const &A, int dim);
// rms_norm operator
STensor rms_norm(STensor const &A);
STensor *rms_norm(STensor const *A);
TBOperator *create_rms_norm_op(STensor const &A);
public:
dim3 grid_dim, block_dim, cluster_dim{4, 4, 1};
int forloop_range;
int reduction_dimx;
std::vector<mirage::threadblock::TBOperator *> operators;
// memory allocator
off_t smem_offset;
std::vector<std::pair<off_t, size_t>> allocated_tensors;
using OpType = TBOperator;
using TensorType = STensor;
};
void from_json(json const &j, Graph &g);
} // namespace threadblock
} // namespace mirage
以 reduction_max 为例,代码如下:
c++
std::vector<STensor *> Graph::reduction_max(STensor const *input, int dim) {
TBOperator *op = create_reduction_max_op(*input, dim);
assert(op != nullptr);
operators.push_back(op);
return std::vector<STensor *>{&op->output_tensors[0], &op->output_tensors[1]};
}
TBOperator *Graph::create_reduction_max_op(STensor const &input, int dim) {
TBOperator *op =
new TBReductionOp(this, input, dim, -1 /*size = -1 for max*/);
// Check shmem usage
size_t smem_usage = calculate_shared_memory_usage(op);
if (smem_usage > mirage::config::MAX_SMEM_SIZE) {
delete op;
return nullptr;
} else {
return op;
}
}
4.5 TBOperator
块图在CUDA thread block级别执行,使用TBOperator来表示所包含的操作。也使用TBInputOp连接到上层的mu'Graph的张量。
以 Attention 层为例,其 thread block 可能包含如下结构:
python
Thread Block for Attention:
TB_INPUT_OP(输入QKV张量)
↓
TB_MATMUL_OP(计算QK^T)
↓
TB_REDUCTION_OP(Softmax归一化)
↓
TB_MATMUL_OP(计算Attention输出)
↓
TB_FORLOOP_ACCUM_NO_RED_OP(累积计算)
TBOperator的定义如下:
c++
namespace mirage {
namespace threadblock {
class Graph;
class TBOperator {
public:
TBOperator(Graph *graph, mirage::type::TBOperatorType);
TBOperator(Graph *graph, mirage::type::TBOperatorType, STensor const &input1);
TBOperator(Graph *graph,
mirage::type::TBOperatorType,
STensor const &input1,
STensor const &input2);
TBOperator(Graph *graph,
mirage::type::TBOperatorType,
std::vector<STensor> const &inputs);
int get_input_stensors(STensor **inputs);
int get_output_stensors(STensor **inputs);
virtual ~TBOperator();
virtual operator json() const = 0;
public:
Graph *bgraph;
mirage::type::TBOperatorType op_type;
std::vector<STensor> input_tensors;
std::vector<STensor> output_tensors;
};
TBOperator 的派生类举例。
c++
class TBInputOp : public TBOperator {
public:
TBInputOp(Graph *_graph,
mirage::kernel::DTensor const &dtensor,
int3 input_map,
int forloop_dim,
mirage::layout::SmemLayout layout,
bool store_in_dmem);
~TBInputOp();
operator json() const override;
size_t get_dtensor_guid();
public:
mirage::kernel::DTensor dtensor;
int3 input_map;
int forloop_dim;
};
class TBOutputOp : public TBOperator {
public:
TBOutputOp(Graph *_graph,
STensor const &stensor,
int3 output_map,
int forloop_dim,
mirage::type::TBEpilogueType allreduce);
~TBOutputOp();
operator json() const override;
size_t get_dtensor_guid();
public:
mirage::kernel::DTensor dtensor;
int3 output_map;
int forloop_dim;
mirage::type::TBEpilogueType epilogue;
};
TBOperatorType的类型为:
python
enum TBOperatorType {
TB_UNKOWN = 2000,
TB_INPUT_OP = 2001,
TB_OUTPUT_OP = 2002,
TB_MATMUL_OP = 2003,
// ElementUnary
TB_EXP_OP = 2100,
TB_SQUARE_OP = 2101,
TB_SQRT_OP = 2102,
TB_MUL_SCALAR_OP = 2103,
TB_SILU_OP = 2104,
TB_SIGMOID_OP = 2105,
TB_GELU_OP = 2106,
// non-lax elementunary ops
TB_RELU_OP = 2150,
TB_CLAMP_OP = 2151,
TB_LOG_OP = 2160,
// ElementBinary
TB_ADD_OP = 2200,
TB_MUL_OP = 2201,
TB_DIV_OP = 2202,
TB_SUB_OP = 2203,
TB_POW_OP = 2204,
// Reduction and Normalization
TB_REDUCTION_FIRST_OP_ID = 2300,
TB_REDUCTION_0_OP = 2301,
TB_REDUCTION_1_OP = 2302,
TB_REDUCTION_2_OP = 2303,
TB_REDUCTION_0_TO_DIMX_OP = 2304,
TB_REDUCTION_1_TO_DIMX_OP = 2305,
TB_REDUCTION_2_TO_DIMX_OP = 2306,
TB_REDUCTION_0_MAX_OP = 2307,
TB_REDUCTION_1_MAX_OP = 2308,
TB_REDUCTION_2_MAX_OP = 2309,
TB_REDUCTION_LAST_OP_ID = 2349,
TB_RMS_NORM_OP = 2350,
// Concat & Split
TB_CONCAT_FIRST_OP_ID = 2400,
TB_CONCAT_0_OP = 2400,
TB_CONCAT_1_OP = 2401,
TB_CONCAT_2_OP = 2402,
TB_CONCAT_LAST_OP_ID = 2409,
TB_CONCAT_THEN_MATMUL_OP = 2411,
TB_SPLIT_FIRST_OP_ID = 2420,
TB_SPLIT_0_OP = 2420,
TB_SPLIT_1_OP = 2421,
TB_SPLIT_2_OP = 2422,
TB_SPLIT_LAST_OP_ID = 2429,
// Forloop Accum
// LD indicates last dimension
TB_FORLOOP_ACCUM_FIRST_OP = 2500,
TB_FORLOOP_ACCUM_NO_RED_OP = 2500,
TB_FORLOOP_ACCUM_RED_LD_SUM_OP = 2501,
TB_FORLOOP_ACCUM_RED_LD_MEAN_OP = 2502,
TB_FORLOOP_ACCUM_RED_LD_RMS_OP = 2503,
TB_FORLOOP_ACCUM_REDTOX_LD_SUM_OP = 2504,
TB_FORLOOP_ACCUM_NO_RED_RESCALE_OP = 2505,
TB_FORLOOP_ACCUM_RED_LD_SUM_RESCALE_OP = 2506,
TB_FORLOOP_ACCUM_MAX_OP = 2507,
TB_FORLOOP_ACCUM_LAST_OP = 2599,
TB_CUSTOMIZED_OP = 2999
};
我们用 TBReductionOp 来看看具体实现。
c++
class TBReductionOp : public TBOperator {
public:
TBReductionOp(Graph *graph,
STensor const &_input,
int reduce_dim,
int reduce_size);
~TBReductionOp();
operator json() const override;
public:
int reduce_dim, reduce_size;
};
TBReductionOp::TBReductionOp(Graph *bgraph,
STensor const &input,
int dim,
int size)
: TBOperator(bgraph,
size == 1 ? (mirage::type::TBOperatorType)(
mirage::type::TB_REDUCTION_0_OP + dim)
: size == -1
? (mirage::type::TBOperatorType)(
mirage::type::TB_REDUCTION_0_MAX_OP + dim)
: (mirage::type::TBOperatorType)(
mirage::type::TB_REDUCTION_0_TO_DIMX_OP + dim),
input),
reduce_dim(dim), reduce_size(size) {
STensor output = input;
assert(output.num_dims > reduce_dim);
assert(output.layout == mirage::layout::SmemRowMajor);
output.dim[reduce_dim] = reduce_size == -1 ? 1 : reduce_size;
output.owner_op = this;
output.owner_ts_idx = 0;
output.guid = STensor::next_guid++;
output.after_accum = input.after_accum;
output.smem_offset = bgraph->allocate_fingerprint(output);
output_tensors.push_back(output);
if (reduce_size == -1) {
// For max reduction, we need to allocate another tensor for difference
STensor diff = output;
diff.owner_ts_idx = 1;
diff.guid = STensor::next_guid++;
diff.smem_offset = bgraph->allocate_fingerprint(diff);
output_tensors.push_back(diff);
}
}
4.6 生成样例
在Mirage项目中,block_graph是在创建自定义操作时插入得。
- 可以在Python代码直接通过mi.new_threadblock_graph()直接构建。
- 在 demo.py 中逐层构建模型时,每一层都会插入相应的 block_graph 来定义该层在线程块级别的具体执行方式。即,每个自定义操作的创建过程中:每当调用 PersistentKernel 的 layer 方法时,都会在内部创建一个包含具体线程块级计算的 block_graph。比如,attention_layer(),rmsnorm_linear_layer(), def embed_layer()内部都会构建block_graph。
- 也可以在C++代码直接构建。
4.6.1 Python代码直接构建
原始的rms_linear公式为:
\[ y_i = \frac{ x_i * g_i }{ \sqrt{\frac{1}{n} \sum_{i=1}^{n}{x_i^2}} } \]
逻辑如下:
针对rms_linear,MPK的转换代码如下:
python
def get_rms_linear():
graph = mi.new_kernel_graph() # kernel graph
X = graph.new_input(dims=(num_tokens, 4096), dtype=mi.float16)
W = graph.new_input(dims=(4096, n_local_heads * head_dim + 2 * n_local_kv_heads * head_dim), dtype=mi.float16)
# block graph
tb_graph = mi.new_threadblock_graph(grid_dim=(384,1,1), block_dim=(128,1,1), forloop_range=32, reduction_dimx=64)
tX = tb_graph.new_input(dtensor=X, input_map=(-1, -1, -1), forloop_dim=1)
tW = tb_graph.new_input(dtensor=W, input_map=(1, -1, -1), forloop_dim=0)
tM = tb_graph.matmul(tX, tW)
tAccX = tb_graph.forloop_accum(tX, "rms")
tAccM = tb_graph.forloop_accum(tM)
tO = tb_graph.div(tAccM, tAccX)
tb_graph.new_output(stensor=tO, output_map=(1, -1, -1))
O = graph.customized([X, W], tb_graph)
return graph, O
其中,new_threadblock_graph()内部会直接构建TBGraph(bgraph)。
python
def new_threadblock_graph(
grid_dim: tuple, block_dim: tuple, forloop_range: int, reduction_dimx: int
):
bgraph = core.CyTBGraph(grid_dim, block_dim, forloop_range, reduction_dimx)
return TBGraph(bgraph)
调整之后,其对应的逻辑如下:
4.6.2 PersistentKernel 的 layer 方法间接构建
比如:rmsnorm_linear_layer(),attention_layer()等函数中,都构建了TBGrapattach_inputh(CyTBGraph(grid_dim, block_dim, 1, 64))。
python
mpk.embed_layer(input=x, weight=w_embed, output=embed_out, grid_dim=(1, 1, 1), block_dim=(128, 1, 1))
mpk.rmsnorm_linear_layer(input=x, weight_norm=w_norm_attn, weight_linear=w_qkv, output=attn_in, grid_dim=(96, 1, 1), block_dim=(128, 1, 1))
在embed_layer函数内部,会构建 TBGraph(bgraph)。
python
def embed_layer(
self,
input: DTensor, # [batch_size, num_spec_tokens]
weight: DTensor, # [vocab_size, hidden_size]
output: DTensor, # [batch_size, hidden_size]
grid_dim: tuple,
block_dim: tuple,
input_source: int = 0, # 0: all_tokens, 1: input_token
):
tb_graph = TBGraph(CyTBGraph(grid_dim, block_dim, 1, 64))
tb_graph.new_input(input, (-1, 1, -1), -1, True)
tb_graph.new_input(weight, (1, -1, -1), -1, True)
tb_graph.new_input(output, (1, 0, -1), -1, True)
self.kn_graph.customized([input, weight, output], tb_graph)
self.kn_graph.register_task(tb_graph, "embedding", [input_source])
4.6.3 C++代码直接构建
在graph.cc,自定义操作也会构建block graph。这个是把python定义的图进行转换到c++。
c++
void from_json(json const &j, Graph &g) {
case type::KNOperatorType::KN_CUSTOMIZED_OP: {
std::vector<DTensor> inputs;
for (auto const &jinput : jop.at("input_tensors")) {
size_t guid;
jinput.at("guid").get_to(guid);
inputs.push_back(get_tensor_from_guid(guid));
}
threadblock::Graph bgraph;
from_json(jop.at("bgraph"), bgraph);
// 将muGraph的张量连接到block-graph的输入
for (size_t i = 0; i < bgraph.operators.size(); ++i) {
if (bgraph.operators[i]->op_type == type::TB_INPUT_OP) {
static_cast<threadblock::TBInputOp *>(bgraph.operators[i])
->dtensor = inputs[i];
}
}
std::vector<DTensor> outputs = g.customized(inputs, bgraph);
for (size_t i = 0; i < outputs.size(); ++i) {
size_t guidO;
jop.at("output_tensors")[i].at("guid").get_to(guidO);
guid_mapping[outputs[i].guid] = guidO;
}
break;
}
0x05 线程图
线程图进一步将计算范围从块缩小到单个线程。与块图类似,每个线程图也与块尺寸相关联,指定块内线程的组织,以及 for-loop 尺寸,定义完成定义计算的总迭代次数。每个线程图包括输入迭代器,每个迭代器从 GPU 共享内存加载输入张量到寄存器文件,以及输出累加器,每个累加器从寄存器文件保存输出张量回到共享内存。线程图是 uGraph 中的最低级别图,仅包含预定义的线程操作符。
线程图是最底层的计算图,在代码中没有显式定义为独立的图结构,而是在block-graph的操作中体现。
主要特征:
- 执行单位:在CUDA thread warp或者单个thread级别执行
- 操作细节:包含具体的线程级别计算和内存访问模式
-
Thread graph
-
-
边:Tensor,thread graph的张量位于寄存器
-
节点:描述单个thread内寄存器上的流水,load->emelent-wise->store。只包含预定义算子,对应封装好的寄存器上的一些操作,也支持for loop维+寄存器累加,不过mirage默认用规则化融合快速合成,避免在最细层再做大搜索
-
-
对每个候选内的block图,找出符合form的子图(通常是一串element-wise+reduce),把它们融成thread graph节点,表示这段计算可以放在寄存器里完成
-
规则化、无需大搜索。thread只做局部融合和固定模式的for-loop,避免搜索指数爆炸,这样仍能让大多数逐元素算子留在寄存器中,减少shared-memory访问
0xFF 参考
如何评价CMU将LLM转化为巨型内核的Mirage Persistent Kernel(MPK)工作?
Mirage: A Multi-Level Superoptimizer for Tensor Programs 简记 尘伊光
OSDI2025论文笔记:Mirage: A Multi-Level Superoptimizer for Tensor Programs 画饼充饥
Mirage: A Compiler for High-Performance Tensor Programs on GPUs
https://mirage-project.readthedocs.io/en/latest/mugraph.html
https://mirage-project.readthedocs.io/en/latest/transpiler.html