从零手写 FlashAttention(PyTorch实现 + 原理推导)

本文基于一个最小 PyTorch 示例,手写实现 FlashAttention

的核心计算流程,并详细解释其数值稳定性和分块计算原理。


1. 标准 Attention 回顾

标准 Attention 的计算公式:

Attention(Q,K,V)=softmax(QKT)V Attention(Q,K,V) = softmax(QK^T)V Attention(Q,K,V)=softmax(QKT)V

python 复制代码
import torch

query = torch.randn(1, 12, 10)
key = torch.randn(1, 12, 10)
value = torch.randn(1, 12, 10)

logits = torch.einsum('bqd,bkd->bqk', query, key)
probs = torch.nn.functional.softmax(logits, dim=-1)
softmax_output = torch.einsum('bqk,bkd->bqd', probs, value)

2. FlashAttention 核心思想

FlashAttention 的核心目标:

避免显式存储整个 attention matrix(QK^T)

关键手段:

  • 分块计算(block-wise)
  • 在线 Softmax(online softmax)

3. 数值稳定 Softmax

softmax(xj)=exj−m∑kexk−m,m=max(x) softmax(x_j) = \frac{e^{x_j - m}}{\sum_k e^{x_k - m}}, \quad m = max(x) softmax(xj)=∑kexk−mexj−m,m=max(x)


4. 核心递推

mi=max(mi−1,mij) m_i = max(m_{i-1}, m_{ij}) mi=max(mi−1,mij)

li=li−1emi−1−mi+∑exij−mi l_i = l_{i-1} e^{m_{i-1} - m_i} + \sum e^{x_{ij} - m_i} li=li−1emi−1−mi+∑exij−mi

oi=oi−1emi−1−mi+∑(exij−miVj) o_i = o_{i-1} e^{m_{i-1} - m_i} + \sum (e^{x_{ij} - m_i} V_j) oi=oi−1emi−1−mi+∑(exij−miVj)


🔍 关键细节深入理解

很多人在理解这里时容易卡住:为什么需要对历史的 oi−1o_{i-1}oi−1 做
rescale?

我们一步一步拆解:

1️⃣ oi−1o_{i-1}oi−1 并不是"最终正确的值"

在第 i−1i-1i−1 次循环时:

  • 我们用的是 局部最大值 mi−1m_{i-1}mi−1
  • 所以 softmax 实际是:

exi−1∑exi−1=exi−1−mi−1∑exi−1−mi−1 \frac{e^{x_{i-1}}}{\sum e^{x_{i-1}}} = \frac{e^{x_{i-1} - m_{i-1}}}{\sum e^{x_{i-1} - m_{i-1}}} ∑exi−1exi−1=∑exi−1−mi−1exi−1−mi−1

👉 注意:这里的归一化是 基于局部 block 的尺度


2️⃣ 当进入第 iii 个 block 时发生了什么?

我们得到了新的最大值:

mi=max(mi−1,mij) m_i = max(m_{i-1}, m_{ij}) mi=max(mi−1,mij)

👉 这个 mim_imi 更接近 全局最大值


3️⃣ 问题的本质

此时出现一个不一致:

项目 使用的 max


oi−1o_{i-1}oi−1 mi−1m_{i-1}mi−1

当前 block mim_imi

👉 如果直接相加,会导致:

不同尺度的指数项被混合(数值错误)


4️⃣ 解决方法:统一尺度(rescale)

我们需要把旧的 oi−1o_{i-1}oi−1 从:

ex−mi−1 e^{x - m_{i-1}} ex−mi−1

转换到:

ex−mi e^{x - m_i} ex−mi

变换方式:

ex−mi−1=ex−mi⋅emi−mi−1 e^{x - m_{i-1}} = e^{x - m_i} \cdot e^{m_i - m_{i-1}} ex−mi−1=ex−mi⋅emi−mi−1

👉 因此:

oi−1→oi−1⋅emi−1−mi o_{i-1} \rightarrow o_{i-1} \cdot e^{m_{i-1} - m_i} oi−1→oi−1⋅emi−1−mi


5️⃣ 对应代码

o_i = o_i_1 * torch.exp(m_i_1 - m_i)[..., None] + torch.einsum('bqk,bkd->bqd', exp_term, v_i)

含义是:

  • 第一项:旧结果 rescale 到新尺度
  • 第二项:当前 block 的贡献

6️⃣ 一个直观理解

可以把整个过程理解为:

我们在不断"修正历史",让所有累积值都统一到"当前最稳定的坐标系(最大值)"下

随着循环进行:

  • mim_imi 会逐步逼近 全局最大值
  • 所有历史贡献都会被重新缩放到这个统一尺度

7️⃣ 最终结果

当所有 block 处理完:

  • mim_imi = 全局最大值
  • oi/lio_i / l_ioi/li = 完整 softmax 结果

5. PyTorch实现

python 复制代码
flash_softmax_outputs = []

q_chunks = 4
q_chunk_size = query.shape[1] // q_chunks

k_chunks = 3
k_chunk_size = key.shape[1] // k_chunks

for i in range(q_chunks):
    q_i = query[:, i*q_chunk_size:(i+1)*q_chunk_size]

    m_i_1 = torch.full((q_i.shape[0], q_i.shape[1]), -float('inf'))
    l_i_1 = torch.zeros_like(m_i_1)
    o_i_1 = torch.zeros((q_i.shape[0], q_i.shape[1], value.shape[-1]))

    for j in range(k_chunks):
        k_i = key[:, j * k_chunk_size: (j+1) * k_chunk_size]     # (B, K_block, D)
        v_i = value[:, j * k_chunk_size: (j+1) * k_chunk_size]   # (B, K_block, Dv)
        logits_i = torch.einsum('nqd,nkd->nqk', q_i, k_i)  # (B, Q_block, K_block)
        # ---- 更新 m ----
        m_ij = torch.max(logits_i, dim=-1)[0]              # (B, Q_block)
        m_i = torch.maximum(m_i_1, m_ij)
        # 计算Softmax分子e^(x_i - m_i)
        exp_term = torch.exp(logits_i - m_i[..., None])    # (B, Q_block, K_block)
        # 更新Softmax分母
        # rescale * 旧的softmax分母 + 新的softmax分母
        l_i = l_i_1 * torch.exp(m_i_1 - m_i) + exp_term.sum(dim=-1)
        # ---- 更新 O(关键!)----
        # rescale * 旧的logit * v + 新的logit * v
        o_i = o_i_1 * torch.exp(m_i_1 - m_i)[..., None] + torch.einsum('nqk,nkd->nqd', exp_term, v_i)
        # ---- 状态更新 ----
        m_i_1 = m_i
        l_i_1 = l_i
        o_i_1 = o_i

    # ---- 最后除以Softmax分母----
    output = o_i / l_i[..., None]
    flash_softmax_outputs.append(output)


flash_softmax_outputs = torch.cat(flash_softmax_outputs, dim=1)

6. 正确性验证

python 复制代码
torch.allclose(softmax_output, flash_softmax_outputs)

7. 总结

FlashAttention 本质:

  • 分块计算
  • 在线 softmax
  • 动态重标定(rescale)

复杂度从 O(N^2) 降到 O(N)

相关推荐
renke3364几秒前
写给前端的 CANN-torchtitan-npu:昇腾PyTorch Titan适配到底是啥?
前端·人工智能·pytorch·cann
云烟成雨TD1 分钟前
Spring AI Alibaba 1.x 系列【56】SAA Admin 平台功能介绍
java·人工智能·spring
一勺菠萝丶2 分钟前
常见 AI 模型类型整理:大语言模型、聊天模型、推理模型、Embedding 模型到底有什么区别?
人工智能·语言模型·embedding
多年小白2 分钟前
今日A股 拉
大数据·人工智能·深度学习·microsoft·ai
wujian83113 分钟前
怎么把Kimi里的表格完整复制到wps内
人工智能·ai·wps·豆包·deepseek·ai导出鸭
Joy T4 分钟前
【碳金融】欧盟CBAM逻辑与“磐石·禹衡”系统的技术对冲分析
人工智能·重构·cbam·碳排放·碳核算·磐石
2401_868534784 分钟前
论快速应用开发方法及应用
大数据·python
字节高级特工5 分钟前
C++11(一) 革新:右值引用与移动语义
java·开发语言·c++·人工智能·后端
DO_Community5 分钟前
Token聚合平台 vs 传统云 vs AI原生云,AI推理应用怎么选?
人工智能·agent·token·ai-native·deepseek
郝学胜-神的一滴6 分钟前
系统设计 012:从用户系统出发,吃透缓存、数据库与高并发设计
java·数据库·python·缓存·php·软件构建