Triton Puzzles (1-6)

简介

这个项目的简介可以看前面这篇文章

https://mp.weixin.qq.com/s/4eom3UmHl4pYjeAJc4L6CA

1.加定值

张量加一个定值。块长B0等于张量长度N0

只启动一个块就行,内核不用读取block id。构造一个长度B0,x_ptr为起点的指针数组,读取x_ptr对应的张量,读出来之后加上标量10.0,自动广播成和x一样的shape,最后写回x_ptr对应张量

py 复制代码
r"""
## Puzzle 1: Constant Add

Add a constant to a vector. Uses one program id axis. 
Block size `B0` is always the same as vector `x` with length `N0`.

.. math::
    z_i = 10 + x_i \text{ for } i = 1\ldots N_0
"""


def add_spec(x: Float32[32,]) -> Float32[32,]:
    "This is the spec that you should implement. Uses typing to define sizes."
    return x + 10.0


@triton.jit
def add_kernel(x_ptr, z_ptr, N0, B0: tl.constexpr):
    # We name the offsets of the pointers as "off_"
    off_x = tl.arange(0, B0)
    x = tl.load(x_ptr + off_x)
    # Finish me!

    x += 10.0
    tl.store(z_ptr + off_x, x)
    return

2.加定值 分块

和第一题的区别在于,B0小于N0,每个block负责的只是张量的一部分。

所以先获取block idpid = tl.program_id(0),根据block id做偏移,找到当前块的起点,构造一个从起点开始,长度B0的指针数组。

由于最后一个块可能超过N0,设置mask不超过N0,剩下操作和第一题一样

py 复制代码
r"""
## Puzzle 2: Constant Add Block

Add a constant to a vector. Uses one program block axis (no `for` loops yet). 
Block size `B0` is now smaller than the shape vector `x` which is `N0`.

.. math::
    z_i = 10 + x_i \text{ for } i = 1\ldots N_0
"""


def add2_spec(x: Float32[200,]) -> Float32[200,]:
    return x + 10.0


@triton.jit
def add_mask2_kernel(x_ptr, z_ptr, N0, B0: tl.constexpr):
    pid = tl.program_id(0)
    off_x = tl.arange(0, B0) + pid * B0
    mask = off_x < N0
    x = tl.load(x_ptr + off_x, mask)
    # Finish me!
    x += 10.0
    tl.store(z_ptr + off_x, x, mask)
    return

3.向量外加法

外加法outer add,指的是两个向量加起来结果是个矩阵,或者说就是两个向量根据广播规则,拓展之后再做加法,比如一个(x,1)一个(1,y),做加法结果是(x,y)。

这里两个向量的块长都等于向量总长度,也就是不用分块。读取两个张量,和前面读取一个张量类似,构造一个指针数组即可,长度分别是B0和B1的。

然后把这两个向量分别在不同维度拓展,然后相加,利用广播规则自动拓展z = x[None, :] + y[:, None],这里None就是增加一个长度1的维度,冒号是这个维度保持不变。此时我们已经得到外加法的结果了,只不过还是在临时张量里的,要写入z_ptr对应张量。

需要构造一个二维的张量数组off_z = off_y[:, None] * B0 + off_x[None, :],这里同样利用广播规则,注意这里y是行,x是列,遵从的不是一般计算机的二维规则,而是平面坐标系的二维规则,横着的是x方向,竖着的是y方向。

先把y增加一个维度变成

复制代码
[[0],
 [1],
 [2]]

然后给乘上x维度的长度,或者说一行的长度B0,这里以B0=3为例

复制代码
[[0],
 [3],
 [6]]

然后x也增加一个维度,变成

复制代码
[[0], [1], [2]]

两个做加法,根据广播规则结果是,这样就构造了一个偏移量数组

复制代码
[[0, 1, 2],
 [3, 4, 5],
 [6, 7, 8]]

最后和z_ptr相加根据广播规则会把z_ptr广播到这个张量所有位置上,就得到了目的地址数组tl.store(z_ptr + off_z, z)

py 复制代码
r"""
## Puzzle 3: Outer Vector Add

Add two vectors.

Uses one program block axis. Block size `B0` is always the same as vector `x` length `N0`.
Block size `B1` is always the same as vector `y` length `N1`.

.. math::
    z_{j, i} = x_i + y_j\text{ for } i = 1\ldots B_0,\ j = 1\ldots B_1
"""


def add_vec_spec(x: Float32[32,], y: Float32[32,]) -> Float32[32, 32]:
    return x[None, :] + y[:, None]


@triton.jit
def add_vec_kernel(x_ptr, y_ptr, z_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr):
    # Finish me!
    off_x = tl.arange(0, B0)
    off_y = tl.arange(0, B1)
    x = tl.load(x_ptr + off_x)
    y = tl.load(y_ptr + off_y)
    z = x[None, :] + y[:, None]

    # off_z= tl.arange(0, B0 * B1).reshape(B1, B0)
    off_z = off_y[:, None] * B0 + off_x[None, :]
    tl.store(z_ptr + off_z, z)
    return

4.向量外加 分块

在第三题的基础上,B0不等于N0了,需要分块。参考从第一题到第二题的修改思路,核心就是要读取block id,然后根据block id确定block负责的块的起点,加上起点偏移量off_x = tl.arange(0, B0) + block_id_x * B0

x = tl.load(x_ptr + off_x, off_x < N0)并且由于最后一个块可能超过N0,load时需要加上<N0的mask。最后store时也要加上mask,并且二维写入,需要两个维度的mask,(off_x[None, :] < N0) & (off_y[:, None] < N1)

其余逻辑和第三题类似

py 复制代码
## Puzzle 4: Outer Vector Add Block

Add a row vector to a column vector.

Uses two program block axes. Block size `B0` is always less than the vector `x` length `N0`.
Block size `B1` is always less than vector `y` length `N1`.

.. math::
    z_{j, i} = x_i + y_j\text{ for } i = 1\ldots N_0,\ j = 1\ldots N_1
"""


def add_vec_block_spec(x: Float32[100,], y: Float32[90,]) -> Float32[90, 100]:
    return x[None, :] + y[:, None]


@triton.jit
def add_vec_block_kernel(
    x_ptr, y_ptr, z_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr
):
    block_id_x = tl.program_id(0)
    block_id_y = tl.program_id(1)

    off_x = tl.arange(0, B0) + block_id_x * B0
    off_y = tl.arange(0, B1) + block_id_y * B1

    x = tl.load(x_ptr + off_x, off_x < N0)
    y = tl.load(y_ptr + off_y, off_y < N1)
    z = x[None, :] + y[:, None]

    off_z = off_y[:, None] * N0 + off_x[None, :]
    tl.store(z_ptr + off_z, z, (off_x[None, :] < N0) & (off_y[:, None] < N1))
    # Finish me!
    return

5.向量外积 融合ReLU

把第四题的加法改成乘法,并对结果逐元素做ReLU。ReLU是个简单的操作,可以和外积融合,不保存中间结果。

就是在第四题的基础上,把加法改成乘法,然后关键是ReLU的规则是大于零不变,小于0变成0,这里需要一个if-else。但和大部分GPU编程DSL一样,不支持直接写if-else语句,需要用z = tl.where(z > 0, z, 0.0),第一个参数是bool表达式,如果真则赋值为第二个参数,假则赋值为第三个参数,二三个参数可以是标量,广播到张量。

py 复制代码
r"""
## Puzzle 5: Fused Outer Multiplication

Multiply a row vector to a column vector and take a relu.

Uses two program block axes. Block size `B0` is always less than the vector `x` length `N0`.
Block size `B1` is always less than vector `y` length `N1`.

.. math::
    z_{j, i} = \text{relu}(x_i \times y_j)\text{ for } i = 1\ldots N_0,\ j = 1\ldots N_1
"""


def mul_relu_block_spec(x: Float32[100,], y: Float32[90,]) -> Float32[90, 100]:
    return torch.relu(x[None, :] * y[:, None])


@triton.jit
def mul_relu_block_kernel(
    x_ptr, y_ptr, z_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr
):
    block_id_x = tl.program_id(0)
    block_id_y = tl.program_id(1)

    off_x = tl.arange(0, B0) + block_id_x * B0
    off_y = tl.arange(0, B1) + block_id_y * B1

    x = tl.load(x_ptr + off_x, off_x < N0)
    y = tl.load(y_ptr + off_y, off_y < N1)
    z = x[None, :] * y[:, None]
    z = tl.where(z > 0, z, 0.0)

    off_z = off_y[:, None] * N0 + off_x[None, :]
    tl.store(z_ptr + off_z, z, (off_x[None, :] < N0) & (off_y[:, None] < N1))
    # Finish me!
    return

6.向量矩阵乘 融合ReLU 反向传播

在第五题的基础上,把向量相乘,改成向量乘矩阵,也是类似的触发广播。

并且已知上游loss的梯度,求x对loss的梯度,这我们的表达式是relu(xji∗yj)relu(x_{ji} * y_j)relu(xji∗yj)

对x求偏导是

relu′(xji∗yj)∗yjrelu'(x_{ji} * y_j) * y_jrelu′(xji∗yj)∗yj

而relu(x)relu(x)relu(x)的导数是x大于0则为1,否则为0。

所以我们需要

  • 先计算xji∗yj)x_{ji} * y_j)xji∗yj)
  • 判断正负,确定relu的导数
  • 乘上y
  • 乘上上游的梯度dz

就是x的梯度dx了。

这里写法规范一下。计算range阶段,就计算出全部的range,计算mask阶段就计算出全部mask。注意range和mask这次由于要读dz这个二维张量的block,需要二维mask和二维offset,mask_ji = (mask_i[None, :] & mask_j[:, None])

计算矩阵向量乘,向量y需要广播z = x * y[:, None]

计算relu导数,需要ifelsez = tl.where(z > 0, 1.0, 0.0)

最后乘上上游梯度,y,z = z * dz * y[:, None]这里y也一样要广播

tl.store(dx_ptr + off_ji, z, mask_ji)最后用我们准备的二维mask和二维offset写入

py 复制代码
r"""
## Puzzle 6: Fused Outer Multiplication - Backwards

Backwards of a function that multiplies a matrix with a row vector and take a relu.

Uses two program blocks. Block size `B0` is always less than the vector `x` length `N0`.
Block size `B1` is always less than vector `y` length `N1`. Chain rule backward `dz`
is of shape `N1` by `N0`

.. math::
    f(x, y) = \text{relu}(x_{j, i} \times y_j)\text{ for } i = 1\ldots N_0,\ j = 1\ldots N_1

.. math::
    dx_{j, i} = f_x'(x, y)_{j, i} \times dz_{j, i}
"""


def mul_relu_block_back_spec(
    x: Float32[90, 100], y: Float32[90,], dz: Float32[90, 100]
) -> Float32[90, 100]:
    x = x.clone().detach().requires_grad_(True)
    y = y.clone().detach().requires_grad_(True)
    z = torch.relu(x * y[:, None])
    z.backward(dz)
    dx = x.grad
    return dx


@triton.jit
def mul_relu_block_back_kernel(
    x_ptr, y_ptr, dz_ptr, dx_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr
):
    block_id_i = tl.program_id(0)
    block_id_j = tl.program_id(1)

    off_i = tl.arange(0, B0) + block_id_i * B0
    off_j = tl.arange(0, B1) + block_id_j * B1
    off_ji = off_j[:, None] * N0 + off_i[None, :]

    mask_i = off_i < N0
    mask_j = off_j < N1
    mask_ji = (mask_i[None, :] & mask_j[:, None])


    x = tl.load(x_ptr + off_ji, mask_ji)
    dz = tl.load(dz_ptr + off_ji, mask_ji)

    y = tl.load(y_ptr + off_j, mask_j)
    z = x * y[:, None]

    z = tl.where(z > 0, 1.0, 0.0)
    z = z * dz * y[:, None]

    tl.store(dx_ptr + off_ji, z, mask_ji)
    # Finish me!
    return