【Tilelang入门】Tilelang Puzzles 07

Tillang Puzzles

一个开源仓库https://github.com/tile-ai/tilelang-puzzles/tree/main

给出用tilelang实现经典算子的例子,附带讲解。分为10个puzzle,每个问题都有待补全文件,和参考实现,以及文字讲解。

采用循序渐进的思路,难度逐渐递增,01-05熟悉语法,06-09实现经典算子,10为挑战复杂实战算子

07 Scalar FlashAttention

标准注意力机制的定义

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})VAttention(Q,K,V)=softmax(dk QKT)V

其中QKV都是多个batch的,也就是还有个维度B。

这里为了简化:

  • B=1
  • 所有的矩阵乘法变成逐元素相乘
  • 移除缩放项dk\sqrt d_kd k

定义

py 复制代码
for i in range(B):
    # 步骤1: 计算 Q*K 并找最大值
    MAX = -inf
    for j in range(S):
        QK[i, j] = Q[i, j] * K[i, j]
        MAX = max(QK[i, j], MAX)

    # 步骤2: 计算 exp 和求和
    SUM = 0
    for j in range(S):
        P[i, j] = exp(QK[i, j] - MAX)
        SUM += P[i, j]

    # 步骤3: 归一化并乘以 V
    for j in range(S):
        O[i, j] = P[i, j] / SUM * V[i, j]

baseline

py 复制代码
def ref_scalar_flash_attn(
	Q: torch.Tensor, 
	K: torch.Tensor, 
	V: torch.Tensor
	):
    return torch.softmax(Q * K, dim=1).mul_(V)

套用online softmax的实现

可以发现移除之后,跟上一节的online softmax已经没多大区别了,区别只有输入A变成了QK,最后结果还需要乘一个V。

直接套上一节的框架即可

  • QK = T.alloc_fragment((BLOCK_B, BLOCK_S), dtype),保存块内的幂的临时结果,这个优化在前面的online softmax就有,因为只需要知道sum/LSE,不需要保存幂的结果,最后缩放时完全可以现场计算。
  • 这个优化一方面是因为访存延迟远大于计算延迟,现场计算用时更短;另一方面也是LLM的一个痛点,面对长文本,也就是S维度很大,QKTQK^TQKT矩阵的内存是O(S2)O(S^2)O(S2)的,原始的暴力attention会把这个矩阵写回显存,很容易出现显存不足。而online softmax不需要把整个矩阵的乘法结果保存到显存,只需要在计算是共享内存上申请一些数据块,每个线程块只需要一个小块共享内存,完全是够的,全局内存只需要存输入数据QKV,一般讨论时不考虑batch,设batch=1,那么占用的全局内存就是O(S)O(S)O(S)的,变为线性了。
  • 这里是简化的情况,实际的一个序列,假设batch=1,也不是O(S)的,或者说O(L),这里L=S都指的是输入序列长度,或者说token数。由于每个token都是一个词向量,维度为d,实际输入是O(L×d)O(L×d)O(L×d)的,但是在大模型设计里d是常数,所以可以看成O(L)O(L)O(L)。然后实际进行的矩阵乘法QKTQK^TQKT,也不是现在这样的B×L的shape,而是L,d×d,L=L,LL,d×d,L=L,LL,d×d,L=L,L,这就是一般所说的O(L2)O(L^2)O(L2)的来源
py 复制代码
@tilelang.jit(
    pass_configs={
        tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
        tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
    },
)
def tl_scalar_flash_attn(Q, K, V, BLOCK_B: int, BLOCK_S: int):
    log2_e = 1.44269504
    B, S = T.const("B, S")
    dtype = T.float32
    Q: T.Tensor((B, S), dtype)
    K: T.Tensor((B, S), dtype)
    V: T.Tensor((B, S), dtype)
    O = T.empty((B, S), dtype)

    # TODO: Implement this function
    with T.Kernel(B // BLOCK_B, threads=256) as bx:
        Q_local = T.alloc_fragment((BLOCK_B, BLOCK_S), dtype)
        K_local = T.alloc_fragment((BLOCK_B, BLOCK_S), dtype)
        V_local = T.alloc_fragment((BLOCK_B, BLOCK_S), dtype)
        O_local = T.alloc_fragment((BLOCK_B, BLOCK_S), dtype)

        cur_exp = T.alloc_fragment([BLOCK_B, BLOCK_S], dtype)
        cur_max = T.alloc_fragment([BLOCK_B], dtype)
        cur_sum_exp = T.alloc_fragment([BLOCK_B], dtype)
        QK = T.alloc_fragment((BLOCK_B, BLOCK_S), dtype)

        lse=T.alloc_fragment([BLOCK_B], dtype)  # log-sum-exp

        T.fill(lse, -T.infinity(dtype))

        for by in T.Serial(S // BLOCK_S):
            T.copy(Q[bx * BLOCK_B, by * BLOCK_S], Q_local)
            T.copy(K[bx * BLOCK_B, by * BLOCK_S], K_local)

            for i, j in T.Parallel(BLOCK_B, BLOCK_S):
                QK[i, j] = Q_local[i, j] * K_local[i, j]

            T.reduce_max(QK, cur_max, dim=1, clear=True)

            for i, j in T.Parallel(BLOCK_B, BLOCK_S):
                cur_exp[i, j] = T.exp2((QK[i, j] - cur_max[i]) * log2_e)

            T.reduce_sum(cur_exp, cur_sum_exp, dim=1, clear=True)

            for i in T.Parallel(BLOCK_B):
                lse[i] = T.log2(T.exp2(lse[i] - cur_max[i] * log2_e) + cur_sum_exp[i]) + cur_max[i] * log2_e


        for by in T.Serial(S // BLOCK_S):
            T.copy(Q[bx * BLOCK_B, by * BLOCK_S], Q_local)
            T.copy(K[bx * BLOCK_B, by * BLOCK_S], K_local)
            T.copy(V[bx * BLOCK_B, by * BLOCK_S], V_local)

            for i, j in T.Parallel(BLOCK_B, BLOCK_S):
                O_local[i, j] = T.exp2(Q_local[i, j] * K_local[i, j] * log2_e - lse[i]) * V_local[i, j]

            T.copy(O_local, O[bx * BLOCK_B, by * BLOCK_S])

    return O
相关推荐
一条大祥脚3 小时前
【Tilelang入门】Tilelang Puzzles 08
gemm·tilelang·tilelang-puzzle
一条大祥脚1 天前
【Tilelang入门】Tilelang Puzzles 01-02
tilelang·tilelang-puzzle
西西弗Sisyphus4 天前
构建中文版的 nanoGPT - 断点续训(resume from checkpoint)
transformer·attention·注意力·self-attention·nanogpt
西西弗Sisyphus4 天前
构建中文版的 nanoGPT - 中文版 nanoGPT 的分词(tokenization)
transformer·attention·注意力·self-attention·nanogpt
西西弗Sisyphus19 天前
从零实现Transformer:第 4 部分 - Residual Connection的两种实现 Pre-LN 和 Post-LN
transformer·attention·unsqueeze·self-attention·残差·residual·squeeze
西西弗Sisyphus19 天前
从零实现Transformer:第 9 部分 - 推理(Inference )
transformer·attention·注意力机制·注意力·decoder·self-attention
机器学习之心19 天前
多工况车速数据集训练BiLSTM-Attention用于车速预测,输出未来多个时间步车速,MATLAB代码
matlab·attention·bilstm·车速预测
机器学习之心24 天前
CNN-xLSTM-Attention 回归模型:从原理到 SHAP 可解释性全解析
回归·cnn·attention·cnn-xlstm
庞轩px1 个月前
Transformer的核心思想——Attention机制直观理解
人工智能·rnn·深度学习·transformer·attention·q-k-v