从零手写 FlashAttention(PyTorch实现 + 原理推导)

本文基于一个最小 PyTorch 示例,手写实现 FlashAttention

的核心计算流程,并详细解释其数值稳定性和分块计算原理。


1. 标准 Attention 回顾

标准 Attention 的计算公式:

Attention(Q,K,V)=softmax(QKT)V Attention(Q,K,V) = softmax(QK^T)V Attention(Q,K,V)=softmax(QKT)V

python 复制代码
import torch

query = torch.randn(1, 12, 10)
key = torch.randn(1, 12, 10)
value = torch.randn(1, 12, 10)

logits = torch.einsum('bqd,bkd->bqk', query, key)
probs = torch.nn.functional.softmax(logits, dim=-1)
softmax_output = torch.einsum('bqk,bkd->bqd', probs, value)

2. FlashAttention 核心思想

FlashAttention 的核心目标:

避免显式存储整个 attention matrix(QK^T)

关键手段:

  • 分块计算(block-wise)
  • 在线 Softmax(online softmax)

3. 数值稳定 Softmax

softmax(xj)=exj−m∑kexk−m,m=max(x) softmax(x_j) = \frac{e^{x_j - m}}{\sum_k e^{x_k - m}}, \quad m = max(x) softmax(xj)=∑kexk−mexj−m,m=max(x)


4. 核心递推

mi=max(mi−1,mij) m_i = max(m_{i-1}, m_{ij}) mi=max(mi−1,mij)

li=li−1emi−1−mi+∑exij−mi l_i = l_{i-1} e^{m_{i-1} - m_i} + \sum e^{x_{ij} - m_i} li=li−1emi−1−mi+∑exij−mi

oi=oi−1emi−1−mi+∑(exij−miVj) o_i = o_{i-1} e^{m_{i-1} - m_i} + \sum (e^{x_{ij} - m_i} V_j) oi=oi−1emi−1−mi+∑(exij−miVj)


🔍 关键细节深入理解

很多人在理解这里时容易卡住:为什么需要对历史的 oi−1o_{i-1}oi−1 做
rescale?

我们一步一步拆解:

1️⃣ oi−1o_{i-1}oi−1 并不是"最终正确的值"

在第 i−1i-1i−1 次循环时:

  • 我们用的是 局部最大值 mi−1m_{i-1}mi−1
  • 所以 softmax 实际是:

exi−1∑exi−1=exi−1−mi−1∑exi−1−mi−1 \frac{e^{x_{i-1}}}{\sum e^{x_{i-1}}} = \frac{e^{x_{i-1} - m_{i-1}}}{\sum e^{x_{i-1} - m_{i-1}}} ∑exi−1exi−1=∑exi−1−mi−1exi−1−mi−1

👉 注意:这里的归一化是 基于局部 block 的尺度


2️⃣ 当进入第 iii 个 block 时发生了什么?

我们得到了新的最大值:

mi=max(mi−1,mij) m_i = max(m_{i-1}, m_{ij}) mi=max(mi−1,mij)

👉 这个 mim_imi 更接近 全局最大值


3️⃣ 问题的本质

此时出现一个不一致:

项目 使用的 max


oi−1o_{i-1}oi−1 mi−1m_{i-1}mi−1

当前 block mim_imi

👉 如果直接相加,会导致:

不同尺度的指数项被混合(数值错误)


4️⃣ 解决方法:统一尺度(rescale)

我们需要把旧的 oi−1o_{i-1}oi−1 从:

ex−mi−1 e^{x - m_{i-1}} ex−mi−1

转换到:

ex−mi e^{x - m_i} ex−mi

变换方式:

ex−mi−1=ex−mi⋅emi−mi−1 e^{x - m_{i-1}} = e^{x - m_i} \cdot e^{m_i - m_{i-1}} ex−mi−1=ex−mi⋅emi−mi−1

👉 因此:

oi−1→oi−1⋅emi−1−mi o_{i-1} \rightarrow o_{i-1} \cdot e^{m_{i-1} - m_i} oi−1→oi−1⋅emi−1−mi


5️⃣ 对应代码

o_i = o_i_1 * torch.exp(m_i_1 - m_i)[..., None] + torch.einsum('bqk,bkd->bqd', exp_term, v_i)

含义是:

  • 第一项:旧结果 rescale 到新尺度
  • 第二项:当前 block 的贡献

6️⃣ 一个直观理解

可以把整个过程理解为:

我们在不断"修正历史",让所有累积值都统一到"当前最稳定的坐标系(最大值)"下

随着循环进行:

  • mim_imi 会逐步逼近 全局最大值
  • 所有历史贡献都会被重新缩放到这个统一尺度

7️⃣ 最终结果

当所有 block 处理完:

  • mim_imi = 全局最大值
  • oi/lio_i / l_ioi/li = 完整 softmax 结果

5. PyTorch实现

python 复制代码
flash_softmax_outputs = []

q_chunks = 4
q_chunk_size = query.shape[1] // q_chunks

k_chunks = 3
k_chunk_size = key.shape[1] // k_chunks

for i in range(q_chunks):
    q_i = query[:, i*q_chunk_size:(i+1)*q_chunk_size]

    m_i_1 = torch.full((q_i.shape[0], q_i.shape[1]), -float('inf'))
    l_i_1 = torch.zeros_like(m_i_1)
    o_i_1 = torch.zeros((q_i.shape[0], q_i.shape[1], value.shape[-1]))

    for j in range(k_chunks):
        k_i = key[:, j * k_chunk_size: (j+1) * k_chunk_size]     # (B, K_block, D)
        v_i = value[:, j * k_chunk_size: (j+1) * k_chunk_size]   # (B, K_block, Dv)
        logits_i = torch.einsum('nqd,nkd->nqk', q_i, k_i)  # (B, Q_block, K_block)
        # ---- 更新 m ----
        m_ij = torch.max(logits_i, dim=-1)[0]              # (B, Q_block)
        m_i = torch.maximum(m_i_1, m_ij)
        # 计算Softmax分子e^(x_i - m_i)
        exp_term = torch.exp(logits_i - m_i[..., None])    # (B, Q_block, K_block)
        # 更新Softmax分母
        # rescale * 旧的softmax分母 + 新的softmax分母
        l_i = l_i_1 * torch.exp(m_i_1 - m_i) + exp_term.sum(dim=-1)
        # ---- 更新 O(关键!)----
        # rescale * 旧的logit * v + 新的logit * v
        o_i = o_i_1 * torch.exp(m_i_1 - m_i)[..., None] + torch.einsum('nqk,nkd->nqd', exp_term, v_i)
        # ---- 状态更新 ----
        m_i_1 = m_i
        l_i_1 = l_i
        o_i_1 = o_i

    # ---- 最后除以Softmax分母----
    output = o_i / l_i[..., None]
    flash_softmax_outputs.append(output)


flash_softmax_outputs = torch.cat(flash_softmax_outputs, dim=1)

6. 正确性验证

python 复制代码
torch.allclose(softmax_output, flash_softmax_outputs)

7. 总结

FlashAttention 本质:

  • 分块计算
  • 在线 softmax
  • 动态重标定(rescale)

复杂度从 O(N^2) 降到 O(N)

相关推荐
字节跳动数据库1 小时前
数据孤岛难打通、权限怕失控?DBW 助“小龙虾”落地最后一公里
人工智能
俊哥V1 小时前
AI一周事件 · 2026-04-22 至 2026-04-28
人工智能·ai
用户8356290780512 小时前
用 Python 轻松在 Excel 工作表中应用条件格式
后端·python
red1giant_star2 小时前
Python根据文件后缀统计文件大小、找出文件位置(仿Everything)
后端·python
Black蜡笔小新2 小时前
AI大模型训练工作站/私有化本地化AI模型训推工作站DLTM为农业生产装上AI“慧眼”
人工智能·ai大模型
小星AI2 小时前
Claude Code Agent SDK 从入门到精通,一步到位
人工智能·agent·cursor
端平入洛2 小时前
梯度是什么:PyTorch 自动求导详解
人工智能·深度学习
时序之心2 小时前
上海交大、东北大学:时序分类与感知领域的两项前沿突破
人工智能·分类·时间序列
雷欧力2 小时前
如何使用 Claude API?3 种接入方案实测,附完整代码(2026)
python·claude