Ref
Triton
从Add开始入门
py
import torch
import triton
import triton.language as tl
@triton.jit
def add_kernel(x_ptr, # *Pointer* to first input vector.
y_ptr, # *Pointer* to second input vector.
z_ptr, # *Pointer* to output vector.
N, # Size of the vector.
BLOCK_SIZE: tl.constexpr, # Num elements each program uses
):
# There are multiple 'programs' processing different data.
# We identify which program we are here:
pid = tl.program_id(axis=0)
# Offsets is a list of which elements this program instance will act on
# e.g. if BLOCK_SIZE is 32 these would be
# [0:32], [32:64], [64:96] etc, using the `pid` to find the starting index
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Create a mask to guard memory operations against out-of-bounds acces
mask = offsets < N
# Load x and y, using the mask
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
z = x + y
# Write z back to HBM.
tl.store(z_ptr + offsets, z, mask=mask)
可以看到,pid是以BLOCK_SIZE为单位启动的,然后你同时launch许多pid,他们找到自己执行的区域开始执行并且store回HBM
之后我们launch它:
py
def add(x: torch.Tensor, y: torch.Tensor):
# Preallocate the output.
z = torch.empty_like(x)
N = z.numel()
# grid can be a static tuple, or a callable that returns a tuple
# here it will be (N//BLOCK_SIZE,)
grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE']), )
add_kernel[grid](x, y, z, N, BLOCK_SIZE=1024)
return z
虽然你传入了Tensor,但是他使用了@triton.jit,所以会自动重载到和Kernel相符合的格式