Tillang Puzzles

一个开源仓库https://github.com/tile-ai/tilelang-puzzles/tree/main
给出用tilelang实现经典算子的例子,附带讲解。分为10个puzzle,每个问题都有待补全文件,和参考实现,以及文字讲解。
采用循序渐进的思路,难度逐渐递增,01-05熟悉语法,06-09实现经典算子,10为挑战复杂实战算子
07 Scalar FlashAttention
标准注意力机制的定义
Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})VAttention(Q,K,V)=softmax(dk QKT)V
其中QKV都是多个batch的,也就是还有个维度B。
这里为了简化:
- B=1
- 所有的矩阵乘法变成逐元素相乘
- 移除缩放项dk\sqrt d_kd k
定义
py
for i in range(B):
# 步骤1: 计算 Q*K 并找最大值
MAX = -inf
for j in range(S):
QK[i, j] = Q[i, j] * K[i, j]
MAX = max(QK[i, j], MAX)
# 步骤2: 计算 exp 和求和
SUM = 0
for j in range(S):
P[i, j] = exp(QK[i, j] - MAX)
SUM += P[i, j]
# 步骤3: 归一化并乘以 V
for j in range(S):
O[i, j] = P[i, j] / SUM * V[i, j]
baseline
py
def ref_scalar_flash_attn(
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor
):
return torch.softmax(Q * K, dim=1).mul_(V)
套用online softmax的实现
可以发现移除之后,跟上一节的online softmax已经没多大区别了,区别只有输入A变成了QK,最后结果还需要乘一个V。
直接套上一节的框架即可
QK = T.alloc_fragment((BLOCK_B, BLOCK_S), dtype),保存块内的幂的临时结果,这个优化在前面的online softmax就有,因为只需要知道sum/LSE,不需要保存幂的结果,最后缩放时完全可以现场计算。- 这个优化一方面是因为访存延迟远大于计算延迟,现场计算用时更短;另一方面也是LLM的一个痛点,面对长文本,也就是S维度很大,QKTQK^TQKT矩阵的内存是O(S2)O(S^2)O(S2)的,原始的暴力attention会把这个矩阵写回显存,很容易出现显存不足。而online softmax不需要把整个矩阵的乘法结果保存到显存,只需要在计算是共享内存上申请一些数据块,每个线程块只需要一个小块共享内存,完全是够的,全局内存只需要存输入数据QKV,一般讨论时不考虑batch,设batch=1,那么占用的全局内存就是O(S)O(S)O(S)的,变为线性了。
- 这里是简化的情况,实际的一个序列,假设batch=1,也不是O(S)的,或者说O(L),这里L=S都指的是输入序列长度,或者说token数。由于每个token都是一个词向量,维度为d,实际输入是O(L×d)O(L×d)O(L×d)的,但是在大模型设计里d是常数,所以可以看成O(L)O(L)O(L)。然后实际进行的矩阵乘法QKTQK^TQKT,也不是现在这样的B×L的shape,而是L,d×d,L=L,LL,d×d,L=L,LL,d×d,L=L,L,这就是一般所说的O(L2)O(L^2)O(L2)的来源
py
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
},
)
def tl_scalar_flash_attn(Q, K, V, BLOCK_B: int, BLOCK_S: int):
log2_e = 1.44269504
B, S = T.const("B, S")
dtype = T.float32
Q: T.Tensor((B, S), dtype)
K: T.Tensor((B, S), dtype)
V: T.Tensor((B, S), dtype)
O = T.empty((B, S), dtype)
# TODO: Implement this function
with T.Kernel(B // BLOCK_B, threads=256) as bx:
Q_local = T.alloc_fragment((BLOCK_B, BLOCK_S), dtype)
K_local = T.alloc_fragment((BLOCK_B, BLOCK_S), dtype)
V_local = T.alloc_fragment((BLOCK_B, BLOCK_S), dtype)
O_local = T.alloc_fragment((BLOCK_B, BLOCK_S), dtype)
cur_exp = T.alloc_fragment([BLOCK_B, BLOCK_S], dtype)
cur_max = T.alloc_fragment([BLOCK_B], dtype)
cur_sum_exp = T.alloc_fragment([BLOCK_B], dtype)
QK = T.alloc_fragment((BLOCK_B, BLOCK_S), dtype)
lse=T.alloc_fragment([BLOCK_B], dtype) # log-sum-exp
T.fill(lse, -T.infinity(dtype))
for by in T.Serial(S // BLOCK_S):
T.copy(Q[bx * BLOCK_B, by * BLOCK_S], Q_local)
T.copy(K[bx * BLOCK_B, by * BLOCK_S], K_local)
for i, j in T.Parallel(BLOCK_B, BLOCK_S):
QK[i, j] = Q_local[i, j] * K_local[i, j]
T.reduce_max(QK, cur_max, dim=1, clear=True)
for i, j in T.Parallel(BLOCK_B, BLOCK_S):
cur_exp[i, j] = T.exp2((QK[i, j] - cur_max[i]) * log2_e)
T.reduce_sum(cur_exp, cur_sum_exp, dim=1, clear=True)
for i in T.Parallel(BLOCK_B):
lse[i] = T.log2(T.exp2(lse[i] - cur_max[i] * log2_e) + cur_sum_exp[i]) + cur_max[i] * log2_e
for by in T.Serial(S // BLOCK_S):
T.copy(Q[bx * BLOCK_B, by * BLOCK_S], Q_local)
T.copy(K[bx * BLOCK_B, by * BLOCK_S], K_local)
T.copy(V[bx * BLOCK_B, by * BLOCK_S], V_local)
for i, j in T.Parallel(BLOCK_B, BLOCK_S):
O_local[i, j] = T.exp2(Q_local[i, j] * K_local[i, j] * log2_e - lse[i]) * V_local[i, j]
T.copy(O_local, O[bx * BLOCK_B, by * BLOCK_S])
return O