[Triton笔记7]融合注意力 (Fused Attention)

接口铺垫

triton.language.multiple_of

python 复制代码
@builtin
def multiple_of(input, values, _semantic=None):
    """
    Let the compiler know that the values in :code:`input` are all multiples of :code:`value`.
    """
    if isinstance(values, constexpr):
        values = [values]
    for i, d in enumerate(values):
        if not isinstance(d, constexpr):
            raise TypeError(f"values element {i} must have type `constexpr`")
        if not isinstance(d.value, int):
            raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
    values = [x.value for x in values]
    return _semantic.multiple_of(input, values)
条目 名称 作用 / 含义 类型约束
输入参数 input 需要标记对齐属性的张量或指针。通常是计算出的内存地址偏移量。 TensorPointer
输入参数 values 声明 input 能够被哪些数值整除。告诉编译器:"这里的数都是 X 的倍数"。 tl.constexpr 或其列表
输入参数 _semantic Triton 内部语义解析对象(开发者通常无需手动干预)。 内部对象
返回值 output 返回与 input 数值完全相同的张量。区别在于该返回对象在编译器内部携带了对齐元数据。 input 类型相同

multiple_of 本质上是一个编译器提示(Compiler Hint) ,它在逻辑层面上是一个"恒等函数"(即 f ( x ) = x f(x) = x f(x)=x),但在物理建模层面它改变了编译器对数据的认知。

该函数并不会在运行时对数据执行取模运算或检查,它只是在编译阶段给 input 贴上了一个"标签"。编译器看到这个标签后,会意识到:"既然这些索引都是 16 的倍数,那我可以用最快的 128-bit 矢量指令来搬运数据,而不用担心跨行或非对齐访问的问题。"

在 GPU 体系结构中,内存合并(Memory Coalescing)是性能的生命线。

  • 没有提示:编译器可能采取保守策略,生成多个标量加载指令,导致带宽利用率低下。

  • 使用提示:编译器敢于生成宽矢量指令(如一次性读取 16 字节),大幅减少指令数量并提升吞吐量。

python 复制代码
# 假设我们处理的是 float32 数据(4字节)
# 如果 BLOCK_SIZE 是 32,那么 offset 应该是 32 * 4 = 128 字节的倍数
# 这符合大多数 GPU 的内存对齐要求
offset = tl.program_id(0) * BLOCK_SIZE
aligned_offset = tl.multiple_of(offset, BLOCK_SIZE)

# 使用带标签的 aligned_offset 进行访存,编译器会生成优化后的加载指令
ptr = base_ptr + aligned_offset
data = tl.load(ptr)

风险提示 :如果你对编译器"撒谎"(例如 input 实际上不是 values 的倍数),编译器依然会按照对齐的方式生成指令,这通常会导致非法内存访问(Illegal Memory Access) 或得到错误的数据结果

triton.language.advance

triton.language.advance 是专门用于 Block Pointer(块指针) 操作的函数。它允许你在多维空间中移动一个预先定义好的块指针。

python 复制代码
@must_use_result(
    "Note that tl.advance does not have any side effects. To move the block pointer, you need to assign the result of tl.advance to a variable."
)
@_tensor_member_fn
@builtin
def advance(base, offsets, _semantic=None):
    """
    Advance a block pointer

    :param base: the block pointer to advance
    :param offsets: the offsets to advance, a tuple by dimension
    """
    return _semantic.advance(base, offsets)
条目 名称 作用 / 含义 类型约束
输入参数 base 现有的块指针(Block Pointer)。它是移动操作的起点。 block_ptr
输入参数 offsets 在每个维度上需要移动的偏移量。例如 (16, 32) 表示在第 0 维移动 16 个元素,第 1 维移动 32 个元素。 tuple of inttl.constexpr
输入参数 _semantic Triton 内部语义解析对象。 内部对象
返回值 new_ptr 返回一个新的块指针 ,指向移动后的目标位置。原指针 base 不会发生改变。 block_ptr

函数功能详解

advance 的核心功能是平移(Offset)一个块指针,类似于在多维数组上滑动一个窗口。

1. 纯函数属性(无副作用)

正如源码中 @must_use_result 警告所言,tl.advance 不会修改原指针 。在底层,块指针是一个包含基地址、形状、步长和当前位置的结构体。执行 advance 只是计算出一个位置更新后的新结构体。

错误写法tl.advance(ptr, (16, 16)) (指针没动)

正确写法ptr = tl.advance(ptr, (16, 16))

2. 与传统指针算术的区别

在 Triton 中,如果你使用常规指针(Scalar Pointer),你会手动加减偏移量;但当你使用 tl.make_block_ptr 创建了高级块指针后,必须使用 tl.advance 来切换到下一个数据块。

3. 典型应用场景:循环迭代

在矩阵乘法(GEMM)等算法中,我们需要沿着某个维度(如 K K K 维度)不断移动读取数据块。advance 是实现这种"滑动窗口"读取的标准方式。

4. 代码示例
python 复制代码
# 创建一个指向矩阵 A 起始位置的块指针
a_ptr = tl.make_block_ptr(
    base=A,
    shape=(M, K),
    strides=(stride_am, stride_ak),
    offsets=(0, 0),
    block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K),
    order=(1, 0)
)

# 在循环中沿着 K 维度移动指针
for k in range(0, K, BLOCK_SIZE_K):
    # 加载当前块的数据
    a_block = tl.load(a_ptr)
    
    # 将指针向 K 维度(第 1 维)移动 BLOCK_SIZE_K 个单位
    # 注意:必须重新赋值给 a_ptr
    a_ptr = tl.advance(a_ptr, (0, BLOCK_SIZE_K))

总结

  • 功能:移动块指针的位置。

  • 关键点:必须接收返回值,否则移动无效。

  • 优势:配合块指针使用,能够自动处理复杂的边界逻辑和步长计算,让代码更简洁且易于优化。

triton.language.make_block_ptr

python 复制代码
@builtin

def make_block_ptr(base: tensor, shape, strides, offsets, block_shape, order, _semantic=None):

    """
    Returns a pointer to a block in a parent tensor

    :param base: The base pointer to the parent tensor
    :param shape: The shape of the parent tensor
    :param strides: The strides of the parent tensor
    :param offsets: The offsets to the block
    :param block_shape: The shape of the block
    :param order: The order of the original data format
    """

    warn("tl.make_block_ptr is deprecated. Use TensorDescriptor or tl.make_tensor_descriptor instead.")

    return _semantic.make_block_ptr(base, shape, strides, offsets, block_shape, order) 

triton.language.make_block_ptr 是 Triton 中分块指针(Block Pointer)机制的核心。它将原本扁平的内存地址空间抽象为多维的"张量视图",使得在 Kernel 中处理多维数据块变得极其直观。

虽然源码中提示该函数已被标记为 deprecated(推荐使用 TensorDescriptor),但在目前的 Triton 生态中,它依然是理解和编写高性能 Kernel 的基石。

条目 名称 作用 / 含义 类型约束
输入参数 base 父张量的起始地址。通常是传入 Kernel 的全局内存指针。 Pointer
输入参数 shape 完整父张量的逻辑形状(如整个矩阵的行数和列数)。 tuple of int
输入参数 strides 指向张量各维度在内存中跨越的步长(通常用于处理非连续内存)。 tuple of int
输入参数 offsets 当前块在父张量中的起始坐标偏移量。 tuple of int
输入参数 block_shape 该指针所指向的数据块的大小(即每次 tl.load 读入的大小)。 tuple of int
输入参数 order 数据的存储布局(维度优先级)。例如 (1, 0) 表示列优先。 tuple of int
输入参数 _semantic Triton 内部语义解析对象。 内部对象
返回值 block_ptr 返回一个封装好的块指针对象。它包含了访问该多维块所需的所有元数据。 block_ptr
make_block_ptr 的核心功能是创建一个高级指针,定义如何从物理内存中"切"出一个逻辑上的多维块
1. 自动边界处理

传统的指针计算需要手动判断 if (offset < max_size)。而使用 make_block_ptr 后,搭配 tl.load(ptr, boundary_check=(...)),Triton 编译器会自动生成高效的边界检查逻辑,防止非法越界访问。

2. 硬件加速优化(TMA)

在 NVIDIA Hopper 架构(如 H100)上,块指针能直接映射到硬件级的 TMA (Tensor Memory Accelerator) 单元。这允许硬件自动异步地完成多维数据块的搬运,极大地降低了软件计算开销。

3. 布局(Order)的意义

order 参数决定了数据在内存中移动的"最快"维度。

  • (1, 0):行优先(Row-major),相邻地址在第 1 维移动。

  • (0, 1):列优先(Column-major),相邻地址在第 0 维移动。 正确配置 order 对实现合并访存(Coalesced Access)至关重要。

4. 典型用法示例
python 复制代码
# 定义一个 2D 矩阵的块指针
a_block_ptr = tl.make_block_ptr(
    base=a_ptr,                      # 矩阵首地址
    shape=(M, K),                    # 矩阵总大小
    strides=(stride_am, stride_ak),  # 步长
    offsets=(0, 0),                  # 从 (0,0) 开始读取
    block_shape=(BLOCK_M, BLOCK_K),  # 每次读取一个 BLOCK_M * BLOCK_K 的小块
    order=(1, 0)                     # 内存中是行优先存储
)

# 直接加载整个块,无需手动计算 offset 向量
a_tile = tl.load(a_block_ptr, boundary_check=(0, 1))

tl.load加载

Triton笔记1核心概念-CSDN博客中,我们介绍load的时候没有介绍Block Pointer的情况,现在来看一下

python 复制代码
"""

        (3) If `pointer` is a block pointer defined by `make_block_ptr`, a

            tensor is loaded.  In this case:

            - `mask` and `other` must be `None`, and

            - `boundary_check` and `padding_option` can be specified to control the behavior of out-of-bound access.


    :param boundary_check: tuple of integers, indicating the dimensions which should do the boundary check

    :type boundary_check: tuple of ints, optional

    :param padding_option: should be one of {"", "zero", "nan"}, the padding value to use while out of bounds. "" means an undefined value.

    """
条目 名称 作用 / 含义 类型/取值约束
输入参数 pointer 加载数据的源地址。在此场景下必须是 block_ptr block_ptr
输入参数 mask 逐元素的加载开关。 必须为 None (块指针自带范围信息)
输入参数 other 掩码为假时的填充值。 必须为 None (由 padding_option 代替)
输入参数 boundary_check 指明哪些维度需要进行越界检查。 tuple of int (如 (0, 1))
输入参数 padding_option 越界时的填充策略。 "", "zero", "nan"
在这种模式下,你不再需要手动创建一个布尔矩阵(Mask)来告诉编译器哪些位置合法,而是直接告诉编译器:"请检查第 N 维的边界"。

boundary_check 接收一个元组,里面的数字代表 make_block_ptr 中定义的维度索引。

  • boundary_check=(0,) :只检查行是否越界。如果块的某一部分超出了矩阵的总行数 M M M,则触发填充。

  • boundary_check=(0, 1):同时检查行和列。这在处理非对齐的矩阵末尾块时非常有用。

boundary_check 发现越界时,padding_option 参数决定补什么值:

  • "zero":越界部分填 0(最常用,尤其在矩阵乘法中,填 0 不影响加法累加)。

  • "nan":越界部分填 NaN。

  • "" (默认):不确定的值(通常是内存中原有的垃圾数据),速度最快但安全性最低。

triton.language.tensor_descriptor

triton.language.make_tensor_descriptor 是 Triton 为了更好地支持 NVIDIA Hopper 架构(及后续架构) 中的 TMA (Tensor Memory Accelerator) 硬件特性而引入的高级接口。

它是 make_block_ptr 的继任者和增强版。虽然它们都用于描述多维数据块,但 tensor_descriptor 在设计上更加贴合底层硬件的搬运协议。

python 复制代码
@builtin
def make_tensor_descriptor(
    base: tensor,
    shape: List[tensor],
    strides: List[tensor],
    block_shape: List[constexpr],
    padding_option="zero",
    _semantic=None,
) -> tensor_descriptor:
    """Make a tensor descriptor object

    :param base: the base pointer of the tensor, must be 16-byte aligned
    :param shape: A list of non-negative integers representing the tensor shape
    :param strides: A list of tensor strides. Leading dimensions must be multiples
        of 16-byte strides and the last dimension must be contiguous.
    :param block_shape: The shape of block to be loaded/stored from global memory

    Notes
    *****
    On NVIDIA GPUs with TMA support, this will result in a TMA descriptor object
    and loads and stores from the descriptor will be backed by the TMA hardware.

    Currently only 2-5 dimensional tensors are supported.

    Example
    *******
    .. code-block:: python

        @triton.jit
        def inplace_abs(in_out_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr):
            desc = tl.make_tensor_descriptor(
                in_out_ptr,
                shape=[M, N],
                strides=[N, 1],
                block_shape=[M_BLOCK, N_BLOCK],
            )

            moffset = tl.program_id(0) * M_BLOCK
            noffset = tl.program_id(1) * N_BLOCK

            value = desc.load([moffset, noffset])
            desc.store([moffset, noffset], tl.abs(value))

        # TMA descriptors require a global memory allocation
        def alloc_fn(size: int, alignment: int, stream: Optional[int]):
            return torch.empty(size, device="cuda", dtype=torch.int8)

        triton.set_allocator(alloc_fn)

        M, N = 256, 256
        x = torch.randn(M, N, device="cuda")
        M_BLOCK, N_BLOCK = 32, 32
        grid = (M / M_BLOCK, N / N_BLOCK)
        inplace_abs[grid](x, M, N, M_BLOCK, N_BLOCK)

    """

    padding_option = _unwrap_if_constexpr(padding_option)
    return _semantic.make_tensor_descriptor(base, shape, strides, block_shape, padding_option)
条目 名称 作用 / 含义 约束条件
输入参数 base 张量的基地址指针。 必须 16 字节对齐(由 TMA 硬件要求)。
输入参数 shape 完整父张量的形状列表(如 [M, N])。 支持 2 到 5 维。
输入参数 strides 张量的步长列表。 前导维度须 16 字节对齐,最后维度通常连续(stride 为 1)。
输入参数 block_shape 每次加载/存储的数据块形状。 必须是 tl.constexpr 类型。
输入参数 padding_option 越界时的填充策略。默认 "zero" 目前主要支持 "zero"
返回值 desc 返回一个张量描述符对象 tensor_descriptor 类型。

函数功能详解

make_tensor_descriptor 的核心目的是创建一个可以被硬件直接理解的"访存模板"

1. 硬件级抽象:TMA 驱动

在之前的 make_block_ptr 模式下,地址计算往往还是由 GPU 的计算单元(CUDA Cores)完成的。而 make_tensor_descriptor 生成的对象会尽可能映射到 NVIDIA 的 TMA 硬件

  • 异步性:TMA 可以在不占用计算线程的情况下,在后台自动完成全局内存(Global Memory)到共享内存(Shared Memory)的数据搬运。

  • 零开销寻址:多维索引到线性地址的转换由硬件电路完成,不消耗寄存器和指令。

2. "解耦"的思想

make_block_ptr 不同,描述符(Descriptor)将数据的几何结构(Shape/Stride)具体的访问坐标(Offsets)彻底解耦了:

  • make_tensor_descriptor 定义了"地块有多大,怎么排布"。

  • 在实际使用时,调用 desc.load([offset_m, offset_n]) 才指定"我现在要搬哪一块"。

3. 严格的对齐要求

由于 TMA 硬件是为极致带宽设计的,它对内存地址非常挑剔:

  • 对齐base 指针如果不满足 16 字节对齐,程序会直接报错或崩溃。

  • 步长:行与行之间的跨度(Strides)也需要满足对齐要求,以确保硬件搬运效率。

使用方法对比:Descriptor vs Block Pointer

当你拿到一个描述符 desc 后,它的用法比以前更简洁:

python 复制代码
# 1. 创建描述符 (类似定义一个模板)
desc = tl.make_tensor_descriptor(
    base=ptr, 
    shape=[M, N], 
    strides=[N, 1], 
    block_shape=[BLOCK_M, BLOCK_N]
)

# 2. 计算当前 Block 的起点坐标
moffset = tl.program_id(0) * BLOCK_M
noffset = tl.program_id(1) * BLOCK_N

# 3. 直接加载,传入坐标即可
# 不需要像以前那样用 tl.advance 手动挪指针
tile = desc.load([moffset, noffset])

# 4. 同样支持直接存储
desc.store([moffset, noffset], tl.abs(tile))

总结与建议

  • 如果你在 H100 (SM90) 或更新的显卡上开发 :强烈建议优先使用 make_tensor_descriptor。它能解锁 TMA 的异步搬运能力,显著降低计算单元的访存压力。

  • 如果你需要多维(3D-5D)支持:该函数提供了比以往更规范的 N-Dim 支持。

  • 注意点 :使用此函数通常需要配合自定义的内存分配器(如示例中的 alloc_fn),以确保分配出的地址满足硬件对齐要求。

triton.language.maximum

triton.language.maximum 是一个逐元素(Element-wise)比较函数,用于从两个输入张量中挑选出较大的值。在深度学习中,它常用于实现 ReLU 激活函数或处理数值边界。

条目 名称 作用 / 含义 类型/取值约束
输入参数 x 第一个输入张量。 Tensor (Block)
输入参数 y 第二个输入张量。 Tensor (Block),形状须与 x 匹配或可广播
输入参数 propagate_nan 控制如何处理 NaN(非数字)值。 tl.PropagateNan 枚举值
输入参数 _semantic Triton 内部语义解析对象。 内部对象
返回值 output 返回逐元素比较后的最大值张量 与输入形状相同的 Tensor

函数功能详解

tl.maximum 执行的是 C i j = max ⁡ ( A i j , B i j ) C_{ij} = \max(A_{ij}, B_{ij}) Cij=max(Aij,Bij) 的操作。它与 Python 原生的 max() 不同,它是针对整个数据块(Block)在硬件级并行执行的。

1. NaN 传播策略 (propagate_nan)

这是该函数最专业的一个参数。在 GPU 计算中,当遇到 NaN 时,不同的处理方式会影响结果的稳定性:

  • tl.PropagateNan.NONE (默认) :如果其中一个是 NaN,结果通常倾向于返回另一个非 NaN 的数字(具体行为取决于底层硬件指令,通常遵循 IEEE 754 标准的 fmax)。

  • tl.PropagateNan.ALL :只要 xy 中有一个是 NaN,结果就一定是 NaN

2. 硬件映射

在 NVIDIA GPU 上,对于 float 类型,这个函数通常映射为 FMNMX 指令或类似的算术单元操作,速度极快。对于 int 类型,则映射为整数比较指令。

典型应用场景

场景 A:实现 ReLU 激活函数

ReLU 的公式是 y = max ⁡ ( 0 , x ) y = \max(0, x) y=max(0,x)。在 Triton 中这非常高效:

python 复制代码
# 假设 x 是输入张量
zero = 0.0
# 这里的 0.0 会被广播到和 x 相同的形状
relu_out = tl.maximum(x, zero)
场景 B:数值裁剪(Clamping)

虽然 Triton 有 tl.clamp,但你也可以组合使用 maximumminimum

Python

python 复制代码
# 将数据限制在 [0, 1] 之间
out = tl.maximum(0.0, tl.minimum(1.0, x))

注意:tl.maximum vs tl.max

这是初学者最容易搞混的地方:

  • tl.maximum(x, y)逐元素比较。输入两个张量,输出一个张量。

    例子:max([1, 5], [2, 3]) -> [2, 5]

  • tl.max(x, axis=...)规约(Reduction)操作。在一个张量内部找最大值,会减少维度。

    例子:max([1, 5]) -> 5

triton.language.max

与上一个 maximum 不同,它是一个规约(Reduction)函数 。它的作用不是对比两个张量,而是在同一个张量内部沿着指定的维度寻找最大值。

条目 名称 作用 / 含义 类型约束
输入参数 input 需要进行求最大值操作的输入张量。 Tensor
输入参数 axis 进行规约的维度索引。如果不传,则对全量数据求最大值。 intNone
输入参数 return_indices 是否同时返回最大值所在的索引(即类似 argmax 的功能)。 bool
输入参数 return_indices_tie_break_left 当有多个相同的最大值时,是否返回坐标最左边(较小)的索引。 bool
输入参数 keep_dims 是否保留原张量的维度(设为 True 则被规约的维度长度为 1)。 bool
返回值 res / (res, idx) 返回最大值张量 。如果 return_indices 为 True,则返回包含最大值和索引的元组。 Tensortuple(Tensor, Tensor)

函数功能详解

tl.max 的核心功能是维度压缩。它通过比较指定轴上的所有元素,最终只保留该轴上的最大值。

1. 自动精度提升(Precision Promotion)

从源码中可以看到一个细节:

  • 如果输入是 bfloat16,它会自动转为 float32

  • 如果输入的位宽小于 32 位(如 float16int16),它会强制提升到 float32int32 进行计算。

  • 原因 :规约操作涉及大量两两比较,提升精度可以减少累积误差(虽然对于 max 影响较小,但对于 sum 等规约至关重要,Triton 对此类操作保持了统一的精度策略)。

2. Argmax 功能集成

通过设置 return_indices=True,这个函数可以一次性完成"找最大值"和"找最大值位置"两个任务。

  • 这在深度学习中非常有用,例如在 Softmax 层找 max(x),或者在分类任务中提取预测标签。
3. 维度控制 (keep_dims)
  • 如果输入形状是 [16, 32]

    • tl.max(input, axis=1, keep_dims=False) -> 输出形状 [16]

    • tl.max(input, axis=1, keep_dims=True) -> 输出形状 [16, 1]

典型应用场景

场景 A:Softmax 中的数值稳定性

为了防止计算 e x e^x ex 时溢出,通常会减去行内的最大值。

python 复制代码
# 假设 x 形状为 [BLOCK_M, BLOCK_N]
# 沿着第 1 维(列方向)找每行的最大值
row_max = tl.max(x, axis=1) # 结果形状 [BLOCK_M]
# 利用广播机制减去最大值
safe_x = x - row_max[:, None] 
场景 B:获取分类预测结果
python 复制代码
# logits 形状为 [BATCH, CLASSES]
# 同时获取预测的概率最大值和所属类别索引
max_val, labels = tl.max(logits, axis=1, return_indices=True)

固定Q块,遍历KV

python 复制代码
# 这个函数不是完整 attention kernel
# 它负责固定一块 Q,也就是 BLOCK_M 行 query,然后分块遍历 K/V 的 BLOCK_N 列,
# 用 online softmax 的方式逐块更新 acc、l_i、m_i。
@triton.jit
def _attn_fwd_inner(
    acc,          # [BLOCK_M, HEAD_DIM],当前已经累积的 attention 输出分子,即 sum(P_ij * V_j),使用 float32 累加。
    l_i,          # [BLOCK_M],online softmax 的归一化分母,表示每一行目前累计的 exp 和。
    m_i,          # [BLOCK_M],online softmax 的每一行当前最大值,用于数值稳定。
    q,            # [BLOCK_M, HEAD_DIM],当前 Q block,外层 kernel 已经加载进来,会在整个 inner 循环中复用。
    K_block_ptr,  # Triton block pointer,指向 K 矩阵当前要读取的块;K 在这里按 [HEAD_DIM, N_CTX] 的逻辑形状读取。
    V_block_ptr,  # Triton block pointer,指向 V 矩阵当前要读取的块;V 按 [N_CTX, HEAD_DIM] 的逻辑形状读取。
    start_m,      # 当前 Q block 的 block id,对应第 start_m 个 BLOCK_M query 块。
    qk_scale,     # softmax 前的缩放系数,通常是 sm_scale * 1/log(2),因为这里用 exp2 而不是 exp。
    BLOCK_M: tl.constexpr,  # 一个 program 处理多少行 Q;编译期常量,影响 tile shape 和编译优化。
    HEAD_DIM: tl.constexpr, # 每个 head 的维度,也就是 Q/K/V 最后一维大小;编译期常量。
    BLOCK_N: tl.constexpr,  # 每次循环处理多少列 K/V,也就是多少个 key/value token;编译期常量。
    STAGE: tl.constexpr,    # 当前 inner 处理哪一段注意力区域:1 表示 causal 对角线左侧,2 表示 causal 对角线所在块,其他表示非 causal 全范围。
    offs_m: tl.constexpr,   # [BLOCK_M],当前 Q block 内每一行对应的全局 token 下标。
    offs_n: tl.constexpr,   # [BLOCK_N],当前 K/V block 内的局部列偏移,通常是 0 到 BLOCK_N-1。
    N_CTX: tl.constexpr,    # 序列长度,也就是 token 数;Q/K/V 在序列维度上的长度。
    fp8_v: tl.constexpr     # V 是否是 float8e5 类型;如果是,P 也会转成 fp8 以匹配 dot 的输入类型。
):
    # 根据 STAGE 决定当前 inner loop 需要遍历 K/V 的哪一段。
    # 对 causal attention 来说,每个 Q block 可以分成:
    # - off-band:严格在当前 Q block 左边的 K/V,可以无 mask 计算;
    # - on-band:和当前 Q block 对角线相交的 K/V,需要 causal mask。
    if STAGE == 1:
        # STAGE == 1:处理当前 Q block 左边的所有 K/V token。
        # 对第 start_m 个 Q block,它的起始行是 start_m * BLOCK_M。
        # 因此它左侧可见的 K/V 范围是 [0, start_m * BLOCK_M)。
        lo, hi = 0, start_m * BLOCK_M
    elif STAGE == 2:
        # STAGE == 2:处理当前 Q block 自己对应的对角线区域。
        # 范围是 [start_m * BLOCK_M, (start_m + 1) * BLOCK_M)。
        # 这一段需要 causal mask,因为 block 内有些 query 不能看到未来 key。
        lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
        # 告诉 Triton 编译器 lo 是 BLOCK_M 的倍数,帮助生成更好的地址计算代码。
        lo = tl.multiple_of(lo, BLOCK_M)
    else:
        # 非 causal 情况:每个 Q 都可以看见全部 K/V,所以范围是 [0, N_CTX)。
        lo, hi = 0, N_CTX
    # 把 K block pointer沿着K的序列维度移动到lo的位置
    # K 的逻辑 shape 是 [HEAD_DIM, N_CTX],所以第二维是 token 维,advance 参数是 (行偏移, 列偏移)。
    K_block_ptr = tl.advance(K_block_ptr,(0,lo))
    # 把 V block pointer 沿着 V 的序列维度移动到 lo 位置。
    # V 的逻辑 shape 是 [N_CTX, HEAD_DIM],所以第一维是 token 维,advance 参数是 (行偏移, 列偏移)。
    V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
    # 从 lo 到 hi,每次处理 BLOCK_N 个 K/V token。
    for start_n in tl.range(lo,hi,BLOCK_N):
        # 告诉编译器 start_n 是 BLOCK_N 的倍数,有利于优化内存访问和向量化。
        start_n = tl.multiple_of(start_n,BLOCK_N)
        # 加载当前 K tile。
        # K_block_ptr 的 block_shape 是 [HEAD_DIM, BLOCK_N]。
        # 所以 k 的形状是 [HEAD_DIM, BLOCK_N]。
        k = tl.load(K_block_ptr)
        # 计算 QK^T 的一个 tile。
        # q 形状:[BLOCK_M, HEAD_DIM]
        # k 形状:[HEAD_DIM, BLOCK_N]
        # qk 形状:[BLOCK_M, BLOCK_N]
        # qk[i, j] 表示当前 query 行 i 和当前 key 列 j 的点积。
        qk = tl.dot(q, k)
        if STAGE == 2:
            # causal 对角线块需要 mask。
            # offs_m[:, None] 形状是 [BLOCK_M, 1],表示 query 的全局 token 下标。
            # start_n + offs_n[None, :] 形状是 [1, BLOCK_N],表示 key 的全局 token 下标。
            # query_index >= key_index 表示当前 query 可以看到该 key。
            # 该mask的形状是 [BLOCK_M,BLOCK_N],是一个下三角矩阵的部分列区域
            mask = offs_m[:, None] >= (start_n + offs_n[None, :])

            # 先乘 qk_scale,把点积缩放到 softmax logits。
            # 对不满足 causal 条件的位置加上 -1e6,近似 -inf,使 softmax 后概率接近 0。
            qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)

            # m_ij 是合并当前旧最大值 m_i 和当前 qk tile 每行最大值之后的新最大值。
            # tl.max(qk, 1) 沿列方向取最大值,即每行的最大值,得到 [BLOCK_M]。
            m_ij = tl.maximum(m_i, tl.max(qk, 1))

            # 把 qk 减去每行最大值 m_ij,保证 exp2 的输入不会太大,提升数值稳定性。
            # 这是softmax经典手法
            qk -= m_ij[:, None]
        else:
            # 非对角线块直接计算,无需 mask
            m_ij=tl.maximum(m_i,tl.max(qk,1)*qk_scale)
            qk=qk*qk_scale-m_ij[:,None]
        # 计算 P = exp2(QK^T - max)
        # 注意:Triton 中 exp2 通常比 exp 更快,qk_scale 在外部预先乘了 1/log(2)
        p=tl.math.exp2(qk)
        # 计算当前块的归一化分母
        l_ij=tl.sum(p,1)
        # --- Online Softmax 核心更新逻辑 ---
        # 计算新旧最大值之间的修正系数 alpha = 2^(m_old - m_new)
        alpha=tl.math.exp2(m_i-m_ij)
        # 更新分母:新分母 = 旧分母 * alpha + 当前块的分母
        l_i=l_i*alpha+l_ij
        # 更新分子(累加器):旧累加值也需要缩放以对齐新的最大值基准
        acc=acc*alpha[:,None]
        # 加载 V 块并更新累加器
        v=tl.load(V_block_ptr)
        # 将 P 转换为 V 的类型(fp16 或 fp8)以进行点积
        if fp8_v:
            p=p.to(tl.float8e5)
        else:
            p=p.to(tl.float16)
        # acc = acc + P * V
        acc=tl.dot(p,v,acc)
        # 更新当前行的状态变量,进入下一个 K 块
        m_i=m_ij
        V_block_ptr=tl.advance(V_block_ptr,(BLOCK_N,0))
        K_block_ptr=tl.advance(K_block_ptr,(0,BLOCK_N))
    return acc,l_i,m_i

Causal 模式"(因果模式)

Causal 模式"(因果模式)是自注意力机制(Self-Attention)在处理生成式任务(如 GPT 聊天机器人)时的核心逻辑。

简单来说:Causal 模式就是"不许看答案"模式。

在处理序列数据时,每个 Token(单词或字符)都在尝试关注其他 Token。

  • 非 Causal 模式(如 BERT):每个单词都能看到整个句子。比如"我 爱 你",计算"我"的时候,它知道后面跟着"爱"和"你"。

  • Causal 模式(如 GPT) :每个单词只能看到它自己和它之前的单词。计算"爱"的时候,它只能看到"我",绝对不能看到后面的"你"。

矩阵几何表示:下三角矩阵

如果我们将 Q × K T Q \times K^T Q×KT 的结果看作一个方阵,行表示 Query(当前单词),列表示 Key(被关注的单词):

  • 对角线(Diagonal):单词关注它自己。

  • 下三角(Lower Triangle):单词关注它之前的单词(允许访问)。

  • 上三角(Upper Triangle) :单词关注它之后的单词(被 Mask 掉,设为 − ∞ -\infty −∞)。

在内部函数里,stage有三个状态

STAGE 对应区域 说明
STAGE 1 Off-band (全下三角) 这块 K K K 块完全在当前 Q Q Q 块的左边。所有 Key 都在过去,无需 Mask,直接全速计算。
STAGE 2 On-band (对角线块) 这块 K K K 块正好和 Q Q Q 块的时间线重叠。必须进行逐元素的 Causal Mask
STAGE 3 Future (上三角) 这块 K K K 块完全在未来

调用接口

python 复制代码
# 尝试不同的分块策略 (BM, BN) 以及并行度 (num_warps)
configs = [
    triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w) \
    for BM in [64, 128]\
    for BN in [32, 64]\
    for s in ([1] if is_hip() else [3, 4, 7])\
    for w in [4, 8]\
]




def keep(conf):
    """
    过滤掉不合理的配置。
    例如:如果 Tile 太小但 Warp 太多,会导致计算资源浪费和同步开销过大。
    """
    BLOCK_M = conf.kwargs["BLOCK_M"]
    BLOCK_N = conf.kwargs["BLOCK_N"]
    if BLOCK_M * BLOCK_N < 128 * 128 and conf.num_warps == 8:
        return False
    return True

@triton.autotune(list(filter(keep, configs)), key=["N_CTX", "HEAD_DIM"])
@triton.jit
def _attn_fwd(
    Q, K, V, sm_scale, M, Out, 
    stride_qz, stride_qh, stride_qm, stride_qk, # Q 的步长(Batch, Head, Seq, Dim)
    stride_kz, stride_kh, stride_kn, stride_kk, # K 的步长
    stride_vz, stride_vh, stride_vk, stride_vn, # V 的步长
    stride_oz, stride_oh, stride_om, stride_on, # Output 的步长
    Z, H, N_CTX,                                # Batch, Heads, 全局序列长度
    HEAD_DIM: tl.constexpr, 
    BLOCK_M: tl.constexpr, 
    BLOCK_N: tl.constexpr, 
    STAGE: tl.constexpr                         # 这里的 STAGE 决定了是否是 Causal 模式
):
    tl.static_assert(BLOCK_N<=HEAD_DIM)
    # --- 1. 计算当前程序(Program)在网格中的位置 ---
    start_m = tl.program_id(0)      # 处理第几个 Q 块
    off_hz = tl.program_id(1)       # 处理哪一个 Batch 和哪一个 Head
    off_z = off_hz // H             # 当前 Batch ID
    off_h = off_hz % H              # 当前 Head ID
    # 计算该 Batch/Head 对应的基础内存偏移
    qkv_offset=off_z.to(tl.int64)*stride_qz+off_h.to(tl.int64)
    # --- 2. 初始化 Block Pointer (块指针) ---
    # 块指针是 Triton 2.0+ 的特性,它能自动处理边界检查和复杂的步长计算
    # Q 块:从 (start_m * BLOCK_M, 0) 开始,形状为 (BLOCK_M, HEAD_DIM)
    Q_block_ptr = tl.make_block_ptr(
        base=Q + qkv_offset,
        shape=(N_CTX, HEAD_DIM),
        strides=(stride_qm, stride_qk),
        offsets=(start_m * BLOCK_M, 0),
        block_shape=(BLOCK_M, HEAD_DIM),
        order=(1, 0),
    )
    # V 块:形状 (BLOCK_N, HEAD_DIM)
    # 对于 fp8 格式,V 的存储顺序(order)可能不同以适配硬件加速
    v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0)
    V_block_ptr = tl.make_block_ptr(
        base=V + qkv_offset,
        shape=(N_CTX, HEAD_DIM),
        strides=(stride_vk, stride_vn),
        offsets=(0, 0),
        block_shape=(BLOCK_N, HEAD_DIM),
        order=v_order,
    )
    # K 块:注意这里 K 是被转置读取的 (HEAD_DIM, BLOCK_N),方便与 Q 做点积
    K_block_ptr = tl.make_block_ptr(
        base=K + qkv_offset,
        shape=(HEAD_DIM, N_CTX),
        strides=(stride_kk, stride_kn),
        offsets=(0, 0),
        block_shape=(HEAD_DIM, BLOCK_N),
        order=(0, 1),
    )
    # Output 块指针
    O_block_ptr = tl.make_block_ptr(
        base=Out + qkv_offset,
        shape=(N_CTX, HEAD_DIM),
        strides=(stride_om, stride_on),
        offsets=(start_m * BLOCK_M, 0),
        block_shape=(BLOCK_M, HEAD_DIM),
        order=(1, 0),
    )
    # --- 3. 初始化 Online Softmax 状态变量 ---
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)# 每一行的全局索引
    offs_n = tl.arange(0, BLOCK_N)# 每一列的局部索引

    # 初始化 m 和 l 的指针
    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")# 最大值初始化为负无穷
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 # 分母初始化(稍后会被修正)
    acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)# 结果累加器
    # 加载系数
    qk_scale = sm_scale
    # FlashAttention 使用 base 2 指数运算
    qk_scale *= 1.44269504  # 1/log(2)
    # 预加载 Q 块到 SRAM,因为它在遍历 K/V 时保持不变
    q = tl.load(Q_block_ptr)
    # --- 5. 分阶段调用内层循环 ---
    # 使用位运算判断 STAGE。
    # 如果 STAGE & 1 为真(例如非 Causal 模式 STAGE=1,或者 Causal 模式 STAGE=3),处理常规块。
    if STAGE & 1:
        # 在 Causal 模式下(STAGE=3),这里执行的是 STAGE=1 的逻辑(左侧全量块)
        acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr,  #
                                        start_m, qk_scale,  #
                                        BLOCK_M, HEAD_DIM, BLOCK_N,  #
                                        4 - STAGE, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5  #
                                        )
    # 如果 STAGE & 2 为真(Causal 模式 STAGE=2 或 3),处理对角线块
    if STAGE & 2:
        acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr,  #
                                        start_m, qk_scale,  #
                                        BLOCK_M, HEAD_DIM, BLOCK_N,  #
                                        2, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5  #
                                        )
    # --- 6. 结尾处理(Epilogue) ---
    # 最终输出为 acc / l_i
    # 计算最终的 log-sum-exp,存入 M 矩阵,后续 Backward 阶段会用到
    # 这里的m_i最终存的是一个教log-sum-exp(LSE)的东西,它在反向传播的时候有用
    m_i += tl.math.log2(l_i)#LSE=max(x)+log(li),这里li是2为底的,所以用log2
    acc = acc / l_i[:, None]
    m_ptrs = M + off_hz * N_CTX + offs_m
    tl.store(m_ptrs, m_i)
    tl.store(O_block_ptr, acc.to(Out.type.element_ty))

模型层定义

python 复制代码
class _attention(torch.autograd.Function):


    @staticmethod
    def forward(ctx, q, k, v, causal, sm_scale):
        # 形状约束
        HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1]
        # 当 v 在 float8_e5m2 格式下时,进行转置。
        HEAD_DIM_V = v.shape[-1]
        assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
        assert HEAD_DIM_K in {16, 32, 64, 128, 256}
        o = torch.empty_like(q)
        stage = 3 if causal else 1
        extra_kern_args = {}
        # Tuning for AMD target
        # 为 AMD 设备调整
        if is_hip():
            waves_per_eu = 3 if HEAD_DIM_K <= 64 else 2
            extra_kern_args = {"waves_per_eu": waves_per_eu, "allow_flush_denorm": True}


        grid = lambda args: (triton.cdiv(q.shape[2], args["BLOCK_M"]), q.shape[0] * q.shape[1], 1)
        M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
        _attn_fwd[grid](
            q, k, v, sm_scale, M, o,  #
            q.stride(0), q.stride(1), q.stride(2), q.stride(3),  #
            k.stride(0), k.stride(1), k.stride(2), k.stride(3),  #
            v.stride(0), v.stride(1), v.stride(2), v.stride(3),  #
            o.stride(0), o.stride(1), o.stride(2), o.stride(3),  #
            q.shape[0], q.shape[1],  #
            N_CTX=q.shape[2],  #
            HEAD_DIM=HEAD_DIM_K,  #
            STAGE=stage,  #
            **extra_kern_args)


        ctx.save_for_backward(q, k, v, o, M)
        ctx.grid = grid
        ctx.sm_scale = sm_scale
        ctx.HEAD_DIM = HEAD_DIM_K
        ctx.causal = causal
        return o

测试代码

python 复制代码
try:
    from flash_attn.flash_attn_interface import \
        flash_attn_qkvpacked_func as flash_attn_func
    HAS_FLASH = True
except BaseException:
    HAS_FLASH = False

TORCH_HAS_FP8 = hasattr(torch, 'float8_e5m2')
BATCH, N_HEADS, HEAD_DIM = 4, 32, 64
# vary seq length for fixed head and batch=4
# 对于固定的 head 和 batch 为 4,变化序列长度。
configs = []
for mode in ["fwd"]:
    for causal in [True, False]:
        configs.append(
            triton.testing.Benchmark(
                x_names=["N_CTX"],
                x_vals=[2**i for i in range(10, 15)],
                line_arg="provider",
                line_vals=["triton-fp16"] + (["triton-fp8"] if TORCH_HAS_FP8 else []) +
                (["flash"] if HAS_FLASH else []),
                line_names=["Triton [FP16]"] + (["Triton [FP8]"] if TORCH_HAS_FP8 else []) +
                (["Flash-2"] if HAS_FLASH else []),
                styles=[("red", "-"), ("blue", "-")],
                ylabel="ms",
                plot_name=f"fused-attention-batch{BATCH}-head{N_HEADS}-d{HEAD_DIM}-{mode}-causal={causal}",
                args={
                    "H": N_HEADS,
                    "BATCH": BATCH,
                    "HEAD_DIM": HEAD_DIM,
                    "mode": mode,
                    "causal": causal,
                },
            ))




@triton.testing.perf_report(configs)
def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, device="cuda"):
    assert mode in ["fwd"]
    warmup = 25
    rep = 100
    dtype = torch.float16
    if "triton" in provider:
        q = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True)
        k = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True)
        v = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True)
        if mode == "fwd" and "fp8" in provider:
            q = q.to(torch.float8_e5m2)
            k = k.to(torch.float8_e5m2)
            v = v.permute(0, 1, 3, 2).contiguous()
            v = v.permute(0, 1, 3, 2)
            v = v.to(torch.float8_e5m2)
        sm_scale = 1.3
        fn = lambda: attention(q, k, v, causal, sm_scale)
        ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
    if provider == "flash":
        qkv = torch.randn((BATCH, N_CTX, 3, H, HEAD_DIM), dtype=dtype, device=device, requires_grad=True)
        fn = lambda: flash_attn_func(qkv, causal=causal)
        ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
    flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * HEAD_DIM
    total_flops = 2 * flops_per_matmul
    if causal:
        total_flops *= 0.5
    return total_flops / ms * 1e-9




if __name__ == "__main__":
    # only works on post-Ampere GPUs right now
    # 目前只适用于安培架构 GPU。
    bench_flash_attention.run(save_path=".", print_data=True)

测试结果

causal=False

causal=True

相关推荐
HZ·湘怡17 小时前
数据结构之排序算法 (1)--插入排序
c语言·数据结构·算法·排序算法
开源Z17 小时前
LeetCode 238 · 除自身以外数组的乘积:左右两遍扫描,不用除法
算法·leetcode
人工智能培训17 小时前
用知识图谱重构搜索引擎
大数据·人工智能·3d·重构·知识图谱·agent
Aipollo17 小时前
【知识图谱】语义本体的演进之路:从严谨到敏捷的范式转变
人工智能·知识图谱
制造业的搬运工17 小时前
低价PCB板藏隐患,如何选对线路板厂家?
人工智能·制造·pcb工艺·pcb
润乾软件17 小时前
从 GUI 到 LUI 的进化,报表工具也有了 Copilot
人工智能·ai·copilot·报表工具
小程故事多_8017 小时前
从工程视角拆解Claude Code,读懂工业级AI智能体Harness架构
人工智能·架构
DS随心转插件17 小时前
AI导出鸭深度评测:DeepSeek 转 Word 实战表现与边界分析
人工智能·ai·word·文心一言·deepseek·ai导出鸭
花间相见17 小时前
【AI工作流搭建】—— n8n 自部署完全指南:从 Docker 安装到第一个自动化工作流
人工智能·ai编程·n8n