Triton编程技术背诵核心概念

1、Triton是什么?

一个能让Python代码在GPU上飞速运行的翻译器和优化器。

2、Program是什么?

它是Triton里最小的独立施工队,相当于CUDA里的线程块(Block),它负责并行处理数据的一小块。

(1)身份标识的写法及含义

写法:pid = tl.program_id(axis=0)

含义:每个施工队都有一个从0开始的唯一编号

(2)自动并行逻辑

只需要编写好一个施工队的逻辑,Triton会自动派生出成百上千个施工队同时干活

(3)分工逻辑

利用pid计算各自负责的数据范围,例如第pid个施工队负责处理[pid * BLOCK_SIZE : (pid + 1) * BLOCK_SIZE]这段数据。

3、1D Launch网格是什么?

它是给GPU派活的排班表,它用一行队列规划了要启动多少个并行的Program。

(1)Grid哪里定义&含义

它在CPU端定义

含义:决定了并行任务的总体蓝图。例如grid=(10,)表示排了10个施工队同时开工。

4、axis是什么?

axis用于定义索引的维度。

(1)三维轴分别是什么?

axis=0:线,数据线性排列的未知,计算偏移

axis=1:行,矩阵或二维结构的行号

axis=2:层,三维或复杂结构的深度、层数

(2)计算起始位置的公式

通过tl.program_id(axis=0)返回当前程序在线性序列中的ID,

起始位置=ID * 块大小。

比如:block_start = pid * BLOCK_SIZE

5、利用arange生成块内偏移量公式

全局偏移 = 块起始位置 + tl.arange(块内索引模板)

(1)为何这么写?

向量化,GPU一次性加载整个向量(比如[8,9,10,11]),而非循环4次,效率极高

6、边界检查掩码作用?

防止数据总量非BLOCK_SIZE整倍数时的越界访问。

(1)核心公式

掩码 = 全局偏移 < 总元素数

(2)工作流程

1)计算当前块的全局索引offsets,如[8,9,10,11]

2)比如offsets < n_elements,生成布尔向量

3)将掩码传入tl.load或tl.store,硬件仅对True位置执行读写,忽略False位置

7、JIT即时编译的定义

运行时动态编译,追求快与灵活的平衡。

(1)与静态编译、解释执行的对比

-静态编译:先编译(比如.exe, .jar),后运行。快但不灵活

-解释执行:边读边执行。灵活但慢

-JIT编译:运行时把频繁执行的热点代码抓出来,编译后直接执行机器码。既快又灵活。

(2)Triton中的JIT流程

1)用@triton.jit装饰函数

2)首次调用时,根据参数和硬件,动态生成GPU机器码并缓存

3)后续调用时,直接执行缓存的机器码,速度极快

8、GPU核心概念:SM/线程块/线程/Warp/寄存器

SM:Stream Multiprocessor,流多处理器。硬件核心,含计算单元与资源

线程块:软件逻辑单元,资源分配最小单位

线程:最小执行实体

Warp:硬件调度单元(32线程/组),强制同步

寄存器:SM内高速存储,线程私有

(1)层级关系

GPU由多个SM组成。

每个SM同时运行一个或多个线程块

每个线程块由多个线程组成

硬件将线程块切分为多个Warp进行调度

每个线程拥有私有的寄存器,而线程块内的线程共享 共享内存。

9、TMA内核的定义

利用NVIDIA Hopper架构的张量内存加速器(TMA)硬件,实现高效异步数据搬运的内核。

(1)核心特点

-硬件加速:专门的硬件,独立于计算核心,自动处理多维张量布局(步长、转置等)

-描述符驱动:通过tl.make_tensor_descriptor定义内存布局(地址、形状、步长),硬件自动计算地址并传输

-异步流水线:数据搬运与计算完全重叠,提升利用率

(2)工作流程

1)主机端定义张量布局,生成TMA描述符

2)内核调用load/store,TMA硬件异步执行搬运

3)通过异步屏障(mbarrier)协调计算与搬运的依赖关系

10、持久化内核定义

线程块处理完一个任务后不退出,而是循环获取新任务,直到所有工作完成。

(1)设计动机

-负载均衡:快的线程块自动处理更多任务,避免SM资源闲置

-隐藏延迟:连续处理多个块,重叠计算与内存访问

-减少开销:降低内核启动和GPU调度频率

(2)核心:网格大小/循环处理/步长调度

网格大小:grid = min(SM数量,总块数)

循环处理:每个线程块通过for循环处理多个数据块

步长调度:步长=SM数量,均匀分配任务

(3)代码模板

python 复制代码
grid = min(NUM_SMS, total_tiles)
tile_id_c = start_pid - NUM_SMS
for tile_id in range(start_pid, total_tiles, NUM_SMS):
    # 1. 计算当前块
    result = compute(tile_id)
    # 2. 存储上一块(流水线)
    store(result_prev, tile_id_c)
    tile_id_c += NUM_SMS
    result_prev = result

11、流水线级数num_stages定义

软件流水线的并发阶段数,实现计算与数据加载的重叠。

(1)提升效果

-延迟隐藏:更多预加载操作掩盖内存访问延迟

-计算密度:计算单元持续工作,减少空闲等待

-性能提升:最终实现更高的TFLOPS

(2)缺点

每个stage需在共享内存中保留一份数据副本,级数过高可能导致共享内存不足。

12、常用函数功能

(1)@triton.jit

即时编译,将Python函数编译为GPU机器码

(2)@triton.autotune

自动调优,运行时测量并选择最优配置

(3)tl.program_id(axis)

获取ID,获取当前线程块在指定维度的ID

(4)tl.num_programs(axis)

获取总数,获取指定维度上启动的线程块总数

(5)tl.arange(s, e)

生成索引,生成[s, e)的连续向量,用于计算偏移

(6)tl.load(ptr, mask)

加载数据,从显存加载数据到寄存器

(7)tl.store(ptr, val, mask)

存储数据,将数据写回显存

(8)tl.constexpr

编译时常量,标记编译期已知的常量,用于优化

(9)tl.where(cond, x, y)

条件选择,根据条件从x或y中选择值

(10)tl.dot(a, b)

矩阵乘法,执行矩阵乘法

(11)tl.make_tensor_descriptor()

创建TMA描述符,定义张量布局,用于TMA硬件加速传输

(12)tl.atomic_cas(ptr, cmp, val)

原子操作,比较并交换。读取ptr,若等于cmp则写入val

(13)tl.debug_barrier()

同步屏障:阻塞线程块直到所有线程到达此点

(14)tl.multiple_of(ptr, dims)

对齐提示:提示指针在指定维度上对齐

(15)tl.max_contiguous(ptr, dims)

连续性提示:提示指针元素是连续的

(16)triton.cdiv(a, b)

向上取整除法

(17)triton.next_power_of_2(n)

计算超过n的最小2的幂,用于对齐内存或确定缓冲区大小

(18)triton.Config(meta, num_stages, num_warps)

定义内核配置候选,用于自动调优,指定不同编译参数组合

(19)triton.set_allocator(alloc_fn)

设置自定义内存分配器,覆盖默认分配行为,用于与PyTorch兼容或支持TMA描述符

(20)torch.allclose(input, other, rtol, atol)

判断张量数值近似相等,用于单元测试验证计算结果。

通过相对容差(rotl)和绝对容差(atol)控制精度要求。

13、@triton.jit和@torch.jit.script区别

@triton.jit是写GPU内核,操作线程、内存块

@torch.jit.script是写模型计算图,操作张量、网络层

14、warmup功能

提前编译内核,获取元数据,但不执行实际计算。

(1)3个作用

-触发编译,让Triton把Python代码编译成GPU机器码

-生成元数据,算出内核的寄存器、共享内存使用量

-只编译不计算,节省时间

(2)关键参数num_stages/num_warps/grid

num_stages:流水线级数,预取数据份数,越大隐藏延迟越好

num_warps:线程束数量,每个Block的Warp数

grid:网格数

15、L1/L2缓存定义

L1缓存:每个SM私有的私人文档架,容量几十KB,速度最快

L2缓存:所有SM共享的中央档案室,容量几MB到几十MB,速度较快

(1)数据查找路径

GPU要数据时,先翻L1,没有再去L2,再没有才去显存。

16、L2优化策略

计算C= A * B,从行主序换成分组的列主序的策略。

(1)核心操作

1)变序,将计算顺序改为列主序

2)分组,在列主序的基础上,将行划分为小组(GROUP_SIZE_M)

3)执行:组内竖着算,锁定B的一个列块,一次性算完组内所有行与该列的乘积,再切换B的下一个列块

(2)性能优化原理

A的行数据的访问是线性连续的,GPU内存控制器擅长处理此类流式数据,带宽利用率高。

B的列数据的访问是跨步跳跃的,这种非连续访问模式会导致带宽利用率极低,成为性能瓶颈。

所以从行主序的B访问多次 换成 列主序的A访问多次。

17、Meta变量的定义

meta是Triton运行时自动传入的一个字典,它包含了当前正在尝试的所有编译器常量参数,如BLOCK_SIZE。

它是自动调优(AutoTuning)的核心载体,编译器会通过切换不同的meta配置组合,找出性能最优的一组。

(1)meta值的时效性

meta中具体数值只有在运行时才会被确定。

每当自动调优机制测试一组新的配置参数时,运行时系统才会生成对应的meta字典并传入。在代码定义阶段,这些值是未知的。

(2)grid=(meta['NUM_SM'],)错误的原因

python解释器在代码定义阶段立即尝试读取meta['NUM_SM'],但由于此时meta尚未生成,所以会报错。

改成grid = lambda meta: meta['NUM_SM']的形式后,只定义了计算逻辑但不执行。只有当Triton运行真正调用该函数时,才会将当前meta字典传入并计算出具体的网格大小。这种延迟计算机制确保了grid能随自动调优的配置动态变化。

18、Pytorch中的CTX的定义

ctx是torch.autograd.Function中的上下文对象,专门用于在前向传播(forward)和反向传播(backward)之间安全地传递数据。

它是连接计算图两个阶段的唯一通道,确保反向传播时能获取前向计算所需的中间变量或配置参数。

(1)保存张量的作用

它告诉Pytorch别释放这个张量,确保其在计算图中存活,以便后续计算导数。读取时通过ctx.saved_tensor获取

(2)保存任意对象的作用

这个数据完全脱离计算图,不参与梯度追踪,仅作为普通Python属性在前向后向间透传。读取时直接访问ctx.变量名。

19、智能优化内存合并的原理

当多个线程访问连续的内存地址时,GPU硬件会将这些零散请求合并为少数几次大宽度传输。

这样极大提升显存带宽利用率。

(1)multiple_of的作用

对齐提示。

告知编译器索引地址是某数的倍数(如16,32)。

这样编译器推断出内存对齐信息,启用更快的对齐加载指令,避免非对齐访问的性能惩罚。

(2)max_contiguous的作用

连续提示。

告知编译器索引是连续递增的,且最大连续长度是多少。

这样可以触发向量化加载,一条指令一次性搬运多个连续元素,显著减少指令数量,提高吞吐量。

(3)向量化加载的定义

利用硬件并行通道,单条指令同时从连续内存中加载多个同类型数据到寄存器。

20、后处理子块Epilogue Subtiling技术定义/目的

定义:将计算完成的输出大块逻辑拆分为多个小子块,利用TMA硬件 Gather能力直接写回全局内存。

目的:实现计算(类型转换)与通信(内存写入)的流水线重叠,并消除共享内存缓冲

(1)3个关键步骤

-算:完整块累加至寄存器(数据物理分散)

-拆:reshape -> permute -> split,仅改变逻辑视图,不移动物理数据

-存:子块分别to(dtype)转换 -> TMA.store 异步写入,TMA自动从分散寄存器Gather数据,组装成连续内存块写入

(2)为什么要拆分

-并行:拆除独立任务后,让计算和存储同时跑

-适配硬件:将大任务拆分为小任务,给TMA发独立指令,激活异步能力

-省共享内存:跳过共享内存中转,直接写显存

(3)能拆更多快吗?

拆更多块,调度开销大,管理复杂,收益递减。

相关推荐
前端摸鱼匠2 小时前
面试题4:多头注意力(MHA)相比单头注意力的优势是什么?Head数如何影响模型?
人工智能·ai·面试·职场和发展·求职招聘
yhdata2 小时前
车载图像处理芯片发展按下“快进键”:至2032年市场规模将逼近27.29亿元,产业动能强劲
图像处理·人工智能
NOCSAH2 小时前
统好AI数智平台CRM:智能驱动客户管理新体验
人工智能·数智化一体平台·统好ai
视***间2 小时前
2026:AI算力元年的加冕与思辨
人工智能·microsoft·机器人·边缘计算·智能硬件·视程空间
径硕科技JINGdigital2 小时前
B2B工业制造企业GEO供应商排名审视:以专业交付能力为核心的选型指南
大数据·人工智能·科技
Westward-sun.2 小时前
PyTorch入门实战:MNIST手写数字识别(全连接神经网络详解)
人工智能·pytorch·神经网络
大傻^2 小时前
Spring AI Alibaba Agent开发:基于ChatClient的智能体构建模式
java·数据库·人工智能·后端·spring·springaialibaba
F_U_N_2 小时前
轻量化开源知识库落地路径研究:AI赋能、多端集成及合规管理指引
人工智能·开源
丝斯20112 小时前
AI学习笔记整理(75)——Python学习4
人工智能·笔记·学习