[Triton笔记2]自动调优和共享内存

本文将用Triton重写以下的cuda代码,并重点关注自动调优和Triton中Thread Block相关的内容,包括共享内存,显式同步。

代码是一个简单的矢量点乘。甚至都没有Block间的规约,只有Block内的规约。

cpp 复制代码
#ifndef __CUDACC__
#define __CUDACC__
#endif
#include <iostream>
#include <cuda_runtime.h>
#include "device_launch_parameters.h"
 
#define threadsPerBlock 256
const int Blocks = 32;
const int N = Blocks * threadsPerBlock;
 
__global__ void dot(float* a, float* b, float* c) {
    __shared__ float cache[threadsPerBlock];
 
    int tid = threadIdx.x + blockIdx.x * blockDim.x;
    int cacheIndex = threadIdx.x;
 
    float temp = 0;
    if (tid < N) {
        temp = a[tid] * b[tid];
    }
    cache[cacheIndex] = temp;
    __syncthreads();
 
    // 并行归约求和
    for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
        if (cacheIndex < stride) {
            cache[cacheIndex] += cache[cacheIndex + stride];
        }
        __syncthreads();
    }
 
    // 将每个 block 的结果写入全局内存
    if (cacheIndex == 0) {
        c[blockIdx.x] = cache[0];
    }
}
 
int main() {
    // 1. 在主机分配内存
    float* a = new float[N];
    float* b = new float[N];
 
    // 初始化向量数据
    for (int i = 0; i < N; ++i) {
        a[i] = 1.0f;  // 举例,全1
        b[i] = 2.0f;  // 举例,全2
    }
 
    // 2. 设备内存指针
    float* dev_a, * dev_b, * dev_c;
 
    // 3. 分配设备内存
    cudaMalloc((void**)&dev_a, N * sizeof(float));
    cudaMalloc((void**)&dev_b, N * sizeof(float));
    cudaMalloc((void**)&dev_c, Blocks * sizeof(float));  // 每个 block 一个结果
 
    // 4. 复制数据到设备
    cudaMemcpy(dev_a, a, N * sizeof(float), cudaMemcpyHostToDevice);
    cudaMemcpy(dev_b, b, N * sizeof(float), cudaMemcpyHostToDevice);
 
    // 5. 启动核函数
    dot << <Blocks, threadsPerBlock >> > (dev_a, dev_b, dev_c);
 
    // 等待 GPU 完成
    cudaDeviceSynchronize();
 
    // 6. 从设备拷回部分和
    float* partial_sums = new float[Blocks];
    cudaMemcpy(partial_sums, dev_c, Blocks * sizeof(float), cudaMemcpyDeviceToHost);
 
    // 7. CPU 端归约
    float final_sum = 0;
    for (int i = 0; i < Blocks; ++i) {
        final_sum += partial_sums[i];
    }
 
    // 输出结果
    std::cout << "Dot product result: " << final_sum << std::endl;
 
    // 8. 释放内存
    delete[] a;
    delete[] b;
    delete[] partial_sums;
    cudaFree(dev_a);
    cudaFree(dev_b);
    cudaFree(dev_c);
 
    return 0;
}

triton.autotune

python 复制代码
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, pre_hook=None, post_hook=None,
             warmup=None, rep=None, use_cuda_graph=False, do_bench=None, cache_results=False):
    """
    Decorator for auto-tuning a :code:`triton.jit`'d function.

    .. highlight:: python
    .. code-block:: python

        @triton.autotune(configs=[
            triton.Config(kwargs={'BLOCK_SIZE': 128}, num_warps=4),
            triton.Config(kwargs={'BLOCK_SIZE': 1024}, num_warps=8),
          ],
          key=['x_size'] # the two above configs will be evaluated anytime
                         # the value of x_size changes
        )
        @triton.jit
        def kernel(x_ptr, x_size, BLOCK_SIZE: tl.constexpr):
            ...
    :note: When all the configurations are evaluated, the kernel will run multiple times.
           This means that whatever value the kernel updates will be updated multiple times.
           To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
           resets the value of the provided tensor to `zero` before running any configuration.

    If the environment variable :code:`TRITON_PRINT_AUTOTUNING` is set to
    :code:`"1"`, Triton will print a message to stdout after autotuning each
    kernel, including the time spent autotuning and the best configuration.

    :param configs: a list of :code:`triton.Config` objects
    :type configs: list[triton.Config]
    :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
    :type key: list[str]
    :param prune_configs_by: a dict of functions that are used to prune configs, fields:
        'perf_model': performance model used to predicate running time with different configs, returns running time
        'top_k': number of configs to bench
        'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs.
    :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
    :type reset_to_zero: list[str]
    :param restore_value: a list of argument names whose value will be restored after evaluating any configs.
    :type restore_value: list[str]
    :param pre_hook: a function that will be called before the kernel is called.
        This overrides the default pre_hook used for 'reset_to_zero' and 'restore_value'.
        'kwargs': a dict of all arguments passed to the kernel.
        'reset_only': a boolean indicating whether the pre_hook is called to reset the values only, without a corresponding post_hook.
    :type pre_hook: lambda args, reset_only
    :param post_hook: a function that will be called after the kernel is called.
        This overrides the default post_hook used for 'restore_value'.
        'kwargs': a dict of all arguments passed to the kernel.
        'exception': the exception raised by the kernel in case of a compilation or runtime error.
    :type post_hook: lambda args, exception
    :param warmup: warmup time (in ms) to pass to benchmarking (deprecated).
    :type warmup: int
    :param rep: repetition time (in ms) to pass to benchmarking (deprecated).
    :type rep: int
    :param do_bench: a benchmark function to measure the time of each run.
    :type do_bench: lambda fn, quantiles
    :param cache_results: whether to cache autotune timings to disk.  Defaults to False.
    "type cache_results: bool
    """

    def decorator(fn):
        return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook,
                         post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep,
                         use_cuda_graph=use_cuda_graph, do_bench=do_bench, cache_results=cache_results)

    return decorator

它的作用是:针对同一个 Kernel,给定一组可选的硬件参数组合(Configs) ,Triton 会在程序运行时真实地测量每一组配置的耗时,然后动态选择最快的那一个

  • 编译期缓存:一旦选出最佳配置,它会把结果存起来,下次遇到同样规模的输入直接秒开。

  • 解耦硬件 :你不需要在代码里写死 BLOCK_SIZE = 128。你的代码在 A100 上可能选 1024,在 3090 上可能选 256。

参数一一解释

根据提供的源码注释,我们可以把参数分为 "核心配置"、"触发条件"和"副作用清理" 三类:

A. 核心配置类
参数名 解释
configs 候选名单 。一个 triton.Config 列表。每个 Config 包含 BLOCK_SIZEnum_warpsnum_stages 等参数。编译器会逐个测试这些组合。
do_bench 测量工具。允许你传入自定义的计时函数。默认情况下,Triton 会用内部的计时器跑很多次取平均值。
cache_results 持久化 。设为 True 会把调优结果存到硬盘上,这样你下次重启 Python 脚本时就不需要重新测一遍了。
B. 触发逻辑类
参数名 解释
key 重新调优的触发点 。这是一个参数名列表。例如 key=['N']。如果这次输入的 ,下次,Triton 会认为"环境变了",从而重新为新的规模进行一轮自动调优。
prune_configs_by 剪枝过滤。如果你的候选 Config 有 100 多个,全部跑一遍太慢。你可以传一个性能模型(perf_model)来预估并剔除那些明显不靠谱的配置。
C. 副作用清理类(非常关键!)

调优时 Kernel 会跑很多遍。如果你的 Kernel 是 C += A * B,跑 10 遍调优,C 里的结果就翻了 10 倍,这会导致数据错误。

参数名 解释
reset_to_zero 归零保护 。在测试每一组配置前,自动把指定的 Tensor(如 ['C'])清零。防止多次运行导致数值累加。
restore_value 现场还原。调优前备份数据,测试完每组配置后再还原回去。
pre_hook / post_hook 自定义钩子。如果你有更复杂的清理逻辑(比如重置随机数种子),可以写 lambda 传进来。

看这个源码时有三个细节需要特别注意:

1. 调优是"有损"的

注释里那句 :note: the kernel will run multiple times 非常重要。如果你正在写一个原子加法(Atomic Add)的 Kernel,千万别忘了配置 reset_to_zero,否则你在调优阶段就会把输出张量弄脏。

2. key 的粒度

不要把所有参数都放进 key

  • 如果 ptr 变化也触发重新调优,那你的程序会一直在调优,因为指针地址每次都在变。

  • 通常只放 N (数据总量)、MK 这种影响计算密度和网格划分的形状参数

3. CUDA Graphs

use_cuda_graph=True 是针对高频、小任务算子的。它能减少 Python 调用 GPU 的发射开销(Launch Overhead),让测出来的性能更接近真实生产环境。

triton.Config

python 复制代码
class Config:
    """
    An object that represents a possible kernel configuration for the auto-tuner to try.

    :ivar kwargs: a dictionary of meta-parameters to pass to the kernel as keyword arguments.
    :type kwargs: dict[Str, Any]
    :ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if
                      `num_warps=8`, then each kernel instance will be automatically parallelized to
                      cooperatively execute using `8 * 32 = 256` threads.
    :type num_warps: int
    :ivar num_stages: the number of stages that the compiler should use when software-pipelining loops.
                       Mostly useful for matrix multiplication workloads on SM80+ GPUs.
    :type num_stages: int
    :ivar num_ctas: number of blocks in a block cluster. SM90+ only.
    :type num_ctas: int
    :type maxnreg: Optional[int]
    :ivar maxnreg: maximum number of registers one thread can use.  Corresponds
                       to ptx .maxnreg directive.  Not supported on all platforms.
    :ivar pre_hook: a function that will be called before the kernel is called. Parameters of this
                    function are args.
    :ivar ir_override: filename of a user-defined IR (*.{ttgir|llir|ptx|amdgcn}).
    """

    def __init__(self, kwargs, num_warps=4, num_stages=3, num_ctas=1, maxnreg=None, pre_hook=None, ir_override=None):
        self.kwargs = kwargs
        self.num_warps = num_warps
        self.num_ctas = num_ctas
        self.num_stages = num_stages
        self.maxnreg = maxnreg
        self.pre_hook = pre_hook
        self.ir_override = ir_override
    #略

1.kwargs (核心逻辑参数)

  • 解释 :这是一个字典,包含了你传给 Kernel 的所有 tl.constexpr 参数。

  • 示例{'BLOCK_SIZE': 128, 'GROUP_SIZE_M': 8}

  • 底层映射 :这些值决定了 Triton 编译器如何进行循环展开(Unrolling)静态索引计算

2. num_warps (线程束数量)

  • 解释:定义一个 Program (CTA/Block) 使用多少个 Warp。

  • 硬件映射:1 个 Warp = 32 个线程。

    • 如果你设置 num_warps=4,该 Block 有 128 线程。

    • 如果你设置 num_warps=8,该 Block 有 256 线程。

  • 影响 :Warp 越多,计算吞吐量可能越高,但会消耗更多的寄存器,可能导致活跃的 Block 数量减少(Occupancy 下降)。

3. num_stages (软件流水线级数)

  • 解释 :这是专门为 SM80+ (Ampere, Hopper) 架构设计的参数,用于控制软件流水线(Software Pipelining)

  • 硬件背景:在 A100/H100 上,支持异步拷贝指令(Async Copy)。

    • num_stages=1:同步读,读完计算。

    • num_stages=3+:一边从显存异步读第 N+2 块数据,一边在共享内存里处理第 N 块数据。

  • 影响 :级数越多,越能通过双缓冲(Double Buffering)机制掩盖访存延迟,但会占用更多的 Shared Memory

4. num_ctas (CTA 聚簇大小)

  • 解释 :仅限 SM90+ (Hopper/H100)

  • 硬件映射 :控制 Thread Block Clusters。它允许不同的 Block 之间通过特殊的硬件链路直接交换 Shared Memory 中的数据。

  • 影响:在处理非常大的矩阵乘法(如大规模隐层)时,增加聚簇大小可以极大提升 L2 缓存的命中率。

5. maxnreg (寄存器上限)

  • 解释 :对应 PTX 中的 .maxnreg 指令。

  • 底层映射:手动限制每个线程能用的最大寄存器数量。

  • 应用场景 :当你希望强制提高 Occupancy(占用率) 时使用。例如,编译器默认给每个线程分配 64 个寄存器,导致一个 SM 只能跑 2 个 Block;你限制它只准用 32 个,SM 就能跑 4 个 Block,从而提升整体并发度。

6. pre_hook (预处理钩子)

  • 解释:一个函数指针(或 lambda)。

  • 作用 :在执行这个特定的 Config 之前调用。通常用于复杂的调试,或者配合 autotune 进行一些特定的内存初始化。

7. ir_override(提前定义的IR)

简单来说,ir_override 允许你跳过 Triton 的高级编译阶段,直接喂给它你写好的中间表示(IR)或汇编代码

通常情况下,Triton 的编译流程是: Python (JIT) -> Triton IR -> Triton GPU IR -> LLVM IR -> PTX -> CUBIN (Binary)

如果你设置了 ir_override,你就是在告诉编译器:"别管我 Python 代码里写了什么,直接拿我这个现成的 IR 文件去跑测试。"

它支持多种不同阶段的"中间产物":

  • .ttgir: Triton GPU IR (Triton 自己的中间表示)

  • .llir: LLVM IR (通用的底层编译器中间表示)

  • .ptx: NVIDIA GPU 汇编代码

  • .amdgcn: AMD GPU 汇编代码

共享内存

Triton 逻辑层面上没有显式的"共享内存"关键字,但物理层面上高度依赖共享内存。

你可以把 Triton 的编译器看作是一个"超级自动搬运工"。它把原本需要你手动在 CUDA 里写的搬运逻辑(Load to Shared Memory -> Sync -> Compute),变成了它在后台根据 IR 转换规律自动推导出的结果。

以下是 Triton 将 IR 对应到共享内存操作的三个关键步骤:

1. 从 Tensor 布局(Layout)说起

在 Triton IR 中,每个 Tensor 都有一个属性叫 Layout

  • Blocked Layout:通常对应寄存器中的数据分布。

  • Shared Layout:专门对应共享内存中的数据分布。

当 IR 逻辑中出现"数据交换"的需求时,编译器就会自动触发 Layout 的转换。

2. 转换触发点:ConvertLayout 指令

这是最核心的操作。在 Triton GPU IR 阶段,你会看到大量的 ttg.convert_layout 指令。

场景 A:必须跨线程通信时

当你调用 tl.sumtl.max 进行归约操作,或者执行 tl.dot(矩阵乘法)时,不同的线程需要看到彼此的数据。

  • 编译器行为 :它发现当前的 Blocked Layout(寄存器分布)无法满足这个跨线程计算的需求,于是它会插入一条指令,将数据从 Blocked Layout 转换为 Shared Layout

  • 底层实现 :这行 IR 会被降级(Lowering)为 PTX 中的 st.shared 指令,把数据从寄存器写到共享内存。

场景 B:矩阵乘法(TMA 和 Pipeline)

在执行 tl.dot 时,为了利用 Tensor Core 的高性能,数据必须以特定的格式排布在共享内存在。

  • 编译器行为:它会自动申请一块共享内存空间,把 AB 块搬进去。

3. 内存分配与同步:AllocationBarriers

既然 Triton 替你管了共享内存,它就必须解决两个 CUDA 程序员最头疼的问题:

A. 地址计算(Offsetting)

在 CUDA 里,你需要算 shared_ptr[threadIdx.x + ...]。 在 Triton IR 里,编译器会维护一个 Shared Memory Allocation 模块。它会静态地为每个需要进入共享内存的 Tensor 计算出偏移量(Offset)。这样,不同的 Tensor 块就能共享同一块物理 Shared Memory(通过生命周期分析进行内存重用)。

B. 同步指令(Barrier Insertion)

在 CUDA 里,你得手动写 __syncthreads()。 在生成 PTX 时,Triton 编译器会进行数据流分析

  1. 如果发现第 15 行指令在写共享内存(st.shared),而第 17 行指令在读(ld.shared)。

  2. 编译器会自动在中间插入 bar.sync 指令。

  3. 对于 SM80+ 架构,它甚至会利用 cp.async.wait_group 来实现更细粒度的异步同步,这比手动写要高效得多。

相关推荐
qq_429499572 小时前
RK3566 linux编译成功笔记
笔记
Purple Coder3 小时前
项目一支撑文档阅读笔记-《Handbook on Battery Energy Storage System》
笔记
宵时待雨3 小时前
linux笔记归纳4:进程概念
linux·运维·服务器·c++·笔记
jinyuya4 小时前
[UVM] uvm_reg学习
笔记
magic_now5 小时前
FAT文件系统:嵌入式设备的极简选择
笔记·嵌入式硬件
Hammer_Hans5 小时前
DFT笔记45
java·jvm·笔记
handler015 小时前
速通蓝桥杯省一:二分算法
c语言·开发语言·c++·笔记·算法·职场和发展·蓝桥杯
Hammer_Hans6 小时前
DFT笔记44
笔记
ZOE^V16 小时前
springcloud笔记
笔记·spring cloud·github
QFIUNE6 小时前
【文献阅读】化学空间边缘的分子深度学习
论文阅读·人工智能·笔记·深度学习