[Triton笔记3]融合 Softmax (Fused Softmax)

先看使用pytorch实现的softmax

python 复制代码
import torch
def navie_softmax(x:torch.Tensor):
    '''
    输入:(M, N)
    逻辑:沿着"列"(dim=1)这个轴去寻找。也就是说,对每一行,扫过它所有的列,吐出一个最大值。
    结果:产生 M 个最大值。
    '''
    x_max=x.max(dim=1)[0]#PyTorch 的 max 会返回两个值:(values, indices)。[0] 取的是数值,[1] 取的是对应的索引(类似 argmax)。

    '''
    这里发生了 Broadcasting(广播)。
    x 是 (M, N),x_max 算出来是 (M,)。
    Torch 会自动把 $_max 复制 N 列,变成 (M, N),然后进行元素级相减。
    '''
    z=x-x_max

    numerator=torch.exp(z)

    denominator=numerator.sum(dim=1)

    #None(或者 np.newaxis)的作用是在那个位置插入一个长度为 1 的新维度。
    #这样 (M, N) / (M, 1) 就能触发广播,每一行都除以该行对应的和。
    ret=numerator/denominator[:,None]

    return ret

直接在 PyTorch 中实现时,对于 x∈M×N,计算 y = naive_softmax(x) 需要从 DRAM 中读取 5MN+2M 个元素,并写回 3MN+2M 个元素。

这显然是浪费的;我们更希望有一个自定义的「融合」内核,它只需读取一次 X,并在芯片上进行所有必要的计算。

softmax kernel

softmax 内核工作原理如下:每个Thread Block加载输入矩阵 X 的一组行,按Block数量跨步处理,对其进行归一化,并将结果写回输出 Y。

其实就是一种非常经典的模式:网格跨步循环(Grid Stride Loop)。这在cuda里也很常用。

注意,Triton 的一个重要限制是每个块必须具有 2 的幂次数的元素,因此,如果我们要处理任意可能的输入形状,我们需要在内部「填充」每一行,并适当保护内存操作。

python 复制代码
@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
                   num_stages: tl.constexpr):
    # 程序起始行
    row_start = tl.program_id(0)
    row_step = tl.num_programs(0)
    for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
        # 步长表示我们需要对指针增加多少以推进 1 行
        row_start_ptr = input_ptr + row_idx * input_row_stride
        # 块大小是大于 n_cols 的下一个二的幂,因此我们可以适配
        # 单个块中的行
        col_offsets = tl.arange(0, BLOCK_SIZE)
        input_ptrs = row_start_ptr + col_offsets
        # 将行加载到 SRAM 中,使用掩码,因为 BLOCK_SIZE 可能大于 n_cols
        mask = col_offsets < n_cols
        row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
        # 为了数值稳定性而减去最大值
        row_minus_max = row - tl.max(row, axis=0)
        # 请注意,Triton 中的指数运算速度很快,但是是近似的(例如,类似于 CUDA 中的 __expf)。
        numerator = tl.exp(row_minus_max)
        denominator = tl.sum(numerator, axis=0)
        softmax_output = numerator / denominator
        # 将输出写回 DRAM
        output_row_start_ptr = output_ptr + row_idx * output_row_stride
        output_ptrs = output_row_start_ptr + col_offsets
        tl.store(output_ptrs, softmax_output, mask=mask)

softmax_kernel 输入参数详解

参数名称 类型 含义 Infra/CUDA 视角对照
output_ptr 指针 输出张量在显存(DRAM)中的首地址。 对应 float* out
input_ptr 指针 输入张量在显存(DRAM)中的首地址。 对应 const float* in
input_row_stride 整数 输入矩阵相邻两行之间的元素距离。 对应 2D 数组的 pitchwidth
output_row_stride 整数 输出矩阵相邻两行之间的元素距离。 通常与输入相同,除非输出做了 Padding
n_rows 整数 矩阵的总行数(MMM)。 循环终止的边界条件
n_cols 整数 矩阵的总列数(NNN)。 用于生成 mask 防止越界访问
BLOCK_SIZE tl.constexpr 编译时常量。定义每个 Program 一次加载的向量长度。 必须是 2 的幂,且 ≥n_cols\ge n\_cols≥n_cols
num_stages tl.constexpr 编译时常量。控制流水线循环的深度。 决定了异步搬运和计算的重叠程度
1. 为什么需要 input_row_stride

在 PyTorch 或 C++ 中,矩阵可能不是连续存储的(例如对大矩阵取了子块 slice)。

  • input_row_stride 允许 Kernel 正确跳过内存。

  • 物理地址计算公式:address = base + row * stride + col

2. BLOCK_SIZE 的特殊性

在 Triton 的这个实现里,BLOCK_SIZE 是一个强约束

  • 因为它在 Kernel 里执行了 row - tl.max(row),这就要求 整行数据必须同时存在于同一个 Block 的寄存器中

  • 所以调用这个 Kernel 前,Host 端通常会计算 BLOCK_SIZE = next_power_of_2(n_cols)

3. num_stages 对循环的影响

注意代码中的循环:

Python 复制代码
for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):

在 Triton 中,当 tl.range 配合 num_stages > 1 使用时,编译器会尝试进行 Software Pipelining(软件流水线化)

  • Stage 1 : 发起第 i+1i+1i+1 行的异步 Load 请求。

  • Stage 2 : 计算第 iii 行的 Softmax。

    这在 NVIDIA Ampere (A100) 及之后的架构上能极大地掩盖 DRAM 访存延迟。

triton.language.num_programs

简单来说,tl.num_programs(axis) 对应的就是 CUDA 里的 gridDim。它告诉你整个计算网格在某个维度上一共有多少个 Block(在 Triton 里叫 Program)。

python 复制代码
@builtin
def num_programs(axis, _semantic=None):
    """
    Returns the number of program instances launched along the given :code:`axis`.

    :param axis: The axis of the 3D launch grid. Must be 0, 1 or 2.
    :type axis: int
    """
    axis = _unwrap_if_constexpr(axis)
    return _semantic.num_programs(axis)
属性 内容
函数作用 获取在特定维度上启动的 Program(计算实例/Block)总数
输入参数 (axis) 维度索引。必须是 0, 12(对应 3D 网格的 X, Y, Z 轴)。
返回值 一个整数。代表该轴向上的网格大小(Grid Size)。

如果你有 CUDA 背景,可以这样秒懂:

Triton 代码 CUDA C++ 等价物 物理含义
tl.num_programs(0) gridDim.x X 轴上的总 Block 数
tl.num_programs(1) gridDim.y Y 轴上的总 Block 数
tl.num_programs(2) gridDim.z Z 轴上的总 Block 数

需要注意:num_programs 的值是由 Host 端启动 Kernel 时传入的 grid 参数 决定的。

python 复制代码
grid = (64, 1, 1) # 启动 64 个 Program
softmax_kernel[grid](...)

triton.language.range

python 复制代码
class range(base_value):
    """
    Iterator that counts upward forever.

    .. highlight:: python
    .. code-block:: python

        @triton.jit
        def kernel(...):
            for i in tl.range(10, num_stages=3):
                ...
    :note: This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of
        :code:`triton.jit` functions. In addition, it allows user to pass extra attributes to the compiler.
    :param arg1: the start value.
    :param arg2: the end value.
    :param step: the step value.
    :param num_stages: pipeline the loop into this many stages (so there are
        :code:`num_stages` iterations of the loop in flight at once).

        Note this is subtly different than passing :code:`num_stages` as a
        kernel argument.  The kernel argument only pipelines loads that feed
        into :code:`dot` operations, while this attribute tries to pipeline most
        (though not all) loads in this loop.
    :param loop_unroll_factor: Tells the Triton IR level loop unroller how many
        times to unroll a for loop that this range is used with. Less than 2 for
        this value implies no unrolling.
    :param disallow_acc_multi_buffer: If true, prevent the accumulator of the dot
        operation in the loop to be multi-buffered, if applicable.
    :param flatten: automatically flatten the loop nest starting at this loop to
        create a single flattened loop. The compiler will try to pipeline the
        flattened loop which can avoid stage stalling.
    :param warp_specialize: Enable automatic warp specialization on the loop.
        The compiler will attempt to partition memory, MMA, and vector
        operations in the loop into separate async partitions. This will
        increase the total number of warps required by the kernel.
    :param disable_licm: Tells the compiler it shouldn't hoist loop invariant
        code outside the loop. This is often useful to avoid creating long liveranges
        within a loop.

        Note that warp specialization is only supported on Blackwell GPUs and
        only works on simple matmul loops. Support for arbitrary loops will be
        expanded over time.
    """

    def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None,
                 disallow_acc_multi_buffer=False, flatten=False, warp_specialize=False, disable_licm=False):
        if step is None:
            self.step = constexpr(1)
        else:
            self.step = step
        if arg2 is None:
            self.start = constexpr(0)
            self.end = arg1
        else:
            self.start = arg1
            self.end = arg2
        self.num_stages = num_stages
        self.loop_unroll_factor = loop_unroll_factor
        self.disallow_acc_multi_buffer = disallow_acc_multi_buffer
        self.flatten = flatten
        self.warp_specialize = warp_specialize
        self.disable_licm = disable_licm

    def __iter__(self):
        raise RuntimeError("tl.range can only be used in @triton.jit'd functions")

    def __next__(self):
        raise RuntimeError("tl.range can only be used in @triton.jit'd functions")
1. 基础迭代控制
  • arg1 (start): 迭代的起始值(如果只传一个参数,则为 0)。

  • arg2 (end): 迭代的终止值(不包含)。

  • step (step): 迭代步长,默认为 1。

2. 软件流水线 (Software Pipelining)
  • num_stages:

    • 翻译 :将循环流水线化为指定的级数(即同时有 num_stages 个迭代在并行执行)。

    • Infra 解析:这是最重要的参数。这里的参数会尝试流水线化循环内的大多数加载操作。它利用双缓冲或多缓冲技术,在计算当前迭代时,异步搬运后续迭代的数据。

3. 循环展开 (Loop Unrolling)
  • loop_unroll_factor:

    • 翻译:告诉 Triton IR 级别的循环展开器,这个循环需要展开多少次。如果值小于 2,则表示不展开。

    • Infra 解析 :对应 CUDA 里的 #pragma unroll。通过增加寄存器压力来减少循环开销(计数器加法、条件跳转)。

4. 累加器多缓冲控制
  • disallow_acc_multi_buffer:

    • 翻译 :如果为 True,防止循环中 dot 操作的累加器(accumulator)使用多缓冲。

    • Infra 解析:通常在流水线化时,为了性能会给累加器开辟多个缓冲区。如果你面临极端的寄存器压力,可以设为 True 来节省空间。

5. 循环嵌套打平 (Loop Flattening)
  • flatten:

    • 翻译:自动将从当前循环开始的嵌套循环打平为单个循环。编译器会尝试对打平后的循环进行流水线化,以避免阶段性停顿(Stage Stalling)。

    • Infra 解析:这在处理 2D 数据块迭代时非常有用,能让指令流水线更平滑。

6. Warp 特化 (Warp Specialization)
  • warp_specialize:

    • 翻译:启用自动 Warp 特化。编译器会将循环中的内存操作、MMA(矩阵乘法)和向量操作分配到不同的异步分区中。这会增加 Kernel 所需的 Warp 总数。

    • Infra 解析 :这是 Blackwell (NVIDIA GB200) 等新架构的高级特性。它让一部分 Warp 专门负责搬运数据(Producer),另一部分专门负责计算(Consumer),实现硬件级别的任务解耦。

7. 禁用循环不变式外提 (Disable LICM)
  • disable_licm:

    • 翻译:告诉编译器不要将循环内不变的代码提到循环外。

    • Infra 解析 :LICM (Loop Invariant Code Motion) 是通用编译优化,但在 GPU 上,外提代码可能会显著增加寄存器生命周期(Live Ranges),导致寄存器压力过大(Register Pressure)。禁用它可以让寄存器回收更快。

总结对照表

参数 对应 C++/编译优化概念 目的
num_stages Double/Triple Buffering 掩盖访存延迟 (Hide Latency)
loop_unroll_factor #pragma unroll 减少分支指令开销
flatten Loop Collapse 优化流水线效率
warp_specialize Producer-Consumer Model 硬件资源解耦 (New Architectures)
disable_licm Register Pressure Control 防止寄存器溢出 (Spilling)

triton.language.max

python 复制代码
def max(input, axis=None, return_indices=False, 
return_indices_tie_break_left=True, keep_dims=False):

该函数的主要功能是计算输入张量 input 在指定维度(axis)上的最大值

它不仅支持基础的数值归约,还处理了以下高级特性:

  • 精度提升 (Promotion) :自动将低位宽浮点数(如 bfloat16)转为 float32 以保证计算精度。

  • 索引回传 (Argmax):可选地返回最大值对应的索引位置。

  • 平局处理 (Tie-break):当存在多个相同的最大值时,决定返回哪一个索引(最左侧或快速模式)。

参数解释

参数名 类型 描述
input Tensor 需要进行最大值运算的输入张量。
axis int/tuple/None 执行归约的轴。如果为 None,通常代表对全量元素进行归约。
return_indices bool 关键参数 。若为 True,函数不仅返回最大值,还会返回最大值的索引(类似于 argmax)。
return_indices_tie_break_left bool 仅在 return_indices=True 时有效。若为 True,遇到多个最大值时确保返回**第一个(索引最小)**的索引。
keep_dims bool 是否保留归约后的维度。若为 True,被归约的轴长度变为 1。

triton.language.exp

python 复制代码
@core.builtin
@_check_dtype(dtypes=["fp32", "fp64"])
@_add_math_1arg_docstr("exponential")
@core._tensor_member_fn
def exp(x, _semantic=None):

这个 exp(指数函数,即 exe^xex)的实现属于 Triton 的底层原子算子封装。

该函数用于计算输入张量 x 中每个元素的指数值。它是逐元素(Element-wise)运算,不涉及跨维度的归约(Reduction)。

参数名 描述
x 输入数据。根据装饰器限制,必须是浮点类型(fp32fp64)。

triton.language.sum

python 复制代码
@core._tensor_member_fn
@jit
@core._add_reduction_docstr("sum", dtype_arg="dtype")
def sum(input, axis=None, keep_dims=False, dtype: core.constexpr = None):

该函数用于计算输入张量 input 在指定维度(axis)上的累加和。

与 max 类似,它也是一个归约(Reduction)算子,但它特别强调了对输出数据类型(dtype)的控制,以防止在累加过程中出现数值溢出。

参数解释

参数名 类型 描述
input Tensor 需要求和的输入张量。
axis int/tuple/None 执行求和的轴。如果为 None,则对张量内所有元素求和。
keep_dims bool 是否保留被归约的维度(保持维度秩不变)。
dtype core.constexpr 关键参数 。强制指定的输出类型。由于求和极易导致数值超出原类型范围(如 int8 累加很快会溢出),用户可以手动指定更高的精度。

辅助函数

python 复制代码
DEVICE = triton.runtime.driver.active.get_active_torch_device()


def is_hip():
    return triton.runtime.driver.active.get_current_target().backend == "hip"


def is_cdna():
    return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
                                                                                   'gfx90a', 'gfx908')

properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}


def softmax(x):
    n_rows, n_cols = x.shape

    # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
    BLOCK_SIZE = triton.next_power_of_2(n_cols)

    # Another trick we can use is to ask the compiler to use more threads per row by
    # increasing the number of warps (`num_warps`) over which each row is distributed.
    # You will see in the next tutorial how to auto-tune this value in a more natural
    # way so you don't have to come up with manual heuristics yourself.
    num_warps = 8

    # Number of software pipelining stages.
    num_stages = 4 if SIZE_SMEM > 200000 else 2

    # Allocate output
    y = torch.empty_like(x)

    # pre-compile kernel to get register usage and compute thread occupancy.
    kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
                                   num_stages=num_stages, num_warps=num_warps, grid=(1, ))
    kernel._init_handles()
    n_regs = kernel.n_regs
    size_smem = kernel.metadata.shared
    if is_hip():
        # NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
        # However, this is not always the case. In most cases all registers can be used as regular purpose registers.
        # ISA SECTION (3.6.4 for CDNA3)
        # VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
        # with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
        # VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
        # not required to be equal numbers of both types.
        NUM_GPRS = NUM_REGS
        if is_cdna():
            NUM_GPRS = NUM_REGS * 2

        # MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
        # When we divide this number with WARP_SIZE we get maximum number of waves that can
        # execute on a CU (multi-processor)  in parallel.
        MAX_NUM_THREADS = properties["max_threads_per_sm"]
        max_num_waves = MAX_NUM_THREADS // WARP_SIZE
        occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
    else:
        occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
    occupancy = min(occupancy, SIZE_SMEM // size_smem)
    num_programs = NUM_SM * occupancy

    num_programs = min(num_programs, n_rows)

    # Create a number of persistent programs.
    kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)
    return y

这段代码展示了 Triton 中一个非常经典的高级技巧:持久化算子(Persistent Kernel) 的启动环境配置。

与普通的算子调用不同,这个辅助函数不是简单地启动与数据规模一样多的线程块,而是通过硬件感知(Hardware-aware)的计算,精确控制并行度以最大化利用 GPU 资源。

1. 硬件资源探测 (Hardware Querying)

在代码开头,它通过 driver 接口直接获取了当前 GPU 的物理参数:

  • NUM_SM: GPU 核心(流式多处理器)的数量。

  • NUM_REGS & SIZE_SMEM: 每个 SM 拥有的寄存器总量和共享内存(L1 缓存)总量。

  • 这些参数决定了 GPU 的"胃口",即它一次性最高能处理多少计算任务。

2. 计算占用率 (Occupancy Calculation)

这是该函数最精华的部分。它通过"预编译(warmup)"来探测内核的资源消耗,并计算占用率(Occupancy)

  • 寄存器约束:计算每个 SM 能容纳多少个线程组(Warps)。

  • 共享内存约束:计算 SM 的内存空间够分给几个 Block。

  • 最终结果occupancy 取两者的最小值。这代表了在硬件不宕机、不排队的前提下,一个物理 SM 同时能跑多少个程序实例

3. 设置持久化网格 (Persistent Grid)

通常的写法是 grid = (n_rows, ),即一行对应一个 Block。但如果行数非常多(比如 100 万行),GPU 会频繁地调度、创建和销毁 Block,产生巨大的开销。

该函数通过 num_programs = NUM_SM * occupancy 设置了网格大小:

  • 它只启动刚好能填满整个 GPU 物理槽位的 Block 数量。

  • 这些 Block 启动后不会立即退出,而是通过内部循环"吞掉"所有的行(Persistent 模式)。

  • 好处:极大地减少了 GPU 调度器的压力,并提高了 L2 缓存的命中率。

4. 编译优化参数

  • BLOCK_SIZE : 自动寻找大于列数的最小 2n2^n2n,确保内存对齐且符合 Triton 的编译要求。

  • num_stages: 根据共享内存大小决定软件流水线(Software Pipelining)的深度。如果内存够大,就用 4 阶流水线来隐藏指令延迟。

  • num_warps: 固定使用 8 个 Warp(256 个线程)来处理每一行,保证了足够的线程并行度。

triton.runtime.DriverConfig

python 复制代码
def _create_driver() -> DriverBase:
    active_drivers = [x.driver for x in backends.values() if x.driver.is_active()]
    if len(active_drivers) != 1:
        raise RuntimeError(f"{len(active_drivers)} active drivers ({active_drivers}). There should only be one.")
    return active_drivers[0]()


class DriverConfig:

    def __init__(self) -> None:
        self._default: DriverBase | None = None
        self._active: DriverBase | None = None

    @property
    def default(self) -> DriverBase:
        if self._default is None:
            self._default = _create_driver()
        return self._default

    @property
    def active(self) -> DriverBase:
        if self._active is None:
            self._active = self.default
        return self._active

    def set_active(self, driver: DriverBase) -> None:
        self._active = driver

    def reset_active(self) -> None:
        self._active = self.default


driver = DriverConfig()

triton.runtime.DriverConfig 类在 Triton 架构中扮演的是 "后端驱动管理器" 的角色。它的设计非常简洁,核心目的是为了实现 惰性初始化(Lazy Initialization) 和 多后端切换。

由于 Triton 需要支持不同的硬件(如 NVIDIA GPU、AMD GPU 等),它不能在启动时就写死驱动。DriverConfig 就像一个转换插头,确保系统总能找到当前环境下最合适的硬件接口。

1. 核心成员变量

  • _default : 存储系统检测到的默认驱动。例如在 NVIDIA 环境下,它通常指向 CUDADriver

  • _active : 当前正在使用的驱动。在大多数情况下,它和 _default 是一样的,但在复杂的调试或多硬件共存环境下,你可以手动切换。

2. 核心逻辑:惰性初始化

注意到 defaultactive 都是 @property(属性装饰器):

  • 当你运行 driver.active 时,它会检查 _active 是否为空。

  • 如果为空,它会去调用 _create_driver()。这个函数会扫描系统环境(比如检查是否有 CUDA 库、是否有 ROCm 环境),然后动态创建一个驱动实例。

  • 好处:如果你的代码只是定义了函数而没有实际在 GPU 上运行,Triton 就不会去加载沉重的驱动库,节省开销。

triton.backends.driver.DriverBase和GPUDriver

DriverBase 定义了规则(接口),而 GPUDriver 则是针对具体硬件(目前主要是通过 PyTorch 桥接的 GPU)的具体实现。

python 复制代码
class DriverBase(metaclass=ABCMeta):

    @classmethod
    @abstractmethod
    def is_active(self):
        pass

    @abstractmethod
    def map_python_to_cpp_type(self, ty: str) -> str:
        """
        Converts a Triton type string to its corresponding C++ type string for this backend.

        Args:
            ty (str): The Triton type string. e.g., 'i32', '*fp16', 'fp32'.

        Returns:
            str: The C++ type string.
        """
        pass

    @abstractmethod
    def get_current_target(self):
        pass

    @abstractmethod
    def get_active_torch_device(self):
        pass

    @abstractmethod
    def get_benchmarker(self) -> Benchmarker:
        """
        Return the benchmarking function that this backend should use by default.
        """
        raise NotImplementedError

    def __init__(self) -> None:
        pass


class GPUDriver(DriverBase):

    def __init__(self):
        # TODO: support other frameworks than torch
        import torch
        self.get_device_capability = torch.cuda.get_device_capability
        try:
            from torch._C import _cuda_getCurrentRawStream
            self.get_current_stream = _cuda_getCurrentRawStream
        except ImportError:
            self.get_current_stream = lambda idx: torch.cuda.current_stream(idx).cuda_stream
        self.get_current_device = torch.cuda.current_device
        self.set_current_device = torch.cuda.set_device

    # TODO: remove once TMA is cleaned up
    def assemble_tensormap_to_arg(self, tensormaps_info, args):
        return args

DriverBase 是一个抽象基类 (Abstract Base Class)。它不负责具体干活,而是规定了任何一个 Triton 驱动后端(无论是 NVIDIA、AMD 还是未来的 CPU 后端)必须提供的功能:

  • 类型映射 (map_python_to_cpp_type) : 这是编译器的关键。它把 Triton 的类型(如 i32)翻译成底层编译器(如 LLVM 或 NVCC)能懂的 C++ 类型。

  • 目标检测 (get_current_target) : 确定当前的硬件架构(例如:它是 sm_80 还是 sm_90?)。这决定了编译器能否使用 Tensor Core 或 TMA 等高级特性。

  • 性能评估 (get_benchmarker) : 提供一个标准的"跑分"工具,用于 autotune(自动调优)过程中对比不同配置的快慢。

GPUDriverDriverBase 的子类。有趣的是,它的实现高度依赖于 PyTorch

核心任务:绑定 PyTorch 的底层接口

__init__ 中,你可以看到它在疯狂"偷" PyTorch 的函数:

  1. get_device_capability : 调用 torch.cuda.get_device_capability 来获取显卡的计算能力(Compute Capability)。

  2. get_current_stream : 这是最底层、最关键的部分。它试图获取 CUDA 的原始流句柄(Raw Stream Handle)

    • Triton 需要直接操作 CUDA Stream 来提交任务。

    • 代码中做了兼容性处理:如果能找到 PyTorch 内部的 _cuda_getCurrentRawStream(高性能 C++ 接口)就直接用;否则退而求其次使用常规的 Python 接口。

  3. 设备管理: 绑定了获取和设置当前显卡编号(Device ID)的方法。

triton.backends.nvidia.driver.CudaDriver

python 复制代码
class CudaDriver(GPUDriver):

    def __init__(self):
        self.utils = CudaUtils()  # TODO: make static
        self.launcher_cls = CudaLauncher
        super().__init__()

    def get_current_target(self):
        device = self.get_current_device()
        capability = self.get_device_capability(device)
        capability = capability[0] * 10 + capability[1]
        warp_size = 32
        return GPUTarget("cuda", capability, warp_size)

    def get_active_torch_device(self):
        import torch
        return torch.device("cuda", self.get_current_device())

    def get_device_interface(self):
        import torch
        return torch.cuda

    @staticmethod
    def is_active():
        try:
            import torch
            return torch.cuda.is_available() and (torch.version.hip is None)
        except ImportError:
            return False

    def map_python_to_cpp_type(self, ty: str) -> str:
        return ty_to_cpp(ty)

    def get_benchmarker(self):
        from triton.testing import do_bench
        return do_bench

    def get_empty_cache_for_benchmark(self):
        import torch

        # We maintain a buffer of 256 MB that we clear
        # before each kernel call to make sure that the L2 cache
        # doesn't contain any input data before the run
        cache_size = 256 * 1024 * 1024
        return torch.empty(int(cache_size // 4), dtype=torch.int, device='cuda')

    def clear_cache(self, cache):
        cache.zero_()

算力计算 (get_current_target)

Python 复制代码
capability = capability[0] * 10 + capability[1]

这是一个非常经典的设计。NVIDIA 的算力通常表示为 (major, minor),例如 A100 是 (8, 0),RTX 4090 是 (8, 9)

  • Triton 将其转换为整数(80, 89 等),因为底层编译器(PTXAS)需要根据这个数字来决定使用哪些指令集。

  • Warp Size : 这里硬编码为 32,这是所有 NVIDIA GPU 目前的物理常数。

utils成员

它是CudaUtils类型。

triton.backends.nvidia.driver.CudaUtils

python 复制代码
class CudaUtils(object):

    def __new__(cls):
        if not hasattr(cls, "instance"):
            cls.instance = super(CudaUtils, cls).__new__(cls)
        return cls.instance

    def __init__(self):
        mod = compile_module_from_src(
            src=Path(os.path.join(dirname, "driver.c")).read_text(),
            name="cuda_utils",
            library_dirs=library_dirs(),
            include_dirs=include_dirs,
            libraries=libraries,
        )
        global PyCUtensorMap
        PyCUtensorMap = mod.PyCUtensorMap
        self.load_binary = mod.load_binary
        self.get_device_properties = mod.get_device_properties
        self.cuOccupancyMaxActiveClusters = mod.cuOccupancyMaxActiveClusters
        self.set_printf_fifo_size = mod.set_printf_fifo_size
        self.fill_tma_descriptor = mod.fill_tma_descriptor

CudaUtils 是 Triton 驱动层中最接近 NVIDIA Driver API (CUDA Driver API) 的部分。

这段代码使用了典型的单例模式(Singleton) ,通过 compile_module_from_src 实时编译了一个名为 driver.c 的 C 语言模块。它返回的成员变量实际上是直接映射到 CUDA 驱动层(如 libcuda.sonvcuda.dll)的 C 函数。

以下是这些成员变量的详细含义:

1. self.get_device_properties

  • 对应 CUDA API : cuDeviceGetAttributecuDeviceGetProperties

  • 含义: 获取显卡的物理参数。

  • 用途 : 返回一个字典,包含你在 Softmax 辅助函数中看到的 multiprocessor_count (SM 数量)、max_shared_mem (共享内存大小) 等。这是进行硬件感知优化(如计算 Occupancy)的基础。

2. self.load_binary

  • 对应 CUDA API : cuModuleLoadData / cuModuleGetFunction

  • 含义: 加载二进制内核。

  • 用途 : Triton 编译器会将 Python 代码编译成汇编代码(PTX)或二进制机器码(CUBIN)。这个函数负责把这些二进制数据交给 GPU 驱动,并返回一个可以被调用的内核句柄

3. self.cuOccupancyMaxActiveClusters

  • 对应 CUDA API : cuOccupancyMaxActiveClusters (NVIDIA Hopper 架构及以上)。

  • 含义: 计算最大活跃集群占用率。

  • 用途 : 专门用于 NVIDIA Hopper (H100) 架构。它能计算在特定集群配置下,GPU 物理上能同时并行跑多少个线程块集群。这比传统的 Occupancy 计算更高级,考虑了多 SM 之间的协同。

4. self.set_printf_fifo_size

  • 对应 CUDA API : cuCtxSetLimit (带 CU_LIMIT_PRINTF_FIFO_SIZE 参数)。

  • 含义 : 设置 printf 打印缓冲队列的大小。

  • 用途 : 当你在 Triton Kernel 里使用 tl.static_print 或调试打印时,GPU 需要一块内存来缓存这些字符串。如果打印内容太多,默认缓存会溢出,这个函数用来调大这个缓冲区。

5. self.fill_tma_descriptorPyCUtensorMap

  • 对应 CUDA API: CUDA Tensor Memory Accelerator (TMA) 相关接口。

  • 含义: 填充 TMA 描述符。

  • 用途 : 这是 NVIDIA Hopper (H100) 最核心的特性之一。

    • TMA 允许硬件自动在全局显存和共享内存之间搬运多维张量数据,而无需 CPU 或传统指令干预。

    • PyCUtensorMap 定义了张量在内存中的布局(维度、步长、对齐),fill_tma_descriptor 负责把这些配置写进硬件能读懂的"配置卡"中。

triton.backends.nvidia.driver.CudaUtils.get_device_properties

看一下这个字典都有哪些字段

Python 字典键名 CUDA 属性枚举值 (Attribute) 物理含义说明
max_shared_mem CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN 最大共享内存 :每个 Block 能够申请的最大 Shared Memory(字节)。注意这里用了 OPTIN,表示包含用户手动开启的高级共享内存配置。
max_num_regs CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK 最大寄存器数:每个 Block 可用的 32-bit 寄存器总量。这直接决定了线程束(Warp)的占用率。
multiprocessor_count CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT SM 数量:显卡上流多处理器的总数。Triton 用它来决定 Grid 的大小(即一次能铺满多少个 SM)。
warpSize CU_DEVICE_ATTRIBUTE_WARP_SIZE 线程束大小:NVIDIA 硬件上目前固定为 32。
sm_clock_rate CU_DEVICE_ATTRIBUTE_CLOCK_RATE SM 主频:GPU 核心的工作频率(单位:kHz)。
mem_clock_rate CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE 显存频率:显存的工作频率(单位:kHz)。
mem_bus_width CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH 显存位宽:显存接口的位宽(单位:bit,如 384-bit)。

triton.next_power_of_2

python 复制代码
@constexpr_function
def next_power_of_2(n: int):
    """Return the smallest power of 2 greater than or equal to n"""
    n -= 1
    n |= n >> 1
    n |= n >> 2
    n |= n >> 4
    n |= n >> 8
    n |= n >> 16
    n |= n >> 32
    n += 1
    return n

这是一个非常经典的位运算算法,用于计算大于或等于 nnn 的最小 2 的幂(2k2^k2k)。

triton.runtime.JITFunction.warmup

python 复制代码
def warmup(self, *args, grid, **kwargs):
        return self.run(grid=grid, warmup=True, *map(MockTensor.wrap_dtype, args), **kwargs)

简单来说,warmup 就是一次 "只编译,不执行" 的 run 调用。

MockTensor.wrap_dtype: 这是一个非常聪明的做法。在 warmup 阶段,Triton 并不需要真实的显存数据,它只需要知道数据的 Dtype(类型) 和 Shape(形状)。因此它用"模拟张量"代替真实张量,避免了不必要的内存开销。

warmup=True: 告诉后面的 run 方法,编译完就停下,别往 GPU 上发指令。

triton.runtime.JITFunction.run

python 复制代码
def run(self, *args, grid, warmup, **kwargs):
        kwargs["debug"] = kwargs.get("debug", self.debug) or knobs.runtime.debug

        # parse options
        device = driver.active.get_current_device()
        stream = driver.active.get_current_stream(device)

        # Execute pre run hooks with args and kwargs
        for hook in self.pre_run_hooks:
            hook(*args, **kwargs)

        kernel_cache, kernel_key_cache, target, backend, binder = self.device_caches[device]
        # specialization is list[tuple[str, Any]], where first element of tuple is
        # the type and the second parameter is the 'specialization' value.
        bound_args, specialization, options = binder(*args, **kwargs)

        key = compute_cache_key(kernel_key_cache, specialization, options)
        kernel = kernel_cache.get(key, None)

        # Kernel is not cached; we have to compile.
        if kernel is None:
            options, signature, constexprs, attrs = self._pack_args(backend, kwargs, bound_args, specialization,
                                                                    options)

            kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
            if kernel is None:
                return None

        # Check that used global values have not changed.
        not_present = object()
        for (name, _), (val, globals_dict) in self.used_global_vals.items():
            if (newVal := globals_dict.get(name, not_present)) != val:
                raise RuntimeError(
                    f"Global variable {name} has changed since we compiled this kernel, from {val} to {newVal}")

        if not warmup:
            # canonicalize grid
            assert grid is not None
            if callable(grid):
                grid = grid(bound_args)
            grid_size = len(grid)
            grid_0 = grid[0]
            grid_1 = grid[1] if grid_size > 1 else 1
            grid_2 = grid[2] if grid_size > 2 else 1
            if hasattr(kernel, "result"):
                kernel = kernel.result()
            # launch kernel
            launch_metadata = kernel.launch_metadata(grid, stream, *bound_args.values())
            kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
                       knobs.runtime.launch_enter_hook, knobs.runtime.launch_exit_hook, *bound_args.values())
        return kernel
第一阶段:缓存查找 (The Cache Lookup)
python 复制代码
kernel_cache, kernel_key_cache, target, backend, binder = self.device_caches[device]
bound_args, specialization, options = binder(*args, **kwargs)
key = compute_cache_key(kernel_key_cache, specialization, options)
kernel = kernel_cache.get(key, None)
  • binder: 将你传入的参数映射到 Kernel 定义的签名上。

  • specialization (特化) : 这是高性能的关键。如果你的 stride 是 1,或者某个常数是 16,Triton 会根据这些具体数值生成"特化版本"的代码。

  • compute_cache_key : 根据这些特化参数生成一个唯一的哈希值(Key)。如果之前编译过完全一样的配置,直接从 kernel_cache 拿结果。

第二阶段:编译 (The Compilation)
python 复制代码
if kernel is None:
    options, signature, constexprs, attrs = self._pack_args(...)
    kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
  • 如果缓存没中,就会调用 _do_compile。这是最重的一步,会将 Python AST 转换成 TTIR -> LLVM IR -> PTX -> CUBIN(二进制机器码)。

  • 编译完成后,生成的 kernel 对象里就包含了我们需要的物理信息:寄存器数量 (n_regs)共享内存大小 (shared)

第三阶段:安全检查 (Safety Check)
python 复制代码
for (name, _), (val, globals_dict) in self.used_global_vals.items():
    if (newVal := globals_dict.get(name, not_present)) != val:
        raise RuntimeError(...)
  • 这是为了防止你修改了 Kernel 引用到的 Python 全局变量。如果变量变了,之前的编译结果可能失效,Triton 会报错强制你重新检查逻辑。
第四阶段:发射 (The Launch)
python 复制代码
if not warmup:
    # 真正的 Kernel 发射逻辑
    kernel.run(grid_0, grid_1, grid_2, stream, ...)
  • 如果是真正的调用(warmup=False),它会解析 grid(网格大小),然后调用底层驱动将二进制代码推送到 GPU 流中。

  • 对于 warmup 调用 ,由于 warmup=True,这一段会被跳过。

  • 最后返回CompiledKernel类型的kernel对象。

triton.runtime.jit.CompiledKernel

1. kernel._init_handles() 是在干什么?

这个接口是连接 Python 内存GPU 驱动句柄 的最后一步。

  • 背景 :在 _do_compile 之后,Triton 已经拿到了二进制代码(CUBIN),但这只是内存中的一段字节流。

  • 功能_init_handles() 调用了底层驱动(也就是你之前看的 CudaUtils.load_binary),将这段二进制代码真正加载到 GPU 上。

  • 副作用 :它会初始化 self.function(GPU 上的函数指针)和 self.module(CUDA 模块句柄)。只有初始化了句柄,下面的 n_regs 等物理属性才是准确可读的。


2. n_regsmetadata.shared 是哪来的?

这两个属性不是算出来的,而是 "读"出来的

kernel.n_regs (Registers)
  • 来源:由 NVIDIA 的反汇编工具或驱动 API 提供。

  • 含义:每个线程(Thread)在执行这个内核时,物理上需要占据多少个 32-bit 寄存器。

  • 重要性:GPU 的寄存器堆是所有线程共享的资源。如果一个线程用 64 个寄存器,而 SM 总共只有 64K 个,那么能同时并行的线程数就被锁死了。

kernel.metadata.shared (Shared Memory)
  • 来源 :Triton 编译器在分配 tl.static_shared 空间时记录的数值。

  • 含义:这个内核静态申请的共享内存字节数(Bytes)。

  • 重要性 :它是除了寄存器之外,限制 GPU 占用率 (Occupancy) 的第二大瓶颈。

启动

python 复制代码
torch.manual_seed(0)
x = torch.randn(1823, 781, device=DEVICE)
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['N'],  # argument names to use as an x-axis for the plot
        x_vals=[128 * i for i in range(2, 100)],  # different possible values for `x_name`
        line_arg='provider',  # argument name whose value corresponds to a different line in the plot
        line_vals=['triton', 'torch', 'naive_softmax'],  # possible values for `line_arg``
        line_names=["Triton", "Torch", "Naive Softmax"],  # label name for the lines
        styles=[('blue', '-'), ('green', '-'), ('red', '-')],  # line styles
        ylabel="GB/s",  # label name for the y-axis
        plot_name="softmax-performance",  # name for the plot. Used also as a file name for saving the plot.
        args={'M': 4096},  # values for function arguments not in `x_names` and `y_name`
    ))
def benchmark(M, N, provider):
    x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
    stream = getattr(torch, DEVICE.type).Stream()
    getattr(torch, DEVICE.type).set_stream(stream)
    if provider == 'torch':
        ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
    if provider == 'triton':
        ms = triton.testing.do_bench(lambda: softmax(x))
    if provider == 'naive_softmax':
        ms = triton.testing.do_bench(lambda: naive_softmax(x))
    gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms)


benchmark.run(show_plots=True, print_data=True)
相关推荐
岑梓铭2 小时前
考研408《操作系统》复习笔记,第二章《2.3.3 + 2.3.4 经典同步问题、管程》
笔记·考研·操作系统·408·os
前端小马2 小时前
PTE考试笔记
笔记·英语
m0_46644103詹湛3 小时前
FPGA时序优化与高速接口实战手册
笔记·学习·fpga开发·硬件架构·verilog
问心无愧05133 小时前
ctf show web 入门39
android·前端·笔记
Yeh2020583 小时前
Mybatis笔记一
java·笔记·mybatis
羊群智妍3 小时前
2026 AI搜索优化技术:GEO监测工具选型与应用
笔记
半导体守望者4 小时前
MKS elite 300 600 750W RF Plasma Generator 射频电源 OPERATIONMANUAL
经验分享·笔记·机器人·自动化·制造
05候补工程师4 小时前
【线性代数笔记】初等变换、正交化与特殊矩阵性质核心总结
经验分享·笔记·线性代数·考研·矩阵
Heartache boy5 小时前
野火STM32_HAL库版课程笔记-I2C介绍
笔记·stm32·单片机