先看使用pytorch实现的softmax
python
import torch
def navie_softmax(x:torch.Tensor):
'''
输入:(M, N)
逻辑:沿着"列"(dim=1)这个轴去寻找。也就是说,对每一行,扫过它所有的列,吐出一个最大值。
结果:产生 M 个最大值。
'''
x_max=x.max(dim=1)[0]#PyTorch 的 max 会返回两个值:(values, indices)。[0] 取的是数值,[1] 取的是对应的索引(类似 argmax)。
'''
这里发生了 Broadcasting(广播)。
x 是 (M, N),x_max 算出来是 (M,)。
Torch 会自动把 $_max 复制 N 列,变成 (M, N),然后进行元素级相减。
'''
z=x-x_max
numerator=torch.exp(z)
denominator=numerator.sum(dim=1)
#None(或者 np.newaxis)的作用是在那个位置插入一个长度为 1 的新维度。
#这样 (M, N) / (M, 1) 就能触发广播,每一行都除以该行对应的和。
ret=numerator/denominator[:,None]
return ret
直接在 PyTorch 中实现时,对于 x∈M×N,计算 y = naive_softmax(x) 需要从 DRAM 中读取 5MN+2M 个元素,并写回 3MN+2M 个元素。
这显然是浪费的;我们更希望有一个自定义的「融合」内核,它只需读取一次 X,并在芯片上进行所有必要的计算。
softmax kernel
softmax 内核工作原理如下:每个Thread Block加载输入矩阵 X 的一组行,按Block数量跨步处理,对其进行归一化,并将结果写回输出 Y。
其实就是一种非常经典的模式:网格跨步循环(Grid Stride Loop)。这在cuda里也很常用。
注意,Triton 的一个重要限制是每个块必须具有 2 的幂次数的元素,因此,如果我们要处理任意可能的输入形状,我们需要在内部「填充」每一行,并适当保护内存操作。
python
@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
num_stages: tl.constexpr):
# 程序起始行
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
# 步长表示我们需要对指针增加多少以推进 1 行
row_start_ptr = input_ptr + row_idx * input_row_stride
# 块大小是大于 n_cols 的下一个二的幂,因此我们可以适配
# 单个块中的行
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
# 将行加载到 SRAM 中,使用掩码,因为 BLOCK_SIZE 可能大于 n_cols
mask = col_offsets < n_cols
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
# 为了数值稳定性而减去最大值
row_minus_max = row - tl.max(row, axis=0)
# 请注意,Triton 中的指数运算速度很快,但是是近似的(例如,类似于 CUDA 中的 __expf)。
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# 将输出写回 DRAM
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)
softmax_kernel 输入参数详解
| 参数名称 | 类型 | 含义 | Infra/CUDA 视角对照 |
|---|---|---|---|
output_ptr |
指针 | 输出张量在显存(DRAM)中的首地址。 | 对应 float* out |
input_ptr |
指针 | 输入张量在显存(DRAM)中的首地址。 | 对应 const float* in |
input_row_stride |
整数 | 输入矩阵相邻两行之间的元素距离。 | 对应 2D 数组的 pitch 或 width |
output_row_stride |
整数 | 输出矩阵相邻两行之间的元素距离。 | 通常与输入相同,除非输出做了 Padding |
n_rows |
整数 | 矩阵的总行数(MMM)。 | 循环终止的边界条件 |
n_cols |
整数 | 矩阵的总列数(NNN)。 | 用于生成 mask 防止越界访问 |
BLOCK_SIZE |
tl.constexpr |
编译时常量。定义每个 Program 一次加载的向量长度。 | 必须是 2 的幂,且 ≥n_cols\ge n\_cols≥n_cols |
num_stages |
tl.constexpr |
编译时常量。控制流水线循环的深度。 | 决定了异步搬运和计算的重叠程度 |
1. 为什么需要 input_row_stride?
在 PyTorch 或 C++ 中,矩阵可能不是连续存储的(例如对大矩阵取了子块 slice)。
-
input_row_stride允许 Kernel 正确跳过内存。 -
物理地址计算公式:
address = base + row * stride + col。
2. BLOCK_SIZE 的特殊性
在 Triton 的这个实现里,BLOCK_SIZE 是一个强约束:
-
因为它在 Kernel 里执行了
row - tl.max(row),这就要求 整行数据必须同时存在于同一个 Block 的寄存器中。 -
所以调用这个 Kernel 前,Host 端通常会计算
BLOCK_SIZE = next_power_of_2(n_cols)。
3. num_stages 对循环的影响
注意代码中的循环:
Python
for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
在 Triton 中,当 tl.range 配合 num_stages > 1 使用时,编译器会尝试进行 Software Pipelining(软件流水线化)。
-
Stage 1 : 发起第 i+1i+1i+1 行的异步 Load 请求。
-
Stage 2 : 计算第 iii 行的 Softmax。
这在 NVIDIA Ampere (A100) 及之后的架构上能极大地掩盖 DRAM 访存延迟。
triton.language.num_programs
简单来说,tl.num_programs(axis) 对应的就是 CUDA 里的 gridDim。它告诉你整个计算网格在某个维度上一共有多少个 Block(在 Triton 里叫 Program)。
python
@builtin
def num_programs(axis, _semantic=None):
"""
Returns the number of program instances launched along the given :code:`axis`.
:param axis: The axis of the 3D launch grid. Must be 0, 1 or 2.
:type axis: int
"""
axis = _unwrap_if_constexpr(axis)
return _semantic.num_programs(axis)
| 属性 | 内容 |
|---|---|
| 函数作用 | 获取在特定维度上启动的 Program(计算实例/Block)总数。 |
输入参数 (axis) |
维度索引。必须是 0, 1 或 2(对应 3D 网格的 X, Y, Z 轴)。 |
| 返回值 | 一个整数。代表该轴向上的网格大小(Grid Size)。 |
如果你有 CUDA 背景,可以这样秒懂:
| Triton 代码 | CUDA C++ 等价物 | 物理含义 |
|---|---|---|
tl.num_programs(0) |
gridDim.x |
X 轴上的总 Block 数 |
tl.num_programs(1) |
gridDim.y |
Y 轴上的总 Block 数 |
tl.num_programs(2) |
gridDim.z |
Z 轴上的总 Block 数 |
需要注意:num_programs 的值是由 Host 端启动 Kernel 时传入的 grid 参数 决定的。
python
grid = (64, 1, 1) # 启动 64 个 Program
softmax_kernel[grid](...)
triton.language.range
python
class range(base_value):
"""
Iterator that counts upward forever.
.. highlight:: python
.. code-block:: python
@triton.jit
def kernel(...):
for i in tl.range(10, num_stages=3):
...
:note: This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of
:code:`triton.jit` functions. In addition, it allows user to pass extra attributes to the compiler.
:param arg1: the start value.
:param arg2: the end value.
:param step: the step value.
:param num_stages: pipeline the loop into this many stages (so there are
:code:`num_stages` iterations of the loop in flight at once).
Note this is subtly different than passing :code:`num_stages` as a
kernel argument. The kernel argument only pipelines loads that feed
into :code:`dot` operations, while this attribute tries to pipeline most
(though not all) loads in this loop.
:param loop_unroll_factor: Tells the Triton IR level loop unroller how many
times to unroll a for loop that this range is used with. Less than 2 for
this value implies no unrolling.
:param disallow_acc_multi_buffer: If true, prevent the accumulator of the dot
operation in the loop to be multi-buffered, if applicable.
:param flatten: automatically flatten the loop nest starting at this loop to
create a single flattened loop. The compiler will try to pipeline the
flattened loop which can avoid stage stalling.
:param warp_specialize: Enable automatic warp specialization on the loop.
The compiler will attempt to partition memory, MMA, and vector
operations in the loop into separate async partitions. This will
increase the total number of warps required by the kernel.
:param disable_licm: Tells the compiler it shouldn't hoist loop invariant
code outside the loop. This is often useful to avoid creating long liveranges
within a loop.
Note that warp specialization is only supported on Blackwell GPUs and
only works on simple matmul loops. Support for arbitrary loops will be
expanded over time.
"""
def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None,
disallow_acc_multi_buffer=False, flatten=False, warp_specialize=False, disable_licm=False):
if step is None:
self.step = constexpr(1)
else:
self.step = step
if arg2 is None:
self.start = constexpr(0)
self.end = arg1
else:
self.start = arg1
self.end = arg2
self.num_stages = num_stages
self.loop_unroll_factor = loop_unroll_factor
self.disallow_acc_multi_buffer = disallow_acc_multi_buffer
self.flatten = flatten
self.warp_specialize = warp_specialize
self.disable_licm = disable_licm
def __iter__(self):
raise RuntimeError("tl.range can only be used in @triton.jit'd functions")
def __next__(self):
raise RuntimeError("tl.range can only be used in @triton.jit'd functions")
1. 基础迭代控制
-
arg1(start): 迭代的起始值(如果只传一个参数,则为 0)。 -
arg2(end): 迭代的终止值(不包含)。 -
step(step): 迭代步长,默认为 1。
2. 软件流水线 (Software Pipelining)
-
num_stages:-
翻译 :将循环流水线化为指定的级数(即同时有
num_stages个迭代在并行执行)。 -
Infra 解析:这是最重要的参数。这里的参数会尝试流水线化循环内的大多数加载操作。它利用双缓冲或多缓冲技术,在计算当前迭代时,异步搬运后续迭代的数据。
-
3. 循环展开 (Loop Unrolling)
-
loop_unroll_factor:-
翻译:告诉 Triton IR 级别的循环展开器,这个循环需要展开多少次。如果值小于 2,则表示不展开。
-
Infra 解析 :对应 CUDA 里的
#pragma unroll。通过增加寄存器压力来减少循环开销(计数器加法、条件跳转)。
-
4. 累加器多缓冲控制
-
disallow_acc_multi_buffer:-
翻译 :如果为 True,防止循环中
dot操作的累加器(accumulator)使用多缓冲。 -
Infra 解析:通常在流水线化时,为了性能会给累加器开辟多个缓冲区。如果你面临极端的寄存器压力,可以设为 True 来节省空间。
-
5. 循环嵌套打平 (Loop Flattening)
-
flatten:-
翻译:自动将从当前循环开始的嵌套循环打平为单个循环。编译器会尝试对打平后的循环进行流水线化,以避免阶段性停顿(Stage Stalling)。
-
Infra 解析:这在处理 2D 数据块迭代时非常有用,能让指令流水线更平滑。
-
6. Warp 特化 (Warp Specialization)
-
warp_specialize:-
翻译:启用自动 Warp 特化。编译器会将循环中的内存操作、MMA(矩阵乘法)和向量操作分配到不同的异步分区中。这会增加 Kernel 所需的 Warp 总数。
-
Infra 解析 :这是 Blackwell (NVIDIA GB200) 等新架构的高级特性。它让一部分 Warp 专门负责搬运数据(Producer),另一部分专门负责计算(Consumer),实现硬件级别的任务解耦。
-
7. 禁用循环不变式外提 (Disable LICM)
-
disable_licm:-
翻译:告诉编译器不要将循环内不变的代码提到循环外。
-
Infra 解析 :LICM (Loop Invariant Code Motion) 是通用编译优化,但在 GPU 上,外提代码可能会显著增加寄存器生命周期(Live Ranges),导致寄存器压力过大(Register Pressure)。禁用它可以让寄存器回收更快。
-
总结对照表
| 参数 | 对应 C++/编译优化概念 | 目的 |
|---|---|---|
num_stages |
Double/Triple Buffering | 掩盖访存延迟 (Hide Latency) |
loop_unroll_factor |
#pragma unroll |
减少分支指令开销 |
flatten |
Loop Collapse | 优化流水线效率 |
warp_specialize |
Producer-Consumer Model | 硬件资源解耦 (New Architectures) |
disable_licm |
Register Pressure Control | 防止寄存器溢出 (Spilling) |
triton.language.max
python
def max(input, axis=None, return_indices=False,
return_indices_tie_break_left=True, keep_dims=False):
该函数的主要功能是计算输入张量 input 在指定维度(axis)上的最大值。
它不仅支持基础的数值归约,还处理了以下高级特性:
-
精度提升 (Promotion) :自动将低位宽浮点数(如
bfloat16)转为float32以保证计算精度。 -
索引回传 (Argmax):可选地返回最大值对应的索引位置。
-
平局处理 (Tie-break):当存在多个相同的最大值时,决定返回哪一个索引(最左侧或快速模式)。
参数解释
| 参数名 | 类型 | 描述 |
|---|---|---|
input |
Tensor |
需要进行最大值运算的输入张量。 |
axis |
int/tuple/None |
执行归约的轴。如果为 None,通常代表对全量元素进行归约。 |
return_indices |
bool |
关键参数 。若为 True,函数不仅返回最大值,还会返回最大值的索引(类似于 argmax)。 |
return_indices_tie_break_left |
bool |
仅在 return_indices=True 时有效。若为 True,遇到多个最大值时确保返回**第一个(索引最小)**的索引。 |
keep_dims |
bool |
是否保留归约后的维度。若为 True,被归约的轴长度变为 1。 |
triton.language.exp
python
@core.builtin
@_check_dtype(dtypes=["fp32", "fp64"])
@_add_math_1arg_docstr("exponential")
@core._tensor_member_fn
def exp(x, _semantic=None):
这个 exp(指数函数,即 exe^xex)的实现属于 Triton 的底层原子算子封装。
该函数用于计算输入张量 x 中每个元素的指数值。它是逐元素(Element-wise)运算,不涉及跨维度的归约(Reduction)。
| 参数名 | 描述 |
|---|---|
x |
输入数据。根据装饰器限制,必须是浮点类型(fp32 或 fp64)。 |
triton.language.sum
python
@core._tensor_member_fn
@jit
@core._add_reduction_docstr("sum", dtype_arg="dtype")
def sum(input, axis=None, keep_dims=False, dtype: core.constexpr = None):
该函数用于计算输入张量 input 在指定维度(axis)上的累加和。
与 max 类似,它也是一个归约(Reduction)算子,但它特别强调了对输出数据类型(dtype)的控制,以防止在累加过程中出现数值溢出。
参数解释
| 参数名 | 类型 | 描述 |
|---|---|---|
input |
Tensor |
需要求和的输入张量。 |
axis |
int/tuple/None |
执行求和的轴。如果为 None,则对张量内所有元素求和。 |
keep_dims |
bool |
是否保留被归约的维度(保持维度秩不变)。 |
dtype |
core.constexpr |
关键参数 。强制指定的输出类型。由于求和极易导致数值超出原类型范围(如 int8 累加很快会溢出),用户可以手动指定更高的精度。 |
辅助函数
python
DEVICE = triton.runtime.driver.active.get_active_torch_device()
def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"
def is_cdna():
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
'gfx90a', 'gfx908')
properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}
def softmax(x):
n_rows, n_cols = x.shape
# The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
BLOCK_SIZE = triton.next_power_of_2(n_cols)
# Another trick we can use is to ask the compiler to use more threads per row by
# increasing the number of warps (`num_warps`) over which each row is distributed.
# You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself.
num_warps = 8
# Number of software pipelining stages.
num_stages = 4 if SIZE_SMEM > 200000 else 2
# Allocate output
y = torch.empty_like(x)
# pre-compile kernel to get register usage and compute thread occupancy.
kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
num_stages=num_stages, num_warps=num_warps, grid=(1, ))
kernel._init_handles()
n_regs = kernel.n_regs
size_smem = kernel.metadata.shared
if is_hip():
# NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
# However, this is not always the case. In most cases all registers can be used as regular purpose registers.
# ISA SECTION (3.6.4 for CDNA3)
# VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
# with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
# VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
# not required to be equal numbers of both types.
NUM_GPRS = NUM_REGS
if is_cdna():
NUM_GPRS = NUM_REGS * 2
# MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
# When we divide this number with WARP_SIZE we get maximum number of waves that can
# execute on a CU (multi-processor) in parallel.
MAX_NUM_THREADS = properties["max_threads_per_sm"]
max_num_waves = MAX_NUM_THREADS // WARP_SIZE
occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
else:
occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
occupancy = min(occupancy, SIZE_SMEM // size_smem)
num_programs = NUM_SM * occupancy
num_programs = min(num_programs, n_rows)
# Create a number of persistent programs.
kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)
return y
这段代码展示了 Triton 中一个非常经典的高级技巧:持久化算子(Persistent Kernel) 的启动环境配置。
与普通的算子调用不同,这个辅助函数不是简单地启动与数据规模一样多的线程块,而是通过硬件感知(Hardware-aware)的计算,精确控制并行度以最大化利用 GPU 资源。
1. 硬件资源探测 (Hardware Querying)
在代码开头,它通过 driver 接口直接获取了当前 GPU 的物理参数:
-
NUM_SM: GPU 核心(流式多处理器)的数量。 -
NUM_REGS&SIZE_SMEM: 每个 SM 拥有的寄存器总量和共享内存(L1 缓存)总量。 -
这些参数决定了 GPU 的"胃口",即它一次性最高能处理多少计算任务。
2. 计算占用率 (Occupancy Calculation)
这是该函数最精华的部分。它通过"预编译(warmup)"来探测内核的资源消耗,并计算占用率(Occupancy):
-
寄存器约束:计算每个 SM 能容纳多少个线程组(Warps)。
-
共享内存约束:计算 SM 的内存空间够分给几个 Block。
-
最终结果 :
occupancy取两者的最小值。这代表了在硬件不宕机、不排队的前提下,一个物理 SM 同时能跑多少个程序实例。
3. 设置持久化网格 (Persistent Grid)
通常的写法是 grid = (n_rows, ),即一行对应一个 Block。但如果行数非常多(比如 100 万行),GPU 会频繁地调度、创建和销毁 Block,产生巨大的开销。
该函数通过 num_programs = NUM_SM * occupancy 设置了网格大小:
-
它只启动刚好能填满整个 GPU 物理槽位的 Block 数量。
-
这些 Block 启动后不会立即退出,而是通过内部循环"吞掉"所有的行(Persistent 模式)。
-
好处:极大地减少了 GPU 调度器的压力,并提高了 L2 缓存的命中率。
4. 编译优化参数
-
BLOCK_SIZE: 自动寻找大于列数的最小 2n2^n2n,确保内存对齐且符合 Triton 的编译要求。 -
num_stages: 根据共享内存大小决定软件流水线(Software Pipelining)的深度。如果内存够大,就用 4 阶流水线来隐藏指令延迟。 -
num_warps: 固定使用 8 个 Warp(256 个线程)来处理每一行,保证了足够的线程并行度。
triton.runtime.DriverConfig
python
def _create_driver() -> DriverBase:
active_drivers = [x.driver for x in backends.values() if x.driver.is_active()]
if len(active_drivers) != 1:
raise RuntimeError(f"{len(active_drivers)} active drivers ({active_drivers}). There should only be one.")
return active_drivers[0]()
class DriverConfig:
def __init__(self) -> None:
self._default: DriverBase | None = None
self._active: DriverBase | None = None
@property
def default(self) -> DriverBase:
if self._default is None:
self._default = _create_driver()
return self._default
@property
def active(self) -> DriverBase:
if self._active is None:
self._active = self.default
return self._active
def set_active(self, driver: DriverBase) -> None:
self._active = driver
def reset_active(self) -> None:
self._active = self.default
driver = DriverConfig()
triton.runtime.DriverConfig 类在 Triton 架构中扮演的是 "后端驱动管理器" 的角色。它的设计非常简洁,核心目的是为了实现 惰性初始化(Lazy Initialization) 和 多后端切换。
由于 Triton 需要支持不同的硬件(如 NVIDIA GPU、AMD GPU 等),它不能在启动时就写死驱动。DriverConfig 就像一个转换插头,确保系统总能找到当前环境下最合适的硬件接口。
1. 核心成员变量
-
_default: 存储系统检测到的默认驱动。例如在 NVIDIA 环境下,它通常指向CUDADriver。 -
_active: 当前正在使用的驱动。在大多数情况下,它和_default是一样的,但在复杂的调试或多硬件共存环境下,你可以手动切换。
2. 核心逻辑:惰性初始化
注意到 default 和 active 都是 @property(属性装饰器):
-
当你运行
driver.active时,它会检查_active是否为空。 -
如果为空,它会去调用
_create_driver()。这个函数会扫描系统环境(比如检查是否有 CUDA 库、是否有 ROCm 环境),然后动态创建一个驱动实例。 -
好处:如果你的代码只是定义了函数而没有实际在 GPU 上运行,Triton 就不会去加载沉重的驱动库,节省开销。
triton.backends.driver.DriverBase和GPUDriver
DriverBase 定义了规则(接口),而 GPUDriver 则是针对具体硬件(目前主要是通过 PyTorch 桥接的 GPU)的具体实现。
python
class DriverBase(metaclass=ABCMeta):
@classmethod
@abstractmethod
def is_active(self):
pass
@abstractmethod
def map_python_to_cpp_type(self, ty: str) -> str:
"""
Converts a Triton type string to its corresponding C++ type string for this backend.
Args:
ty (str): The Triton type string. e.g., 'i32', '*fp16', 'fp32'.
Returns:
str: The C++ type string.
"""
pass
@abstractmethod
def get_current_target(self):
pass
@abstractmethod
def get_active_torch_device(self):
pass
@abstractmethod
def get_benchmarker(self) -> Benchmarker:
"""
Return the benchmarking function that this backend should use by default.
"""
raise NotImplementedError
def __init__(self) -> None:
pass
class GPUDriver(DriverBase):
def __init__(self):
# TODO: support other frameworks than torch
import torch
self.get_device_capability = torch.cuda.get_device_capability
try:
from torch._C import _cuda_getCurrentRawStream
self.get_current_stream = _cuda_getCurrentRawStream
except ImportError:
self.get_current_stream = lambda idx: torch.cuda.current_stream(idx).cuda_stream
self.get_current_device = torch.cuda.current_device
self.set_current_device = torch.cuda.set_device
# TODO: remove once TMA is cleaned up
def assemble_tensormap_to_arg(self, tensormaps_info, args):
return args
DriverBase 是一个抽象基类 (Abstract Base Class)。它不负责具体干活,而是规定了任何一个 Triton 驱动后端(无论是 NVIDIA、AMD 还是未来的 CPU 后端)必须提供的功能:
-
类型映射 (
map_python_to_cpp_type) : 这是编译器的关键。它把 Triton 的类型(如i32)翻译成底层编译器(如 LLVM 或 NVCC)能懂的 C++ 类型。 -
目标检测 (
get_current_target) : 确定当前的硬件架构(例如:它是sm_80还是sm_90?)。这决定了编译器能否使用 Tensor Core 或 TMA 等高级特性。 -
性能评估 (
get_benchmarker) : 提供一个标准的"跑分"工具,用于autotune(自动调优)过程中对比不同配置的快慢。
GPUDriver 是 DriverBase 的子类。有趣的是,它的实现高度依赖于 PyTorch。
核心任务:绑定 PyTorch 的底层接口
在 __init__ 中,你可以看到它在疯狂"偷" PyTorch 的函数:
-
get_device_capability: 调用torch.cuda.get_device_capability来获取显卡的计算能力(Compute Capability)。 -
get_current_stream: 这是最底层、最关键的部分。它试图获取 CUDA 的原始流句柄(Raw Stream Handle)。-
Triton 需要直接操作 CUDA Stream 来提交任务。
-
代码中做了兼容性处理:如果能找到 PyTorch 内部的
_cuda_getCurrentRawStream(高性能 C++ 接口)就直接用;否则退而求其次使用常规的 Python 接口。
-
-
设备管理: 绑定了获取和设置当前显卡编号(Device ID)的方法。
triton.backends.nvidia.driver.CudaDriver
python
class CudaDriver(GPUDriver):
def __init__(self):
self.utils = CudaUtils() # TODO: make static
self.launcher_cls = CudaLauncher
super().__init__()
def get_current_target(self):
device = self.get_current_device()
capability = self.get_device_capability(device)
capability = capability[0] * 10 + capability[1]
warp_size = 32
return GPUTarget("cuda", capability, warp_size)
def get_active_torch_device(self):
import torch
return torch.device("cuda", self.get_current_device())
def get_device_interface(self):
import torch
return torch.cuda
@staticmethod
def is_active():
try:
import torch
return torch.cuda.is_available() and (torch.version.hip is None)
except ImportError:
return False
def map_python_to_cpp_type(self, ty: str) -> str:
return ty_to_cpp(ty)
def get_benchmarker(self):
from triton.testing import do_bench
return do_bench
def get_empty_cache_for_benchmark(self):
import torch
# We maintain a buffer of 256 MB that we clear
# before each kernel call to make sure that the L2 cache
# doesn't contain any input data before the run
cache_size = 256 * 1024 * 1024
return torch.empty(int(cache_size // 4), dtype=torch.int, device='cuda')
def clear_cache(self, cache):
cache.zero_()
算力计算 (get_current_target)
Python
capability = capability[0] * 10 + capability[1]
这是一个非常经典的设计。NVIDIA 的算力通常表示为 (major, minor),例如 A100 是 (8, 0),RTX 4090 是 (8, 9)。
-
Triton 将其转换为整数(80, 89 等),因为底层编译器(PTXAS)需要根据这个数字来决定使用哪些指令集。
-
Warp Size : 这里硬编码为
32,这是所有 NVIDIA GPU 目前的物理常数。
utils成员
它是CudaUtils类型。
triton.backends.nvidia.driver.CudaUtils
python
class CudaUtils(object):
def __new__(cls):
if not hasattr(cls, "instance"):
cls.instance = super(CudaUtils, cls).__new__(cls)
return cls.instance
def __init__(self):
mod = compile_module_from_src(
src=Path(os.path.join(dirname, "driver.c")).read_text(),
name="cuda_utils",
library_dirs=library_dirs(),
include_dirs=include_dirs,
libraries=libraries,
)
global PyCUtensorMap
PyCUtensorMap = mod.PyCUtensorMap
self.load_binary = mod.load_binary
self.get_device_properties = mod.get_device_properties
self.cuOccupancyMaxActiveClusters = mod.cuOccupancyMaxActiveClusters
self.set_printf_fifo_size = mod.set_printf_fifo_size
self.fill_tma_descriptor = mod.fill_tma_descriptor
CudaUtils 是 Triton 驱动层中最接近 NVIDIA Driver API (CUDA Driver API) 的部分。
这段代码使用了典型的单例模式(Singleton) ,通过 compile_module_from_src 实时编译了一个名为 driver.c 的 C 语言模块。它返回的成员变量实际上是直接映射到 CUDA 驱动层(如 libcuda.so 或 nvcuda.dll)的 C 函数。
以下是这些成员变量的详细含义:
1. self.get_device_properties
-
对应 CUDA API :
cuDeviceGetAttribute或cuDeviceGetProperties。 -
含义: 获取显卡的物理参数。
-
用途 : 返回一个字典,包含你在 Softmax 辅助函数中看到的
multiprocessor_count(SM 数量)、max_shared_mem(共享内存大小) 等。这是进行硬件感知优化(如计算 Occupancy)的基础。
2. self.load_binary
-
对应 CUDA API :
cuModuleLoadData/cuModuleGetFunction。 -
含义: 加载二进制内核。
-
用途 : Triton 编译器会将 Python 代码编译成汇编代码(PTX)或二进制机器码(CUBIN)。这个函数负责把这些二进制数据交给 GPU 驱动,并返回一个可以被调用的内核句柄。
3. self.cuOccupancyMaxActiveClusters
-
对应 CUDA API :
cuOccupancyMaxActiveClusters(NVIDIA Hopper 架构及以上)。 -
含义: 计算最大活跃集群占用率。
-
用途 : 专门用于 NVIDIA Hopper (H100) 架构。它能计算在特定集群配置下,GPU 物理上能同时并行跑多少个线程块集群。这比传统的 Occupancy 计算更高级,考虑了多 SM 之间的协同。
4. self.set_printf_fifo_size
-
对应 CUDA API :
cuCtxSetLimit(带CU_LIMIT_PRINTF_FIFO_SIZE参数)。 -
含义 : 设置
printf打印缓冲队列的大小。 -
用途 : 当你在 Triton Kernel 里使用
tl.static_print或调试打印时,GPU 需要一块内存来缓存这些字符串。如果打印内容太多,默认缓存会溢出,这个函数用来调大这个缓冲区。
5. self.fill_tma_descriptor 与 PyCUtensorMap
-
对应 CUDA API: CUDA Tensor Memory Accelerator (TMA) 相关接口。
-
含义: 填充 TMA 描述符。
-
用途 : 这是 NVIDIA Hopper (H100) 最核心的特性之一。
-
TMA 允许硬件自动在全局显存和共享内存之间搬运多维张量数据,而无需 CPU 或传统指令干预。
-
PyCUtensorMap定义了张量在内存中的布局(维度、步长、对齐),fill_tma_descriptor负责把这些配置写进硬件能读懂的"配置卡"中。
-
triton.backends.nvidia.driver.CudaUtils.get_device_properties
看一下这个字典都有哪些字段
| Python 字典键名 | CUDA 属性枚举值 (Attribute) | 物理含义说明 |
|---|---|---|
max_shared_mem |
CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN |
最大共享内存 :每个 Block 能够申请的最大 Shared Memory(字节)。注意这里用了 OPTIN,表示包含用户手动开启的高级共享内存配置。 |
max_num_regs |
CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK |
最大寄存器数:每个 Block 可用的 32-bit 寄存器总量。这直接决定了线程束(Warp)的占用率。 |
multiprocessor_count |
CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT |
SM 数量:显卡上流多处理器的总数。Triton 用它来决定 Grid 的大小(即一次能铺满多少个 SM)。 |
warpSize |
CU_DEVICE_ATTRIBUTE_WARP_SIZE |
线程束大小:NVIDIA 硬件上目前固定为 32。 |
sm_clock_rate |
CU_DEVICE_ATTRIBUTE_CLOCK_RATE |
SM 主频:GPU 核心的工作频率(单位:kHz)。 |
mem_clock_rate |
CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE |
显存频率:显存的工作频率(单位:kHz)。 |
mem_bus_width |
CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH |
显存位宽:显存接口的位宽(单位:bit,如 384-bit)。 |
triton.next_power_of_2
python
@constexpr_function
def next_power_of_2(n: int):
"""Return the smallest power of 2 greater than or equal to n"""
n -= 1
n |= n >> 1
n |= n >> 2
n |= n >> 4
n |= n >> 8
n |= n >> 16
n |= n >> 32
n += 1
return n
这是一个非常经典的位运算算法,用于计算大于或等于 nnn 的最小 2 的幂(2k2^k2k)。
triton.runtime.JITFunction.warmup
python
def warmup(self, *args, grid, **kwargs):
return self.run(grid=grid, warmup=True, *map(MockTensor.wrap_dtype, args), **kwargs)
简单来说,warmup 就是一次 "只编译,不执行" 的 run 调用。
MockTensor.wrap_dtype: 这是一个非常聪明的做法。在 warmup 阶段,Triton 并不需要真实的显存数据,它只需要知道数据的 Dtype(类型) 和 Shape(形状)。因此它用"模拟张量"代替真实张量,避免了不必要的内存开销。
warmup=True: 告诉后面的 run 方法,编译完就停下,别往 GPU 上发指令。
triton.runtime.JITFunction.run
python
def run(self, *args, grid, warmup, **kwargs):
kwargs["debug"] = kwargs.get("debug", self.debug) or knobs.runtime.debug
# parse options
device = driver.active.get_current_device()
stream = driver.active.get_current_stream(device)
# Execute pre run hooks with args and kwargs
for hook in self.pre_run_hooks:
hook(*args, **kwargs)
kernel_cache, kernel_key_cache, target, backend, binder = self.device_caches[device]
# specialization is list[tuple[str, Any]], where first element of tuple is
# the type and the second parameter is the 'specialization' value.
bound_args, specialization, options = binder(*args, **kwargs)
key = compute_cache_key(kernel_key_cache, specialization, options)
kernel = kernel_cache.get(key, None)
# Kernel is not cached; we have to compile.
if kernel is None:
options, signature, constexprs, attrs = self._pack_args(backend, kwargs, bound_args, specialization,
options)
kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
if kernel is None:
return None
# Check that used global values have not changed.
not_present = object()
for (name, _), (val, globals_dict) in self.used_global_vals.items():
if (newVal := globals_dict.get(name, not_present)) != val:
raise RuntimeError(
f"Global variable {name} has changed since we compiled this kernel, from {val} to {newVal}")
if not warmup:
# canonicalize grid
assert grid is not None
if callable(grid):
grid = grid(bound_args)
grid_size = len(grid)
grid_0 = grid[0]
grid_1 = grid[1] if grid_size > 1 else 1
grid_2 = grid[2] if grid_size > 2 else 1
if hasattr(kernel, "result"):
kernel = kernel.result()
# launch kernel
launch_metadata = kernel.launch_metadata(grid, stream, *bound_args.values())
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
knobs.runtime.launch_enter_hook, knobs.runtime.launch_exit_hook, *bound_args.values())
return kernel
第一阶段:缓存查找 (The Cache Lookup)
python
kernel_cache, kernel_key_cache, target, backend, binder = self.device_caches[device]
bound_args, specialization, options = binder(*args, **kwargs)
key = compute_cache_key(kernel_key_cache, specialization, options)
kernel = kernel_cache.get(key, None)
-
binder: 将你传入的参数映射到 Kernel 定义的签名上。 -
specialization(特化) : 这是高性能的关键。如果你的stride是 1,或者某个常数是 16,Triton 会根据这些具体数值生成"特化版本"的代码。 -
compute_cache_key: 根据这些特化参数生成一个唯一的哈希值(Key)。如果之前编译过完全一样的配置,直接从kernel_cache拿结果。
第二阶段:编译 (The Compilation)
python
if kernel is None:
options, signature, constexprs, attrs = self._pack_args(...)
kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
-
如果缓存没中,就会调用
_do_compile。这是最重的一步,会将 Python AST 转换成 TTIR -> LLVM IR -> PTX -> CUBIN(二进制机器码)。 -
编译完成后,生成的
kernel对象里就包含了我们需要的物理信息:寄存器数量 (n_regs) 和 共享内存大小 (shared)。
第三阶段:安全检查 (Safety Check)
python
for (name, _), (val, globals_dict) in self.used_global_vals.items():
if (newVal := globals_dict.get(name, not_present)) != val:
raise RuntimeError(...)
- 这是为了防止你修改了 Kernel 引用到的 Python 全局变量。如果变量变了,之前的编译结果可能失效,Triton 会报错强制你重新检查逻辑。
第四阶段:发射 (The Launch)
python
if not warmup:
# 真正的 Kernel 发射逻辑
kernel.run(grid_0, grid_1, grid_2, stream, ...)
-
如果是真正的调用(
warmup=False),它会解析grid(网格大小),然后调用底层驱动将二进制代码推送到 GPU 流中。 -
对于
warmup调用 ,由于warmup=True,这一段会被跳过。 -
最后返回
CompiledKernel类型的kernel对象。
triton.runtime.jit.CompiledKernel
1. kernel._init_handles() 是在干什么?
这个接口是连接 Python 内存 和 GPU 驱动句柄 的最后一步。
-
背景 :在
_do_compile之后,Triton 已经拿到了二进制代码(CUBIN),但这只是内存中的一段字节流。 -
功能 :
_init_handles()调用了底层驱动(也就是你之前看的CudaUtils.load_binary),将这段二进制代码真正加载到 GPU 上。 -
副作用 :它会初始化
self.function(GPU 上的函数指针)和self.module(CUDA 模块句柄)。只有初始化了句柄,下面的n_regs等物理属性才是准确可读的。
2. n_regs 和 metadata.shared 是哪来的?
这两个属性不是算出来的,而是 "读"出来的。
kernel.n_regs (Registers)
-
来源:由 NVIDIA 的反汇编工具或驱动 API 提供。
-
含义:每个线程(Thread)在执行这个内核时,物理上需要占据多少个 32-bit 寄存器。
-
重要性:GPU 的寄存器堆是所有线程共享的资源。如果一个线程用 64 个寄存器,而 SM 总共只有 64K 个,那么能同时并行的线程数就被锁死了。
kernel.metadata.shared (Shared Memory)
-
来源 :Triton 编译器在分配
tl.static_shared空间时记录的数值。 -
含义:这个内核静态申请的共享内存字节数(Bytes)。
-
重要性 :它是除了寄存器之外,限制 GPU 占用率 (Occupancy) 的第二大瓶颈。
启动
python
torch.manual_seed(0)
x = torch.randn(1823, 781, device=DEVICE)
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'], # argument names to use as an x-axis for the plot
x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['triton', 'torch', 'naive_softmax'], # possible values for `line_arg``
line_names=["Triton", "Torch", "Naive Softmax"], # label name for the lines
styles=[('blue', '-'), ('green', '-'), ('red', '-')], # line styles
ylabel="GB/s", # label name for the y-axis
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`
))
def benchmark(M, N, provider):
x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
stream = getattr(torch, DEVICE.type).Stream()
getattr(torch, DEVICE.type).set_stream(stream)
if provider == 'torch':
ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
if provider == 'triton':
ms = triton.testing.do_bench(lambda: softmax(x))
if provider == 'naive_softmax':
ms = triton.testing.do_bench(lambda: naive_softmax(x))
gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms)
benchmark.run(show_plots=True, print_data=True)