目前 Triton 主要支持 Linux 系统,并且需要拥有 NVIDIA GPU(通常要求 Compute Capability 7.0 及以上,即 Volta 架构以后,如 V100, RTX 20/30/40 系列)。
你可以使用 pip 快速安装:
bash
pip install triton
这里我们看Triton官方的第一个示例代码:向量加法
python
import torch
import triton
import triton.language as tl
@triton.jit
def add_kernel(
x_ptr,
y_ptr,
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr
):
pid=tl.program_id(axis=0)
block_start=pid*BLOCK_SIZE
offsets=block_start+tl.arange(0,BLOCK_SIZE)
mask=offsets<n_elements
x:tl.tensor = tl.load(x_ptr + offsets, mask=mask)
y:tl.tensor = tl.load(y_ptr + offsets, mask=mask)
output = x + y
# Write x + y back to DRAM.
# 将 x + y 写回 DRAM。
tl.store(output_ptr + offsets, output, mask=mask)
def add(x: torch.Tensor, y: torch.Tensor):
# We need to preallocate the output.
# 需要预分配输出。
output = torch.empty_like(x)
assert x.is_cuda and y.is_cuda and output.is_cuda
n_elements = output.numel()
# The SPMD launch grid denotes the number of kernel instances that run in parallel.
# SPMD 启动网格表示并行运行的内核实例的数量。
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].
# 它类似于 CUDA 启动网格。它可以是 Tuple[int],也可以是 Callable(metaparameters) -> Tuple[int]。
# In this case, we use a 1D grid where the size is the number of blocks:
# 在这种情况下,使用 1D 网格,其中大小是块的数量:
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
# NOTE:
# 注意:
# - Each torch.tensor object is implicitly converted into a pointer to its first element.
# - 每个 torch.tensor 对象都会隐式转换为其第一个元素的指针。
# - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel.
# - `triton.jit` 函数可以通过启动网格索引来获得可调用的 GPU 内核。
# - Don't forget to pass meta-parameters as keywords arguments.
# - 不要忘记以关键字参数传递元参数。
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still running asynchronously at this point.
# 返回 z 的句柄,但由于 `torch.cuda.synchronize()` 尚未被调用,此时内核仍在异步运行。
return output
torch.manual_seed(0)
size = 98432
x = torch.rand(size, device='cuda')
y = torch.rand(size, device='cuda')
output_torch = x + y
output_triton = add(x, y)
print(output_torch)
print(output_triton)
print(f'The maximum difference between torch and triton is '
f'{torch.max(torch.abs(output_torch - output_triton))}')
从 C++ CUDA 的视角来看,Triton 的核心逻辑其实是将 Thread-level (线程级) 的显式控制隐藏了,转而使用 SIMD-style (单指令多数据) 的向量化操作。
我们可以通过下面这个表格和详细拆解,将这段代码与 CUDA C++ 的概念一一对应:
| Triton 代码 / 概念 | CUDA C++ 对应概念 | 说明 |
|---|---|---|
@triton.jit |
__global__ void ... |
定义一个在 GPU 上启动的内核(Kernel)。 |
x_ptr |
float* x_ptr |
传入的指针,在内核内部被视为指向内存起始地址的基础。 |
BLOCK_SIZE: tl.constexpr |
template <int BLOCK_SIZE> |
编译期常量,类似于 CUDA 中的模板参数或宏定义。 |
tl.program_id(0) |
blockIdx.x |
每一个 Triton 程序实例(Instance)对应 CUDA 中的一个 Block。 |
tl.arange(0, BLOCK_SIZE) |
(无直接对应) | 创建一个长度为 BLOCK_SIZE 的局部索引序列。 |
tl.load / tl.store |
内存 Load/Store | Triton 会自动优化为内存对齐的向量化加载(Coalesced Access)。 |
@triton.jit
这本身是python的装饰器语法。语法部分暂时就不深究了,主要了解一下作用
对于习惯 C++ 编译流程(预处理 -> 编译 -> 汇编 -> 链接)的开发者来说,@triton.jit 是 Triton 的核心引擎 ,它把 Python 变成了一个真正的硬件编译器前端。
简单来说,它的作用是将你的 Python 函数拦截下来,将其转换成中间表示(IR),然后针对 GPU 硬件即时编译成二进制代码。
当你运行一个被 @triton.jit 装饰的函数时,发生的不是简单的 Python 解释执行,而是一套完整的编译链:
-
AST 解析:Triton 遍历该函数的 Python 抽象语法树(AST)。
-
生成 TTIR (Triton Intermediate Representation):将 Python 逻辑转换为 Triton 特有的中间表示。这一步是硬件无关的。
-
生成 LLVM IR:将 TTIR 转换为 LLVM 的中间表示,进行通用的编译器优化(如死代码消除、循环展开)。
-
生成 PTX / CUBIN :针对你的 NVIDIA GPU 架构(如 RTX 3090 的
sm_86),生成底层的汇编指令(PTX)和二进制镜像(CUBIN)。
为什么要用装饰器?(动态编译的威力)
这种 JIT 方式比 CUDA C++ 的 AOT(Ahead-of-Time)编译有巨大的灵活性优势:
-
特化编译 (Specialization) : 内核会根据你传入的常量参数 (如
BLOCK_SIZE)和数据类型 (如fp32还是fp16)生成不同的二进制版本。-
如果你调用
add_kernel[grid](..., BLOCK_SIZE=256),它编一个版本。 -
如果你调用
add_kernel[grid](..., BLOCK_SIZE=1024),它又编一个优化过的版本。
-
-
消除参数传递开销 : 因为
BLOCK_SIZE是tl.constexpr,编译器在编译时就直接把它硬编码到指令里了,不需要像 C++ 那样在运行时从寄存器读取。
装饰器内部持有的"超能力"
当你装饰了一个函数后,原本的 Python 函数对象就被包装成了一个 JITFunction 对象。这个对象拥有几个开发者非常喜欢的属性:
| 属性/方法 | 作用 |
|---|---|
.get_ir() |
查看生成的中间表示(TTIR),这类似于查看 C++ 的汇编中间件。 |
.get_ptx() |
这是最直接的! 你可以直接打印出 PTX 代码,看看它生成的指令是不是你想要的(比如有没有用到 ld.global.nc)。 |
| 缓存机制 | Triton 会把编译好的 .cubin 缓存在磁盘(通常是 ~/.triton/cache)。下次运行相同参数时,秒开,无需重新编译。 |
为什么不能在里面写普通的 Python?
由于 @triton.jit 负责把代码翻译成 GPU 指令,你不能在被装饰的函数内部调用非 Triton 提供的 Python 库:
-
❌ 不行 :
import numpy; numpy.sin(x)(CPU 库无法在 GPU 内核运行)。 -
❌ 不行 :使用
list或dict(GPU 显存里没有这些高级数据结构)。 -
✅ 可以 :使用
triton.language(即tl) 提供的函数。这些函数都有对应的硬件映射。
总结:CUDA 视角下的 @triton.jit
你可以把它理解为 "带自动 JIT 触发功能的 C++ 模板"。
cpp
// 它的效果等价于:
template <typename T, int BLOCK_SIZE>
__global__ void kernel(T* ptr) { ... }
// 当你在 Python 里调用时:
// kernel[grid](tensor, BLOCK_SIZE=512)
// 相当于:
// 1. 检查是否有 kernel<float, 512> 的实例?
// 2. 没有则立即调用 nvcc 现场编一个。
// 3. 运行它。
这种设计让你既能享受 Python 快速迭代的快感(改个参数直接运行),又能获得 C++ 级别的执行效率。
triton.language.tensor
以下简写为tl.tensor
在 Triton 中,tl.tensor 是一个逻辑抽象 。它代表的是一个存储在寄存器(Register)中的多维数据块。
对于习惯 C++ 和 CUDA 的开发者,理解 tl.tensor 的关键在于:它不是指向显存的指针,而是已经被加载到计算单元内部的一组数据值。
以下是它的核心特性:
它是"分块"的(Block-based)
在 CUDA C++ 中,你的基本操作对象通常是 float(标量)。而在 Triton 中,基本操作对象就是 tl.tensor。
-
特性:它必须是 2 的幂次方大小(如 128, 256, 512...)。这是为了匹配 GPU 硬件的 SIMD/SIMT 架构。
-
C++ 类比 :你可以把它想象成 C++ 里的
std::array<float, 1024>,但在硬件层面,它直接映射到了一组向量寄存器上。
它是"隐式并行"的(Implicitly Parallel)
当你对两个 tl.tensor 执行加法时:
python
output = x + y # x 和 y 都是 tl.tensor
这行代码在硬件上会自动并行化。Triton 编译器会将这个操作分配给该 Program(即 CTA)内部的所有线程。
- 特性 :你不需要写
for循环,也不需要写threadIdx。所有的算术运算(+,-,*,/,tl.exp,tl.dot)都是在整个 Tensor 块上同步执行的。
强力类型转换与精度控制
作为 AI Infra 开发者,你肯定关心精度(FP32, FP16, BF16)。tl.tensor 支持非常方便的类型转换,这对编写高性能 Kernels(如混合精度算子)至关重要。
python
x = tl.load(ptr) # 假设读取的是 FP32
x_half = x.to(tl.float16) # 类似于 static_cast<half>(x)
广播机制(Broadcasting)
tl.tensor 遵循类似 NumPy 的广播规则。这在处理 Bias(偏置)或者归一化(Normalization)时极其好用。
- 例子:如果你有一个 [128, 128] 的矩阵 Tensor 和一个 [128, 1] 的列向量 Tensor 相加,Triton 会自动在寄存器层面处理广播逻辑。
它是"无地址"的(Addressless)
这一点对 C++ 开发者来说最容易混淆:
-
x_ptr:是一个地址。 -
x = tl.load(x_ptr):执行完这行后,x就是一个tl.tensor。它此时已经离开了显存(DRAM),住进了 寄存器 里。 -
特性 :你不能对
x进行取地址操作,因为它已经不是内存里的东西了。
静态维度约束(Static Shapes)
虽然 Python 是动态的,但 Triton 的 tl.tensor 维度在编译时(JIT 阶段)必须是确定的(即 tl.constexpr)。
- 如果你的
BLOCK_SIZE是 1024,那么tl.arange(0, BLOCK_SIZE)生成的 Tensor 形状就是固定的。这使得编译器可以生成极其精简的汇编代码,省去了运行时的形状检查。
triton.language.constexpr
对应 C++ CUDA 的概念里,tl.constexpr 对应的不是变量,而是 模板参数(Template Parameters) 或 宏定义(Macros)。
tl.constexpr 告诉 Triton 编译器:"这个值在编译期就已经是确定的常量了。"
-
编译期求值 :当 Triton 的 JIT 编译器在处理
@triton.jit装饰的函数时,它会为不同的constexpr值生成不同的硬件二进制文件 (CUBIN)。 -
死代码消除 :如果你在代码里写
if BLOCK_SIZE > 512:,而传入的constexpr是 256,编译器在生成 PTX 指令时会直接把这个分支删掉,完全没有运行时的if开销。 -
寄存器分配的依据:编译器必须知道具体的数值(如 128, 256),才能决定一个 Thread Block 到底要占用多少个寄存器,以及如何分配 Shared Memory。
cpp
// Triton
BLOCK_SIZE: tl.constexpr
// 对应 C++ CUDA
template <int BLOCK_SIZE>
__global__ void kernel(...) { ... }
BLOCK_SIZE 指定的是什么大小?
在 Triton 的典型语境下,BLOCK_SIZE 指定的是 一个 Program (即一个 CUDA CTA/Block) 一次性处理的元素个数。
我们可以从三个层面来解剖这个"大小":
A. 逻辑层面:Tile Size(分块大小)
它定义了你那个"超级向量"的长度。
-
如果你设置
BLOCK_SIZE = 1024,那么tl.arange(0, BLOCK_SIZE)就会生成一个长度为 1024 的索引向量。 -
这意味着你接下来的
tl.load、tl.store和算术运算,都是以 1024 个元素为一组进行的。
B. 内存层面:Memory Coalescing(访存对齐)
它决定了访存的粒度。
- GPU 的显存带宽在连续访问时效率最高。
BLOCK_SIZE通常设为 128 或更大,这样 Triton 编译器就能一次性发出长向量加载指令(如LDG.E.128),完美对齐显存位宽。
C. 硬件映射层面(最关键)
虽然你指定的是元素个数,但它间接决定了硬件线程的负载。
-
线程数推导 :Triton 还有一个默认参数叫
num_warps(通常是 4 或 8)。-
如果
BLOCK_SIZE = 1024且num_warps = 8: -
总线程数 =
个线程。
-
每个线程的负载 =
。即:每个线程的寄存器里存了 4 个元素。
-
triton.language.program_id
tl.program_id(axis) 在底层语义上几乎完全等同于 CUDA 中的 blockIdx。
python
@builtin
def program_id(axis, _semantic=None):
"""
Returns the id of the current program instance along the given :code:`axis`.
:param axis: The axis of the 3D launch grid. Must be 0, 1 or 2.
:type axis: int
"""
# if axis == -1:
# pid0 = _semantic.program_id(0)
# pid1 = _semantic.program_id(1)
# pid2 = _semantic.program_id(2)
# npg0 = _semantic.num_programs(0)
# npg1 = _semantic.num_programs(1)
# return pid0 + pid1*npg0 + pid2*npg0*npg1
axis = _unwrap_if_constexpr(axis)
return _semantic.program_id(axis)
| Triton | CUDA C++ | 说明 |
|---|---|---|
tl.program_id(0) |
blockIdx.x |
获取当前程序实例在 X 轴的 ID |
tl.program_id(1) |
blockIdx.y |
获取当前程序实例在 Y 轴的 ID |
tl.program_id(2) |
blockIdx.z |
获取当前程序实例在 Z 轴的 ID |
tl.num_programs(0) |
gridDim.x |
获取该轴上一共有多少个程序实例 |
如果axis==-1,那么返回的该CTA在全局的blockIdx编号
triton.language.arange
python
@builtin
def arange(start, end, _semantic=None):
start = _unwrap_if_constexpr(start)
end = _unwrap_if_constexpr(end)
return _semantic.arange(start, end)
arange.__doc__ = f"""
Returns contiguous values within the half-open interval :code:`[start,
end)`. :code:`end - start` must be less than or equal to
:code:`TRITON_MAX_TENSOR_NUMEL = {TRITON_MAX_TENSOR_NUMEL}`
:param start: Start of the interval. Must be a power of two.
:type start: int32
:param end: End of the interval. Must be a power of two greater than
:code:`start`.
:type end: int32
"""
| 维度 | 解释 |
|---|---|
| 函数作用 | 在寄存器中生成一个包含连续整数的 1D Tensor。这个 Tensor 通常被用作"基础索引",后续通过加上偏移量来访问内存。 |
参数 start |
区间的起点(闭区间)。在实际硬件映射中,通常设为 0。注意:源码要求其为 2 的幂。 |
参数 end |
区间的终点(开区间)。它决定了生成的 Tensor 的长度(即 BLOCK_SIZE)。注意:源码要求其为 2 的幂且大于 start。 |
| 返回值 | 一个形状为 (end - start,) 的 tl.tensor ,类型通常为 int32。 |
按我们之前理解的,tl.tensor 是由每个线程寄存器合成的一个向量。所以arange返回的向量中,每个线程寄存器保存了本线程要处理的(end-start)范围中元素所对应的下标索引。同理,一个线程可能保存多个所要处理的元素下标。
triton.language.load
python
@builtin
def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", cache_modifier="", eviction_policy="",
volatile=False, _semantic=None):
"""
Return a tensor of data whose values are loaded from memory at location defined by `pointer`:
(1) If `pointer` is a single element pointer, a scalar is be loaded. In
this case:
- `mask` and `other` must also be scalars,
- `other` is implicitly typecast to `pointer.dtype.element_ty`, and
- `boundary_check` and `padding_option` must be empty.
(2) If `pointer` is an N-dimensional tensor of pointers, an
N-dimensional tensor is loaded. In this case:
- `mask` and `other` are implicitly broadcast to `pointer.shape`,
- `other` is implicitly typecast to `pointer.dtype.element_ty`, and
- `boundary_check` and `padding_option` must be empty.
(3) If `pointer` is a block pointer defined by `make_block_ptr`, a
tensor is loaded. In this case:
- `mask` and `other` must be `None`, and
- `boundary_check` and `padding_option` can be specified to control the behavior of out-of-bound access.
:param pointer: Pointer to the data to be loaded
:type pointer: `triton.PointerType`, or block of `dtype=triton.PointerType`
:param mask: if `mask[idx]` is false, do not load the data at address `pointer[idx]`
(must be `None` with block pointers)
:type mask: Block of `triton.int1`, optional
:param other: if `mask[idx]` is false, return `other[idx]`
:type other: Block, optional
:param boundary_check: tuple of integers, indicating the dimensions which should do the boundary check
:type boundary_check: tuple of ints, optional
:param padding_option: should be one of {"", "zero", "nan"}, the padding value to use while out of bounds. "" means an undefined value.
:param cache_modifier: changes cache option in NVIDIA PTX
:type cache_modifier: str, optional, should be one of {"", ".ca", ".cg", ".cv"}, where ".ca" stands for
cache at all levels, ".cg" stands for cache at global level (cache in L2 and below, not L1),
and ".cv" means don't cache and fetch again. see
`cache operator <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#cache-operators>`_ for more details.
:param eviction_policy: changes eviction policy in NVIDIA PTX
:type eviction_policy: str, optional
:param volatile: changes volatile option in NVIDIA PTX
:type volatile: bool, optional
"""
# `mask` and `other` can be constexpr
mask = _unwrap_if_constexpr(mask)
other = _unwrap_if_constexpr(other)
if mask is not None:
mask = _semantic.to_tensor(mask)
if other is not None:
other = _semantic.to_tensor(other)
padding_option = _unwrap_if_constexpr(padding_option)
cache_modifier = _unwrap_if_constexpr(cache_modifier)
eviction_policy = _unwrap_if_constexpr(eviction_policy)
volatile = _unwrap_if_constexpr(volatile)
return _semantic.load(pointer, mask, other, boundary_check, padding_option, cache_modifier, eviction_policy,
volatile)
1. 三种加载模式(The Three Modes)
注释里提到的 (1)(2)(3) 对应了从底层到高层的三种用法:
| 模式 | 描述 | CUDA 对应概念 | 适用场景 |
|---|---|---|---|
| (1) 标量加载 | pointer 是一个单地址 |
*ptr (直接取值) |
加载全局配置、单个阈值。 |
| (2) 指针张量加载 | pointer 是一组地址 (N维 Tensor) |
Vectorized Load (LDG) | 最常用 。结合 arange 实现连续或不连续的向量加载。 |
| (3) 块指针加载 | 使用 make_block_ptr 定义的块 |
TMA (Tensor Memory Accelerator) | 高级模式,专门针对矩阵运算优化,自动处理越界。 |
add_kernel 用的是模式 (2) :x_ptr + offsets 产生了一个"指针 Tensor",然后 tl.load 一次性把这组地址对应的值全取回来。
2. 核心控制参数:安全与逻辑
-
mask(掩码):-
作用 :决定哪些位置需要加载。如果
mask[idx]为False,对应的内存地址不会被访问。 -
重要性 :这不仅是逻辑问题,更是安全性 问题。在 CUDA 里访问
ptr[N+1]会报错(Illegal Memory Access),在 Triton 里我们靠mask来挡住这些非法访问。
-
-
other(填充值):-
作用 :当
mask为False时,寄存器里该位置存什么? -
典型用法 :在做
tl.sum时,越界位置可以填0;做tl.max时,可以填-inf。
-
3. 专家级性能开关(NVIDIA PTX 特性映射)
这些参数直接暴露了 Triton 作为"显卡编译器"的本质,它们会直接改写生成的 PTX 指令:
-
cache_modifier(缓存策略):-
.ca: 缓存到所有级别(L1 + L2)。 -
.cg: 仅缓存至全局级别(跳过 L1,直接进 L2)。当你确定数据不会被重复使用时,用这个可以防止 L1 缓存污染。 -
.cv: 不缓存(Volatile),每次都从显存读。
-
-
eviction_policy(逐出策略):- 控制缓存行(Cache Line)被换出的优先级。例如,可以设置为
evict_first或evict_last来优化复杂的流式数据处理。
- 控制缓存行(Cache Line)被换出的优先级。例如,可以设置为
-
volatile:- 类似于 C++ 的
volatile关键字,确保不从寄存器或缓存缓存读取旧值,强制内存同步。
- 类似于 C++ 的
triton.language.store
python
@_tensor_member_fn
@builtin
def store(pointer, value, mask=None, boundary_check=(), cache_modifier="", eviction_policy="", _semantic=None):
"""
Store a tensor of data into memory locations defined by `pointer`.
(1) If `pointer` is a single element pointer, a scalar is stored. In
this case:
- `mask` must also be scalar, and
- `boundary_check` and `padding_option` must be empty.
(2) If `pointer` is an N-dimensional tensor of pointers, an
N-dimensional block is stored. In this case:
- `mask` is implicitly broadcast to `pointer.shape`, and
- `boundary_check` must be empty.
(3) If `pointer` is a block pointer defined by `make_block_ptr`, a block
of data is stored. In this case:
- `mask` must be None, and
- `boundary_check` can be specified to control the behavior of out-of-bound access.
`value` is implicitly broadcast to `pointer.shape` and typecast to `pointer.dtype.element_ty`.
:param pointer: The memory location where the elements of `value` are stored
:type pointer: `triton.PointerType`, or block of `dtype=triton.PointerType`
:param value: The tensor of elements to be stored
:type value: Block
:param mask: If `mask[idx]` is false, do not store `value[idx]` at `pointer[idx]`
:type mask: Block of triton.int1, optional
:param boundary_check: tuple of integers, indicating the dimensions which should do the boundary check
:type boundary_check: tuple of ints, optional
:param cache_modifier: changes cache option in NVIDIA PTX
:type cache_modifier: str, optional, should be one of {"", ".wb", ".cg", ".cs", ".wt"}, where ".wb" stands for
cache write-back all coherent levels, ".cg" stands for cache global, ".cs" stands for cache streaming, ".wt"
stands for cache write-through, see `cache operator <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#cache-operators>`_ for more details.
:param eviction_policy: changes eviction policy in NVIDIA PTX
:type eviction_policy: str, optional, should be one of {"", "evict_first", "evict_last"}
"""
# `value` can be constexpr
value = _semantic.to_tensor(value)
mask = _unwrap_if_constexpr(mask)
if mask is not None:
mask = _semantic.to_tensor(mask)
cache_modifier = _unwrap_if_constexpr(cache_modifier)
eviction_policy = _unwrap_if_constexpr(eviction_policy)
return _semantic.store(pointer, value, mask, boundary_check, cache_modifier, eviction_policy)
1. 三种存储模式(与 Load 一一对应)
| 模式 | 描述 | 硬件行为 |
|---|---|---|
| (1) 标量存储 | pointer 是单地址 |
向特定地址写入一个数值。 |
| (2) 指针张量存储 | pointer 是一组地址 Tensor |
Vectorized Store 。这是最常用的模式,配合 offsets 实现一整块寄存器数据的写回。 |
| (3) 块指针存储 | 配合 make_block_ptr |
使用 TMA (Tensor Memory Accelerator) 硬件单元进行写回,自动处理 2D 边界。 |
2. 写操作的"守卫":mask
-
作用 :在 CUDA 中,如果你尝试往
ptr[N+1]写数据,程序会直接 Crash(或者静默报错导致数据污染)。 -
Triton 逻辑 :如果
mask[idx]为False,对应的内存地址绝对不会被改写。
3. 写回策略:cache_modifier(性能调优的核心)
tl.store 暴露了 PTX 级别的控制:
-
.wb(Write-Back, 回写):默认模式。数据先写到缓存(L1/L2),等缓存行被替换时再写回显存。 -
.cg(Cache Global):跳过 L1,直接写到 L2。如果你确信这些输出数据在当前 Kernel 结束后不会被立刻读取,用这个可以保护 L1 缓存不被"污染"。 -
.cs(Cache Streaming):流式缓存。标记这块数据很快就会被丢弃,让 L2 缓存优先回收这些空间。 -
.wt(Write-Through, 直写):数据同时写入缓存和显存。这通常用于强同步场景,开销很大。
作为习惯 C++ 的开发者,在使用 tl.store 时要记住:
-
原子性 :
tl.store本身不是原子的。如果两个 Program (CTA) 往同一个地址写,结果是未定义的(Race Condition)。这时候你需要的是tl.atomic_add。 -
合并写入 (Coalescing) :和
load一样,如果你的pointer里的地址是连续的(比如由arange生成),硬件能实现合并写入,带宽利用率最高。 -
类型匹配 :
value会被自动强制转换成pointer所指向的类型(隐式typecast)。如果你的指针是fp16但你写的是fp32的计算结果,这里会发生精度截断。
启动
在cuda中,我们一般直接给 gridSize 传一个整数。但在 Triton 里,grid 往往是一个"在发射瞬间才计算结果的函数"。
1. meta 是什么?
在 Triton 的语境下,meta 指的是 "元参数字典(Meta-parameters Dictionary)"。
当你调用 add_kernel[grid](..., BLOCK_SIZE=1024) 时,方括号里的 BLOCK_SIZE=1024 不仅仅是传给 Kernel 的常量,它还会被 Triton 收集起来,放进一个叫 meta 的 Python 字典里。
- 此时,
meta的内容实际上就是:{"BLOCK_SIZE": 1024}。
只有显式通过 KEY=VALUE(关键字参数)传递的东西,或者在函数签名中被标记为特殊身份(如 tl.constexpr)的东西,才会被 Triton 塞进这个 meta 字典里。
在 Triton 中,函数签名里的 : tl.constexpr 就像是 C++ 里的 template<int BLOCK_SIZE>。
当你执行 add_kernel[grid](..., BLOCK_SIZE=1024) 时:
-
Triton 检查签名 :它发现
BLOCK_SIZE被标注为了tl.constexpr。 -
强制入库 :所有被标注为
constexpr的参数,无论你是怎么传进去的,都会被视为"元参数(Metadata)"。 -
关键字参数(Kwargs)优先 :Triton 的 Launcher 会优先把所有通过
KEY=VALUE形式传递的参数放进meta字典,供grid函数使用。
在例子里:
python
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
你会发现 n_elements 并没有写成 meta['n_elements'],而是直接写了变量名。
这是 Python 的"闭包(Closure)"特性:
-
n_elements是在add函数作用域里定义的变量。 -
lambda函数可以"看见"并"捕获"它外面的变量。 -
所以
n_elements不需要通过meta传递,它已经在lambda的口袋里了。
但是 BLOCK_SIZE 不行 :因为 BLOCK_SIZE 是在调用 add_kernel 的那一刻才确定的(甚至可能由 Autotuner 动态决定),所以它必须通过 meta 这个"官方信使"来传递。
2. 为什么写成 lambda 而不直接写个数字?
这是为了实现 "动态绑定" 和 "自动调优(Autotune)"。
-
CUDA 方式(静态) :你必须先在外面算好
int gridSize = (N + 1023) / 1024,然后再传给<<<gridSize, ...>>>。 -
Triton 方式(延迟计算) :Triton 允许
grid依赖于代码里的tl.constexpr变量。triton.cdiv(n_elements, meta['BLOCK_SIZE'])的意思是:"请去 meta 字典里查一下当前的BLOCK_SIZE是多少,然后用总元素量除以它,算出我需要多少个 Program 实例。"
这样做的好处: 如果你以后用 Triton 的 @triton.autotune 功能,让它在 BLOCK_SIZE 为 512, 1024, 2048 之间自动选最快的,这个 grid 表达式就不需要改动,它会自动根据不同的 BLOCK_SIZE 算出正确的 Grid 大小。
3. 返回值为什么是元组 (..., )?
这对应了 CUDA 的 dim3(x, y, z):
-
Triton 的 Grid 可以是 1D、2D 或 3D 的。
-
即使是 1D 的,也必须返回一个元组(Python 中单元素元组要加逗号,如
(val, ))。 -
如果我们要写一个处理矩阵的 2D Grid,就会写成:
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']), triton.cdiv(N, meta['BLOCK_SIZE_N']))
4. 整个发射流程的"分步动作"
当你执行 add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) 时,后台发生了这些事:
-
收集元参数 :Triton 看到
BLOCK_SIZE=1024。 -
调用 Lambda :Triton 调用
grid({"BLOCK_SIZE": 1024})。 -
计算网格 :Lambda 返回
(97, )(假设 98432 / 1024 \\approx 97)。 -
寻找/编译内核 :Triton 去找一个"专门为
BLOCK_SIZE=1024编译好的二进制文件"。 -
正式发射:调用 CUDA Driver API,启动 97 个 Blocks。
5. C++ 视角下的等价替换
如果你觉得 lambda 太绕,这段代码在逻辑上完全等价于:
python
# 1. 定义一个普通函数来计算网格
def calculate_grid(meta):
# meta 是由 Triton 传入的一个字典
num_blocks = triton.cdiv(n_elements, meta['BLOCK_SIZE'])
return (num_blocks, )
# 2. 传给启动器
add_kernel[calculate_grid](x, y, output, n_elements, BLOCK_SIZE=1024)