Flash Attention反向梯度优化显存

1. 背景

前面我有文章介绍子Flash Attention 针对长序列的正向优化,而其反向算子(Backward Pass)的优化由于涉及到复杂的梯度重计算和显存权衡,往往比正向过程更具挑战性。

在反向算子中,核心目标是计算梯度 dQ,dK, dV。首先看公式, 为了描述简单,其计算流程也简单,系数、mask等先省略。

|--------------------------------------------|------------------------------------------------------------------------------|
| 正向计算公式 | 反向计算公式 |
| | , |
| | , |
| | |

2. Flash Attention分块

类似Flash Attention前向的tiling分块优化,可以优化显存访问量,但反向过程会增加一次S和P的重计算,详解流程图可以看我前面的一篇正向文章。先看看伪代码:

python 复制代码
def flash_attention_backward(Q, K, V, dO, m, l, Bc, Br):
    N, d = Q.shape
    dQ = dK = dV = zeros(N, d)

    # 外层循环:遍历K、V的块(列块)
    for j in range(0, N, Bc):
        Kj = K[j:j+Bc, :]   # [Bc, d]
        Vj = V[j:j+Bc, :]   # [Bc, d]

        # 内层循环:遍历Q、dO的块(行块)
        for i in range(0, N, Br):
            Qi = Q[i:i+Br, :]       # [Br, d]
            dOi = dO[i:i+Br, :]     # [Br, d]
            m_i = m[i:i+Br]         # [Br]
            l_i = l[i:i+Br]         # [Br]

            # ===== 1. 重计算前向 =====
            Sij = Qi @ Kj.T                    # [Br, Bc]
            # 稳定Softmax
            Pij = exp(Sij - m_i.unsqueeze(1)) / l_i.unsqueeze(1)   # [Br, Bc]

            # ===== 2. 计算 dP = dO · V^T =====
            dPij = dOi @ Vj.T                  # [Br, Bc]  (mm1)

            # ===== 3. Softmax反向,计算 dS = P ⊙ (dP - rowsum(P⊙dP)) =====
            rowsum = (Pij * dPij).sum(dim=1, keepdim=True)   # [Br, 1]
            dSij = Pij * (dPij - rowsum)                     # [Br, Bc]

            # ===== 4. 计算梯度贡献并累加 =====
            dQ[i:i+Br, :] += dSij @ Kj          # [Br, d] (mm2的一部分)
            dK[j:j+Bc, :] += Qi.T @ dSij          # [Bc, d] (mm3的一部分)
            dV[j:j+Bc, :] += Pij.T @ dOi          # [Bc, d] (mm4的一部分)

    return dQ, dK, dV

和正向FlashAttention类似,尽量利用好片上SRAM,这样MAC(显存访问量)从传统的优化到,显存占用从优化到, N是序列长度,d是tokens向量的维度。

3. 数学恒等式的简化

其中计算dS时有个技巧,按行先存下Di,这样可以避免重复冗余计算。

利用链式法则,损失对注意力得分S的梯度可以化简为:

,

为了方便工程实现,我们通常记 , 于是公式变为:

4.负载均衡

在反向传播中,由于 Casual Mask(因果掩码)的存在,矩阵是一个下三角矩阵。

问题:如果简单划分 Workload,处理上三角部分的线程块会无事可做,导致 GPU 利用率低下。

优化:动态调整分块任务的分发,确保每个 Streaming Multiprocessor (SM) 处理的有效计算量大致相等。

5. Flash Attention-2 与 v3 的进阶优化
  • FA2 改进:在 FA2 中,反向过程进一步减少了非矩阵乘法(non-matmul)的计算。通过改变循环顺序,使得对 dQ, dK, dV 的更新更加高效,减少了原子加(Atomic Add)的冲突。

  • FA3 (Hopper GPU 优化)

    • Warp-specialization:将 Warp 划分为"生产者"和"消费者",分别负责数据搬运和计算。

    • FP8 精度:针对 H100 等硬件,利用新的 Tensor Core 支持 FP8 训练,大幅提升吞吐量。

6. 总结对比
特性 传统 Attention 反向 Flash Attention 反向
显存占用 (存储整个P矩阵) (仅存储 LSE)
显存IO访问量
IO 瓶颈 高 (频繁读写 HBM) 低 (在 SRAM 中完成计算)
计算量 较低 较高 (通过重计算换取空间)
主要性能限制 内存带宽 算力 (Compute Bound)
相关推荐
dog2501 分钟前
信号权重和流分类的对数规律
人工智能·分类·数据挖掘
道一云黑板报3 分钟前
告别提示词工程:为什么“循环工程”才是 AI 编程的未来?
人工智能·驱动开发·软件工程·ai编程
实在智能RPA3 分钟前
大模型驱动航班规划实战:2026年企业级Agent重塑航空业调度逻辑
人工智能·ai
叫我:松哥4 分钟前
基于Python的共享单车租赁数据分析与预测系统,技术栈flask+boostrap+随机森林+XGBoost
人工智能·python·深度学习·算法·随机森林·数据分析·flask
米小虾15 分钟前
2026年6月AI大模型全景报告:GPT-5.6、Claude Opus 4.8、Gemini 3.5,中美AI三足鼎立谁主沉浮?
人工智能
米小虾17 分钟前
AI Agent从Demo到生产:2026年主流Agent开发框架全景对比与实战选型指南
人工智能·agent
Sam092724 分钟前
Agent 如何节省 Token 成本:从 Prompt 到工程监控的系统化优化指南
人工智能·ai
拓朗工控28 分钟前
边缘计算对工控机性能要求有多高?
人工智能·边缘计算·工控机·工业电脑
2501_9065651230 分钟前
AI辅助开发工具链2026版
人工智能
冬奇Lab32 分钟前
Agent 系列(20):Harness 实战——从单文件到生产级模块包
人工智能·agent