本文基于一个最小 PyTorch 示例,手写实现 FlashAttention
的核心计算流程,并详细解释其数值稳定性和分块计算原理。
1. 标准 Attention 回顾
标准 Attention 的计算公式:
Attention(Q,K,V)=softmax(QKT)V Attention(Q,K,V) = softmax(QK^T)V Attention(Q,K,V)=softmax(QKT)V
python
import torch
query = torch.randn(1, 12, 10)
key = torch.randn(1, 12, 10)
value = torch.randn(1, 12, 10)
logits = torch.einsum('bqd,bkd->bqk', query, key)
probs = torch.nn.functional.softmax(logits, dim=-1)
softmax_output = torch.einsum('bqk,bkd->bqd', probs, value)
2. FlashAttention 核心思想
FlashAttention 的核心目标:
避免显式存储整个 attention matrix(QK^T)
关键手段:
- 分块计算(block-wise)
- 在线 Softmax(online softmax)
3. 数值稳定 Softmax
softmax(xj)=exj−m∑kexk−m,m=max(x) softmax(x_j) = \frac{e^{x_j - m}}{\sum_k e^{x_k - m}}, \quad m = max(x) softmax(xj)=∑kexk−mexj−m,m=max(x)
4. 核心递推
mi=max(mi−1,mij) m_i = max(m_{i-1}, m_{ij}) mi=max(mi−1,mij)
li=li−1emi−1−mi+∑exij−mi l_i = l_{i-1} e^{m_{i-1} - m_i} + \sum e^{x_{ij} - m_i} li=li−1emi−1−mi+∑exij−mi
oi=oi−1emi−1−mi+∑(exij−miVj) o_i = o_{i-1} e^{m_{i-1} - m_i} + \sum (e^{x_{ij} - m_i} V_j) oi=oi−1emi−1−mi+∑(exij−miVj)
🔍 关键细节深入理解
很多人在理解这里时容易卡住:为什么需要对历史的 oi−1o_{i-1}oi−1 做
rescale?
我们一步一步拆解:
1️⃣ oi−1o_{i-1}oi−1 并不是"最终正确的值"
在第 i−1i-1i−1 次循环时:
- 我们用的是 局部最大值 mi−1m_{i-1}mi−1
- 所以 softmax 实际是:
exi−1∑exi−1=exi−1−mi−1∑exi−1−mi−1 \frac{e^{x_{i-1}}}{\sum e^{x_{i-1}}} = \frac{e^{x_{i-1} - m_{i-1}}}{\sum e^{x_{i-1} - m_{i-1}}} ∑exi−1exi−1=∑exi−1−mi−1exi−1−mi−1
👉 注意:这里的归一化是 基于局部 block 的尺度
2️⃣ 当进入第 iii 个 block 时发生了什么?
我们得到了新的最大值:
mi=max(mi−1,mij) m_i = max(m_{i-1}, m_{ij}) mi=max(mi−1,mij)
👉 这个 mim_imi 更接近 全局最大值
3️⃣ 问题的本质
此时出现一个不一致:
项目 使用的 max
oi−1o_{i-1}oi−1 mi−1m_{i-1}mi−1
当前 block mim_imi
👉 如果直接相加,会导致:
不同尺度的指数项被混合(数值错误)
4️⃣ 解决方法:统一尺度(rescale)
我们需要把旧的 oi−1o_{i-1}oi−1 从:
ex−mi−1 e^{x - m_{i-1}} ex−mi−1
转换到:
ex−mi e^{x - m_i} ex−mi
变换方式:
ex−mi−1=ex−mi⋅emi−mi−1 e^{x - m_{i-1}} = e^{x - m_i} \cdot e^{m_i - m_{i-1}} ex−mi−1=ex−mi⋅emi−mi−1
👉 因此:
oi−1→oi−1⋅emi−1−mi o_{i-1} \rightarrow o_{i-1} \cdot e^{m_{i-1} - m_i} oi−1→oi−1⋅emi−1−mi
5️⃣ 对应代码
o_i = o_i_1 * torch.exp(m_i_1 - m_i)[..., None] + torch.einsum('bqk,bkd->bqd', exp_term, v_i)
含义是:
- 第一项:旧结果 rescale 到新尺度
- 第二项:当前 block 的贡献
6️⃣ 一个直观理解
可以把整个过程理解为:
我们在不断"修正历史",让所有累积值都统一到"当前最稳定的坐标系(最大值)"下
随着循环进行:
- mim_imi 会逐步逼近 全局最大值
- 所有历史贡献都会被重新缩放到这个统一尺度
7️⃣ 最终结果
当所有 block 处理完:
- mim_imi = 全局最大值
- oi/lio_i / l_ioi/li = 完整 softmax 结果
5. PyTorch实现
python
flash_softmax_outputs = []
q_chunks = 4
q_chunk_size = query.shape[1] // q_chunks
k_chunks = 3
k_chunk_size = key.shape[1] // k_chunks
for i in range(q_chunks):
q_i = query[:, i*q_chunk_size:(i+1)*q_chunk_size]
m_i_1 = torch.full((q_i.shape[0], q_i.shape[1]), -float('inf'))
l_i_1 = torch.zeros_like(m_i_1)
o_i_1 = torch.zeros((q_i.shape[0], q_i.shape[1], value.shape[-1]))
for j in range(k_chunks):
k_i = key[:, j * k_chunk_size: (j+1) * k_chunk_size] # (B, K_block, D)
v_i = value[:, j * k_chunk_size: (j+1) * k_chunk_size] # (B, K_block, Dv)
logits_i = torch.einsum('nqd,nkd->nqk', q_i, k_i) # (B, Q_block, K_block)
# ---- 更新 m ----
m_ij = torch.max(logits_i, dim=-1)[0] # (B, Q_block)
m_i = torch.maximum(m_i_1, m_ij)
# 计算Softmax分子e^(x_i - m_i)
exp_term = torch.exp(logits_i - m_i[..., None]) # (B, Q_block, K_block)
# 更新Softmax分母
# rescale * 旧的softmax分母 + 新的softmax分母
l_i = l_i_1 * torch.exp(m_i_1 - m_i) + exp_term.sum(dim=-1)
# ---- 更新 O(关键!)----
# rescale * 旧的logit * v + 新的logit * v
o_i = o_i_1 * torch.exp(m_i_1 - m_i)[..., None] + torch.einsum('nqk,nkd->nqd', exp_term, v_i)
# ---- 状态更新 ----
m_i_1 = m_i
l_i_1 = l_i
o_i_1 = o_i
# ---- 最后除以Softmax分母----
output = o_i / l_i[..., None]
flash_softmax_outputs.append(output)
flash_softmax_outputs = torch.cat(flash_softmax_outputs, dim=1)
6. 正确性验证
python
torch.allclose(softmax_output, flash_softmax_outputs)
7. 总结
FlashAttention 本质:
- 分块计算
- 在线 softmax
- 动态重标定(rescale)
复杂度从 O(N^2) 降到 O(N)