大模型面试题61:Flash Attention中online softmax(在线softmax)的实现方式


一、小白入门:先搞懂「为什么需要online softmax」

在讲online softmax之前,先明确一个核心问题:传统Attention的softmax到底卡在哪里?

1. 传统Attention的softmax(小白版)

Attention的核心计算可以简化为3步(以单条序列为例):

复制代码
输入:Q(查询)、K(键)、V(值),序列长度=L
步骤1:算相似度分数 → score = Q × K^T / √(维度) (得到L×L的大矩阵)
步骤2:算softmax → softmax(score) = exp(score) / sum(exp(score))
步骤3:加权求和 → output = softmax(score) × V

致命问题 :当序列长度L很大(比如16384),L×L的score矩阵会占满GPU显存(比如FP16精度下,16384×16384=2.68亿个元素,约512MB),再加上exp(score)、sum_exp等中间结果,直接导致显存溢出(OOM)

2. online softmax的核心思想(类比理解)

online softmax的本质是:不一次性计算所有score,而是把K/V切成小「块(tile)」,分批次计算,边算边更新关键值(最大值、exp和),全程不存储完整的score矩阵

用生活类比:

  • 传统softmax:算1000笔账单的总和,先把所有账单都写在纸上(占满整张纸),再求和;
  • online softmax:算1000笔账单,每次只拿10笔算,算完就更新「当前总和」,扔掉这10笔的纸,只记总和,最后得到结果(省纸=省显存)。

3. online softmax的小白版核心逻辑

不用记公式,只需要记住:

  • 核心目标:避免存储完整的L×L score矩阵
  • 核心手段:分块计算+边算边更更新全局最大值(max)和exp和(sum_exp)
  • 核心优势:显存占用从O(L²)降到O(L),彻底解决长序列显存爆炸问题。

二、基础进阶:online softmax的核心实现步骤(懂基础编程即可)

先明确2个关键前提(承上启下):

  • 分块规则:把K和V切成固定大小的「tile」(比如128个token为1块),记为K₁、K₂...Kₙ,V₁、V₂...Vₙ;
  • 核心变量:max_so_far(全局最大score值,防止exp溢出)、sum_exp_so_far(全局exp和)、partial_output(分块计算的加权V和)。

1. 传统softmax vs online softmax(对比表)

维度 传统softmax online softmax(Flash Attention)
显存占用 O(L²)(存储完整score矩阵) O(1)(只存分块数据+3个核心变量)
计算方式 一次性计算所有score 分块计算,边算边更新核心变量
中间结果 存储完整的exp(score) 不存储,算完分块就丢弃

2. online softmax的核心步骤(伪代码+逐行解释)

下面是小白能看懂的online softmax伪代码(简化版,对应Flash Attention的正向计算):

python 复制代码
import math

def online_softmax_attention(Q, K, V, tile_size=128):
    """
    Flash Attention的online softmax核心实现(简化版)
    参数:
        Q: [L_q, d]  查询矩阵(L_q是查询序列长度)
        K: [L_k, d]  键矩阵(L_k是键序列长度,通常L_q=L_k)
        V: [L_k, d]  值矩阵
        tile_size: 分块大小(Flash Attention默认128/256)
    返回:
        output: [L_q, d]  Attention输出
    """
    L_q, d = Q.shape
    L_k = K.shape[0]
    output = torch.zeros_like(Q)  # 最终输出
    
    # 初始化online softmax的核心变量
    max_so_far = torch.full((L_q,), -float("inf"))  # 全局max(初始为负无穷)
    sum_exp_so_far = torch.zeros((L_q,))            # 全局exp和(初始为0)
    partial_output = torch.zeros_like(Q)            # 分块累加的V加权和

    # 分块处理K/V(核心:逐块计算,不存完整score)
    for i in range(0, L_k, tile_size):
        # 1. 取当前块的K和V
        K_tile = K[i:i+tile_size]  # [tile_size, d]
        V_tile = V[i:i+tile_size]  # [tile_size, d]
        
        # 2. 计算当前块的score(只算当前块,不是完整L_q×L_k)
        score_tile = torch.matmul(Q, K_tile.T) / math.sqrt(d)  # [L_q, tile_size]
        
        # 3. 更新全局max(关键:取当前块max和历史max的较大值)
        local_max = torch.max(score_tile, dim=-1).values  # [L_q,]
        new_max = torch.maximum(max_so_far, local_max)    # 全局max更新
        
        # 4. 数值稳定性:缩放历史值(适配新的全局max)
        # 原理:exp(a - new_max) = exp(a - old_max) × exp(old_max - new_max)
        exp_diff = torch.exp(max_so_far - new_max)  # 缩放系数
        sum_exp_so_far = sum_exp_so_far * exp_diff  # 历史sum_exp缩放
        partial_output = partial_output * exp_diff  # 历史加权V和缩放
        
        # 5. 计算当前块的exp(score - new_max)(避免exp溢出)
        exp_score_tile = torch.exp(score_tile - new_max.unsqueeze(-1))  # [L_q, tile_size]
        
        # 6. 更新全局sum_exp和partial_output
        sum_exp_so_far += torch.sum(exp_score_tile, dim=-1)  # 累加exp和
        partial_output += torch.matmul(exp_score_tile, V_tile)  # 累加加权V和
        
        # 7. 更新全局max为新值
        max_so_far = new_max

    # 8. 最终归一化(用全局sum_exp除以partial_output)
    output = partial_output / sum_exp_so_far.unsqueeze(-1)
    return output

3. 关键步骤解释(小白必看)

  • 步骤3(更新全局max):每次只算当前块的max,和历史max比,保留更大的那个------这是为了避免exp(score)溢出(比如score=100,exp(100)会变成无穷大);
  • 步骤4(缩放历史值):因为全局max更新了,之前算的exp值需要适配新max,否则会导致softmax结果错误;
  • 全程无完整score矩阵:所有计算都是基于「当前块」,显存里永远只存一个tile_size大小的score_tile,而不是L×L的大矩阵。

三、深度进阶:online softmax的底层实现+核心难点

Flash Attention的online softmax不是简单的分块计算,而是深度结合GPU硬件架构的优化,核心难点也集中在「硬件适配+数值稳定性+反向传播」。

1. online softmax的底层实现(CUDA核函数级)

Flash Attention的online softmax是用CUDA核函数手写的(不是PyTorch原生操作),核心优化点:

  • 双Pass策略
    • Pass 1:只遍历所有K/V块,计算全局max_so_far和sum_exp_so_far(不计算output);
    • Pass 2:再次遍历所有K/V块,用Pass 1得到的全局max/sum_exp,计算最终的softmax权重和output;
    • 优势:避免Pass 1和Pass 2的冗余计算,最大化GPU并行效率。
  • GPU内存层级优化
    • 把K/V tile加载到GPU的「共享内存(Shared Memory)」(速度比全局内存快100倍+);
    • 避免「共享内存银行冲突」(bank conflict):调整tile大小和数据排布,让不同线程访问不同的内存bank;
    • 利用GPU的warp(线程束)并行计算:每个warp处理一个小tile的score计算,最大化硬件利用率。
  • 异步内存拷贝:计算当前tile的同时,异步加载下一个tile到共享内存,隐藏内存拷贝的延迟。

2. online softmax的核心难点(进阶重点)

Flash Attention的online softmax看似简单,实则是数值计算+硬件优化的双重难点,也是它比传统实现难的核心原因:

难点1:数值稳定性(最大的坑)
  • 问题:当全局max更新时,exp(max_so_far - new_max) 可能因为差值太大(比如max_so_far=-10,new_max=100),导致exp(-110)≈0,出现「数值下溢」(结果变成0,丢失精度);
  • 解决:
    • 限制max的更新幅度:避免单次max更新过大;
    • 用log空间计算:把sum_exp_so_far存在log空间,减少下溢风险;
    • 精度补偿:对极小值做舍入修正,防止梯度消失。
难点2:GPU硬件适配(性能的关键)
  • 问题:online softmax的性能完全依赖GPU内存层级的利用效率,稍有不慎就会比传统实现还慢;
  • 解决:
    • 适配不同GPU架构(A100/H100):调整tile大小(比如A100用128,H100用256);
    • 线程块(thread block)划分:让每个thread block处理的tile大小匹配GPU的SM(流多处理器)数量;
    • 避免全局内存访问:尽可能把数据留在共享内存,减少慢的全局内存读写。
难点3:反向传播的online计算
  • 问题:正向是online计算(不存完整score矩阵),反向求梯度时,也不能存储正向的所有中间结果,必须「在线算梯度」;
  • 解决:
    • 推导反向的online公式:梯度计算也分块进行,边算边更新梯度的中间值;
    • 复用正向的max/sum_exp:反向时不需要重新计算全局max,直接用正向存储的结果;
    • 平衡精度和速度:反向的online计算容易丢失精度,需要精细调整分块策略。
难点4:分块边界处理
  • 问题:当序列长度不是tile_size的整数倍(比如L=16385,tile_size=128),最后一个块的大小不足tile_size,会导致线程束空闲(GPU线程是按32/64个一组执行的);
  • 解决:
    • 补零对齐:对最后一个块补零到tile_size,计算后忽略补零部分;
    • 动态tile大小:根据剩余长度调整最后一个块的tile_size,匹配GPU warp大小。

总结

核心关键点回顾

  1. online softmax的核心:把K/V分块,边算边更新全局max和sum_exp,不存储完整的L×L score矩阵,显存占用从O(L²)降到O(L);
  2. 实现的核心逻辑:双Pass策略(先算全局max/sum_exp,再算output)+ 数值稳定性缩放(适配更新后的max);
  3. 核心难点:数值稳定性(防止exp溢出/下溢)、GPU内存层级优化(共享内存利用)、反向传播的online计算、分块边界的硬件适配。

小白到进阶的核心认知

  • 小白级:online softmax=分块算softmax,省显存;
  • 进阶级:online softmax=分块计算+数值稳定+GPU硬件深度优化;
  • 专家级:online softmax的性能上限取决于「数值精度」和「GPU内存访问效率」的平衡。
相关推荐
哥布林学者2 小时前
吴恩达深度学习课程五:自然语言处理 第一周:循环神经网络 (七)双向 RNN 与深层 RNN
深度学习·ai
阿部多瑞 ABU2 小时前
`chenmo` —— 可编程元叙事引擎 V2.3+
linux·人工智能·python·ai写作
极海拾贝2 小时前
GeoScene解决方案中心正式上线!
大数据·人工智能·深度学习·arcgis·信息可视化·语言模型·解决方案
知乎的哥廷根数学学派3 小时前
基于生成对抗U-Net混合架构的隧道衬砌缺陷地质雷达数据智能反演与成像方法(以模拟信号为例,Pytorch)
开发语言·人工智能·pytorch·python·深度学习·机器学习
小和尚同志3 小时前
又来学习提示词啦~13.9k star 的系统提示词集合
人工智能·aigc
昨夜见军贴06163 小时前
IACheck × AI审核重构检测方式:破解工业检测报告频繁返工的根本难题
人工智能·重构
知乎的哥廷根数学学派3 小时前
基于自适应多尺度小波核编码与注意力增强的脉冲神经网络机械故障诊断(Pytorch)
人工智能·pytorch·python·深度学习·神经网络·机器学习
好奇龙猫4 小时前
【AI学习-comfyUI学习-三十二节-FLXU原生态反推+controlnet depth(UNion)工作流-各个部分学习】
人工智能·学习