1. 写在前面
近年来,基于 Transformer 架构的大型语言模型 (LLM) 在自然语言处理 (NLP) 领域取得了巨大的成功。Transformer 的核心组件是自注意力 (Self-Attention) 机制,它允许模型捕捉输入序列中不同位置之间的关系。然而,标准的自注意力机制的计算复杂度与序列长度的平方成正比,这使得它在处理长序列时效率低下。
为了解决这个问题,Flash Attention 被提出,它是一种高效的注意力算法,通过利用现代 GPU 的特性,显著降低了计算复杂度和内存占用。本文将深入 Llama 源码,分析 Flash Attention 的实现逻辑,并与标准的自注意力机制进行比较。
2. Self-Attention 回顾
Self-Attention 的核心思想是:对于输入序列中的每个 token,都计算它与其他所有 token 之间的相关性,并根据这些相关性对所有 token 的表示进行加权求和,得到该 token 的新的表示。
标准的 Self-Attention 计算过程如下:
-
线性变换 : 将输入序列的每个 token 的 embedding 向量
x_i
通过三个线性变换矩阵W_q
,W_k
,W_v
映射成三个向量:q_i
(query),k_i
(key),v_i
(value)。python# 假设 embedding_dim = 512, seq_len = 1024 import torch x = torch.randn(1, 1024, 512) # batch_size=1, seq_len=1024, embedding_dim=512 W_q = torch.randn(512, 512) W_k = torch.randn(512, 512) W_v = torch.randn(512, 512) q = x @ W_q # (1, 1024, 512) k = x @ W_k # (1, 1024, 512) v = x @ W_v # (1, 1024, 512)
-
计算注意力分数 : 计算每个 query
q_i
与所有 keyk_j
之间的点积,得到注意力分数s_ij
。pythons = q @ k.transpose(-2, -1) # (1, 1024, 1024)
-
缩放和掩码 : 对注意力分数进行缩放 (除以
sqrt(d_k)
,d_k
是 key 向量的维度),并应用掩码 (mask) 操作 (例如,在解码器中屏蔽未来 token)。pythonimport math d_k = k.shape[-1] s = s / math.sqrt(d_k) # 假设我们不需要 mask
-
Softmax : 对缩放后的注意力分数应用 softmax 函数,得到注意力权重
a_ij
。pythona = torch.softmax(s, dim=-1) # (1, 1024, 1024)
-
加权求和 : 使用注意力权重
a_ij
对所有 value 向量v_j
进行加权求和,得到每个 token 的新的表示y_i
。pythony = a @ v # (1, 1024, 512)
问题 : 上述计算过程中,s
和 a
这两个矩阵的大小都是 (seq_len, seq_len)
,当 seq_len
很大时 (例如 4096),这两个矩阵会占用大量的显存,并且计算 softmax 和矩阵乘法也非常耗时。
3. Flash Attention 原理
Flash Attention 的核心思想是:避免将整个注意力矩阵 s
和 a
存储在 GPU 的高速缓存 (HBM) 中,而是将输入数据分块 (tiling),每次只加载一小部分数据到 SRAM 中进行计算,并将结果写回 HBM。
Flash Attention 主要利用了以下两个技术:
3.1 Tiling (分块)
将 Q, K, V 矩阵分成多个 block,每次只计算一个 block 的注意力。例如,可以将一个 (1024, 512)
的矩阵分成 16 个 (256, 512)
的 block。
3.2 Recomputation (重计算)
在反向传播时,不存储中间的注意力权重 a
,而是在需要的时候重新计算。由于计算 a
的开销相对较小,这种方法可以节省大量的显存。
4. Llama 中 Flash Attention 的实现
Llama 使用了 Flash Attention 的改进版本,即 Paged Attention 。其核心思想与 Flash Attention 相同,但在处理长序列时更加高效。这里以llama2源码为例说明,其位于llama/model.py
文件中,class Attention(nn.Module)
类下的forward
函数中
以下是 Llama 源码中 Flash Attention 的简化版实现 (已去除部分细节):
python
def flash_attention(q, k, v, block_size):
"""
简化版的 Flash Attention 实现.
Args:
q: Query 矩阵 (B, H, N, D_head)
k: Key 矩阵 (B, H, N, D_head)
v: Value 矩阵 (B, H, N, D_head)
block_size: 分块大小
Returns:
输出矩阵 (B, H, N, D_head)
"""
B, H, N, D_head = q.shape
O = torch.zeros_like(q)
for i in range(0, N, block_size):
# 加载当前 block 的数据到 SRAM
qi = q[:, :, i:i + block_size, :]
mi = -float('inf') # 用于记录当前 block 的最大值
li = 0.0 # 用于记录当前 block 的 softmax 的分母
for j in range(0, N, block_size):
# 加载当前 block 的数据到 SRAM
kj = k[:, :, j:j + block_size, :]
vj = v[:, :, j:j + block_size, :]
# 计算注意力分数
sij = torch.einsum('bhnd,bhmd->bhnm', qi, kj) / math.sqrt(D_head)
# 更新最大值和 softmax 的分母
mij = torch.max(sij, dim=-1).values
li_new = torch.exp(mi - mij).unsqueeze(-1) * li + torch.sum(torch.exp(sij - mij.unsqueeze(-1)), dim=-1)
# 更新输出
O[:, :, i:i + block_size, :] = (li / li_new).unsqueeze(-1) * O[:, :, i:i + block_size, :] + \
torch.einsum('bhnm,bhmd->bhnd', torch.exp(sij - mij.unsqueeze(-1)), vj) / li_new.unsqueeze(-1)
mi = torch.max(mi, mij)
li = li_new
return O
# 示例:假设 block_size = 256
q = torch.randn(1, 8, 1024, 64) # batch_size=1, heads=8, seq_len=1024, d_head=64
k = torch.randn(1, 8, 1024, 64)
v = torch.randn(1, 8, 1024, 64)
o = flash_attention(q, k, v, block_size=256)
print(o.shape) # torch.Size([1, 8, 1024, 64])
代码解释:
q
,k
,v
分别表示 query, key, value 矩阵,block_size
表示分块大小。O
是输出矩阵,初始化为全零。- 外层循环遍历 Q 矩阵的 block。
- 内层循环遍历 K, V 矩阵的 block。
sij
计算当前 block 的注意力分数。mi
和li
分别用于记录当前 block 的最大值和 softmax 的分母,以保证数值稳定性。O
使用增量更新的方式计算最终的输出结果。
注意: 上述代码只是 Flash Attention 的简化版实现,实际的 Llama 源码中还包括了 mask, dropout, causal mask 等操作,并且使用了更高效的 CUDA kernel 来加速计算。
简化版实现说明
上面的代码实现了一个简化版本的Flash Attention算法。它通过两个嵌套的循环来处理查询(Q)、键(K)和值(V)矩阵,这些矩阵被分成了多个块(block)。这种分块处理的方式旨在减少计算过程中的内存占用,特别是对于那些拥有大量头的注意力机制(如多头注意力机制)来说,可以显著提高计算效率。下面我们来逐步解释这段代码的核心逻辑:
初始化输出张量 O
O
被初始化为与查询张量q
相同形状的全零张量。这个张量将累积每个块计算的结果,最终形成完整的输出。
外循环:遍历Q的块
- 代码通过
for i in range(0, N, block_size):
循环遍历 Q 矩阵的块。变量i
表示当前处理的块在序列维度上的起始位置。
内循环:遍历K和V的块
- 对于Q中的每个块,代码通过
for j in range(0, N, block_size):
循环遍历 K 和 V 矩阵的块。变量j
表示 K 和 V 当前处理的块在序列维度上的起始位置。
注意力分数的计算
sij = torch.einsum('bhnd,bhmd->bhnm', qi, kj) / math.sqrt(D_head)
计算当前 Q 块和 K 块之间的注意力分数。这里使用了爱因斯坦求和标记法(einsum),这是一种简洁表示张量操作的方式。
更新最大值和softmax的分母
mij = torch.max(sij, dim=-1).values
计算sij
在最后一个维度上的最大值。li_new = torch.exp(mi - mij).unsqueeze(-1) * li + torch.sum(torch.exp(sij - mij.unsqueeze(-1)), dim=-1)
更新softmax的分母。这里使用了数值稳定的技巧,通过减去最大值来避免指数运算产生过大的数值。
更新输出
O[:, :, i:i + block_size, :] = ...
这行代码是整个算法中最关键的部分。它根据当前计算的注意力分数和值(V)来更新输出张量O
的相应块。这里通过加权求和的方式,将之前步骤的结果累加到O
上。
更新 mi
和 li
mi = torch.max(mi, mij)
更新到目前为止遇到的最大值。li = li_new
更新softmax的分母。
总结
这段代码实现了一种高效的注意力机制,通过分块处理和数值稳定的softmax计算,减少了内存占用并提高了计算效率。尽管代码进行了一定的简化以突出核心逻辑,但它捕捉了Flash Attention算法的关键思想。在实际应用中,还需要考虑如何高效地在硬件上实现这些操作,以及如何处理边界情况和性能优化。
5. Flash Attention 与标准 Self-Attention 的比较
特性 | 标准 Self-Attention | Flash Attention |
---|---|---|
计算复杂度 | O(N^2) | O(N) (理论上, 实际取决于分块大小) |
内存占用 | O(N^2) | O(N) (理论上, 实际取决于分块大小) |
速度 | 慢 | 快 |
适用场景 | 短序列 | 长序列 |
实现复杂性 | 简单 | 复杂 |