简介
这个项目的简介可以看前面这篇文章
https://mp.weixin.qq.com/s/4eom3UmHl4pYjeAJc4L6CA

1.加定值
张量加一个定值。块长B0等于张量长度N0
只启动一个块就行,内核不用读取block id。构造一个长度B0,x_ptr为起点的指针数组,读取x_ptr对应的张量,读出来之后加上标量10.0,自动广播成和x一样的shape,最后写回x_ptr对应张量
py
r"""
## Puzzle 1: Constant Add
Add a constant to a vector. Uses one program id axis.
Block size `B0` is always the same as vector `x` with length `N0`.
.. math::
z_i = 10 + x_i \text{ for } i = 1\ldots N_0
"""
def add_spec(x: Float32[32,]) -> Float32[32,]:
"This is the spec that you should implement. Uses typing to define sizes."
return x + 10.0
@triton.jit
def add_kernel(x_ptr, z_ptr, N0, B0: tl.constexpr):
# We name the offsets of the pointers as "off_"
off_x = tl.arange(0, B0)
x = tl.load(x_ptr + off_x)
# Finish me!
x += 10.0
tl.store(z_ptr + off_x, x)
return
2.加定值 分块
和第一题的区别在于,B0小于N0,每个block负责的只是张量的一部分。
所以先获取block idpid = tl.program_id(0),根据block id做偏移,找到当前块的起点,构造一个从起点开始,长度B0的指针数组。
由于最后一个块可能超过N0,设置mask不超过N0,剩下操作和第一题一样
py
r"""
## Puzzle 2: Constant Add Block
Add a constant to a vector. Uses one program block axis (no `for` loops yet).
Block size `B0` is now smaller than the shape vector `x` which is `N0`.
.. math::
z_i = 10 + x_i \text{ for } i = 1\ldots N_0
"""
def add2_spec(x: Float32[200,]) -> Float32[200,]:
return x + 10.0
@triton.jit
def add_mask2_kernel(x_ptr, z_ptr, N0, B0: tl.constexpr):
pid = tl.program_id(0)
off_x = tl.arange(0, B0) + pid * B0
mask = off_x < N0
x = tl.load(x_ptr + off_x, mask)
# Finish me!
x += 10.0
tl.store(z_ptr + off_x, x, mask)
return
3.向量外加法
外加法outer add,指的是两个向量加起来结果是个矩阵,或者说就是两个向量根据广播规则,拓展之后再做加法,比如一个(x,1)一个(1,y),做加法结果是(x,y)。
这里两个向量的块长都等于向量总长度,也就是不用分块。读取两个张量,和前面读取一个张量类似,构造一个指针数组即可,长度分别是B0和B1的。
然后把这两个向量分别在不同维度拓展,然后相加,利用广播规则自动拓展z = x[None, :] + y[:, None],这里None就是增加一个长度1的维度,冒号是这个维度保持不变。此时我们已经得到外加法的结果了,只不过还是在临时张量里的,要写入z_ptr对应张量。
需要构造一个二维的张量数组off_z = off_y[:, None] * B0 + off_x[None, :],这里同样利用广播规则,注意这里y是行,x是列,遵从的不是一般计算机的二维规则,而是平面坐标系的二维规则,横着的是x方向,竖着的是y方向。
先把y增加一个维度变成
[[0],
[1],
[2]]
然后给乘上x维度的长度,或者说一行的长度B0,这里以B0=3为例
[[0],
[3],
[6]]
然后x也增加一个维度,变成
[[0], [1], [2]]
两个做加法,根据广播规则结果是,这样就构造了一个偏移量数组
[[0, 1, 2],
[3, 4, 5],
[6, 7, 8]]
最后和z_ptr相加根据广播规则会把z_ptr广播到这个张量所有位置上,就得到了目的地址数组tl.store(z_ptr + off_z, z)
py
r"""
## Puzzle 3: Outer Vector Add
Add two vectors.
Uses one program block axis. Block size `B0` is always the same as vector `x` length `N0`.
Block size `B1` is always the same as vector `y` length `N1`.
.. math::
z_{j, i} = x_i + y_j\text{ for } i = 1\ldots B_0,\ j = 1\ldots B_1
"""
def add_vec_spec(x: Float32[32,], y: Float32[32,]) -> Float32[32, 32]:
return x[None, :] + y[:, None]
@triton.jit
def add_vec_kernel(x_ptr, y_ptr, z_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr):
# Finish me!
off_x = tl.arange(0, B0)
off_y = tl.arange(0, B1)
x = tl.load(x_ptr + off_x)
y = tl.load(y_ptr + off_y)
z = x[None, :] + y[:, None]
# off_z= tl.arange(0, B0 * B1).reshape(B1, B0)
off_z = off_y[:, None] * B0 + off_x[None, :]
tl.store(z_ptr + off_z, z)
return
4.向量外加 分块
在第三题的基础上,B0不等于N0了,需要分块。参考从第一题到第二题的修改思路,核心就是要读取block id,然后根据block id确定block负责的块的起点,加上起点偏移量off_x = tl.arange(0, B0) + block_id_x * B0
x = tl.load(x_ptr + off_x, off_x < N0)并且由于最后一个块可能超过N0,load时需要加上<N0的mask。最后store时也要加上mask,并且二维写入,需要两个维度的mask,(off_x[None, :] < N0) & (off_y[:, None] < N1)
其余逻辑和第三题类似
py
## Puzzle 4: Outer Vector Add Block
Add a row vector to a column vector.
Uses two program block axes. Block size `B0` is always less than the vector `x` length `N0`.
Block size `B1` is always less than vector `y` length `N1`.
.. math::
z_{j, i} = x_i + y_j\text{ for } i = 1\ldots N_0,\ j = 1\ldots N_1
"""
def add_vec_block_spec(x: Float32[100,], y: Float32[90,]) -> Float32[90, 100]:
return x[None, :] + y[:, None]
@triton.jit
def add_vec_block_kernel(
x_ptr, y_ptr, z_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr
):
block_id_x = tl.program_id(0)
block_id_y = tl.program_id(1)
off_x = tl.arange(0, B0) + block_id_x * B0
off_y = tl.arange(0, B1) + block_id_y * B1
x = tl.load(x_ptr + off_x, off_x < N0)
y = tl.load(y_ptr + off_y, off_y < N1)
z = x[None, :] + y[:, None]
off_z = off_y[:, None] * N0 + off_x[None, :]
tl.store(z_ptr + off_z, z, (off_x[None, :] < N0) & (off_y[:, None] < N1))
# Finish me!
return
5.向量外积 融合ReLU
把第四题的加法改成乘法,并对结果逐元素做ReLU。ReLU是个简单的操作,可以和外积融合,不保存中间结果。
就是在第四题的基础上,把加法改成乘法,然后关键是ReLU的规则是大于零不变,小于0变成0,这里需要一个if-else。但和大部分GPU编程DSL一样,不支持直接写if-else语句,需要用z = tl.where(z > 0, z, 0.0),第一个参数是bool表达式,如果真则赋值为第二个参数,假则赋值为第三个参数,二三个参数可以是标量,广播到张量。
py
r"""
## Puzzle 5: Fused Outer Multiplication
Multiply a row vector to a column vector and take a relu.
Uses two program block axes. Block size `B0` is always less than the vector `x` length `N0`.
Block size `B1` is always less than vector `y` length `N1`.
.. math::
z_{j, i} = \text{relu}(x_i \times y_j)\text{ for } i = 1\ldots N_0,\ j = 1\ldots N_1
"""
def mul_relu_block_spec(x: Float32[100,], y: Float32[90,]) -> Float32[90, 100]:
return torch.relu(x[None, :] * y[:, None])
@triton.jit
def mul_relu_block_kernel(
x_ptr, y_ptr, z_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr
):
block_id_x = tl.program_id(0)
block_id_y = tl.program_id(1)
off_x = tl.arange(0, B0) + block_id_x * B0
off_y = tl.arange(0, B1) + block_id_y * B1
x = tl.load(x_ptr + off_x, off_x < N0)
y = tl.load(y_ptr + off_y, off_y < N1)
z = x[None, :] * y[:, None]
z = tl.where(z > 0, z, 0.0)
off_z = off_y[:, None] * N0 + off_x[None, :]
tl.store(z_ptr + off_z, z, (off_x[None, :] < N0) & (off_y[:, None] < N1))
# Finish me!
return
6.向量矩阵乘 融合ReLU 反向传播
在第五题的基础上,把向量相乘,改成向量乘矩阵,也是类似的触发广播。
并且已知上游loss的梯度,求x对loss的梯度,这我们的表达式是relu(xji∗yj)relu(x_{ji} * y_j)relu(xji∗yj)
对x求偏导是
relu′(xji∗yj)∗yjrelu'(x_{ji} * y_j) * y_jrelu′(xji∗yj)∗yj
而relu(x)relu(x)relu(x)的导数是x大于0则为1,否则为0。
所以我们需要
- 先计算xji∗yj)x_{ji} * y_j)xji∗yj)
- 判断正负,确定relu的导数
- 乘上y
- 乘上上游的梯度dz
就是x的梯度dx了。
这里写法规范一下。计算range阶段,就计算出全部的range,计算mask阶段就计算出全部mask。注意range和mask这次由于要读dz这个二维张量的block,需要二维mask和二维offset,mask_ji = (mask_i[None, :] & mask_j[:, None])
计算矩阵向量乘,向量y需要广播z = x * y[:, None]
计算relu导数,需要ifelsez = tl.where(z > 0, 1.0, 0.0)
最后乘上上游梯度,y,z = z * dz * y[:, None]这里y也一样要广播
tl.store(dx_ptr + off_ji, z, mask_ji)最后用我们准备的二维mask和二维offset写入
py
r"""
## Puzzle 6: Fused Outer Multiplication - Backwards
Backwards of a function that multiplies a matrix with a row vector and take a relu.
Uses two program blocks. Block size `B0` is always less than the vector `x` length `N0`.
Block size `B1` is always less than vector `y` length `N1`. Chain rule backward `dz`
is of shape `N1` by `N0`
.. math::
f(x, y) = \text{relu}(x_{j, i} \times y_j)\text{ for } i = 1\ldots N_0,\ j = 1\ldots N_1
.. math::
dx_{j, i} = f_x'(x, y)_{j, i} \times dz_{j, i}
"""
def mul_relu_block_back_spec(
x: Float32[90, 100], y: Float32[90,], dz: Float32[90, 100]
) -> Float32[90, 100]:
x = x.clone().detach().requires_grad_(True)
y = y.clone().detach().requires_grad_(True)
z = torch.relu(x * y[:, None])
z.backward(dz)
dx = x.grad
return dx
@triton.jit
def mul_relu_block_back_kernel(
x_ptr, y_ptr, dz_ptr, dx_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr
):
block_id_i = tl.program_id(0)
block_id_j = tl.program_id(1)
off_i = tl.arange(0, B0) + block_id_i * B0
off_j = tl.arange(0, B1) + block_id_j * B1
off_ji = off_j[:, None] * N0 + off_i[None, :]
mask_i = off_i < N0
mask_j = off_j < N1
mask_ji = (mask_i[None, :] & mask_j[:, None])
x = tl.load(x_ptr + off_ji, mask_ji)
dz = tl.load(dz_ptr + off_ji, mask_ji)
y = tl.load(y_ptr + off_j, mask_j)
z = x * y[:, None]
z = tl.where(z > 0, 1.0, 0.0)
z = z * dz * y[:, None]
tl.store(dx_ptr + off_ji, z, mask_ji)
# Finish me!
return