flash-attention连环问答--softmax 、safe softmax 、online softmax

1 Self Attention是怎么计算的,写出对应的公式

其中 Q,K,V 都是 N * dk 的 2D 矩阵,N 为序列长度,dk 为头的维度。

2 softmax 公式的计算有什么问题,在工程实现的时候怎么做的?

因为包含了幂指数计算,所以它有一个明显的问题:数值溢出。

对于大模型常用的半精度 fp16 来说,最大值也才 65536,所以当 xi 大于 11 的时候,e 的 12 次方等于 162754.7914 ,大于 65536。所以实际工程实现相对于原生的 softmax,它先要减去一个 max 的值,确保计算过程中不会导致数值溢出,
如下图:由于 xi-m ≤0 所以不会出现溢出,这种实现方案也叫 safe-softmax。

所以总结一下,对于这个问题,我们要沿着面试官的心理,首先答出标准 softmax 公式导致的问题:数值溢出。再回答实际工程的解决方案:safe-softmax,并结合相应的公式进行回答。
此外对于标准的 softmax 计算,需要 3 步,计算最大值 m,计算分母,最后再依此计算分子。

如果不做任何优化的话,它至少要和 GPU 进行 6 次通信(3 次写入,3 次写出)。
数值不稳定性:如果 x很大,ex会变得非常大,可能导致溢出。

计算效率:需要两次遍历数据------一次找最大值(为了数值稳定),一次计算概率。

内存需求:对于大数据集,需要存储所有中间结果。
所以这里我们自然就引出了下一个问题:那你能够降低 softmax 的 GPU 访存复杂度吗?如果可以,怎么做?

3 你能够降低 softmax 的 GPU 访存复杂度吗?如果可以,怎么做?

这个问题希望你回答什么呢?其实就两点,

第一,你知不知道 softmax 可以通过流式计算降低 GPU 访存复杂度。

第二,能否阐述一下流式计算的核心思想。
首先我们明确一点,就是 softmax 是可以做成流式计算的,18 年 NVIDIA 发表的一篇论文,就提出了 online-softmax 算法,下图展示了其核心计算过程。
Online Softmax 的目标是:

  • 单次遍历:只看一遍数据,边看边算。
  • 流式处理:数据一块一块来,随时更新结果。
  • 内存高效:不用把所有数据都存下来。
    Online Softmax 的核心思想

Online Softmax 的核心在于增量更新。我们不一次性处理所有数据,而是每次来一个新数据点,就更新两个关键统计量:

  • 当前最大值m:记录目前见过的最大输入值。
  • 分母的累加和d:Softmax 分母是所有 e x i 的和,我们动态维护它。
    当新数据到来时,我们只需要用已有的统计量和新数据点,更新这两个值,就能保证结果正确。这就像在流水线上加工零件,每来一个零件就更新一下生产线上的状态,不用等所有零件都到齐。

所以最终我们可以借助 GPU 的 share memory 来存储中间结果,将上面的两步只用一个 kernel 实现,这样就只需要与 global memory 通信两次,一次写入数据,一次读取结果。

不过这里要注意,就是由于第二步的计算仍然需要依赖第一步计算的分母 dN,所以还是需要两步,换句话说,不能做成 one pass。

4 既然 softmax 不能做到 one-pass,为什么 Flash Attention 可以,解释一下背后的核心思想?

首先 Flash Attention 能做到 one-pass 计算,其核心思想是 Flash Attention 让 Attention 的所有计算都符合加法结合律,这样就可以充分利用 GPU 的并行优势,这是希望我们答出的第一个点。
虽然单独的 softmax 运算不能做到 one-pass,但是 self-Attention 中的 softmax 求完之后,它的每一项的值会与 V 中向量相乘,然后累加。这里的累加很关键,有了这个累加的操作,所有的计算又符合结合律了,这就是 FlashAttention 并行加速的的理论核心思想。如果没有这个累加,比如单纯的计算 softmax,反而没有办法并行。
所以用类似 Online Softmax 的方法,就可以将 Attention 所有的操作,都放到一个 for 循环里(一个 Kernel 就可以实现)。我们可以推导一下输出的关系:

从上面式子可以看到,oi 只依赖 oi-1, mi, mi-1,所以可以实现递归计算。

更进一步,分析 Flash Attention 计算过程可以发现,Flash Attention 其实并没有减少 Attention 的计算量,也不影响精度,但是却比标准的 Attention 运算快 2-4 倍的运行速度,同时减少了 5~20 倍的内存使用量。

5 详细解释一下 Flash Attention 中的 tiling 策略?

问这个问题,首先是想考察你,

知不知道什么是 tiling,为什么要使用它?

以及使用之后有什么作用?

其次在 Flash Attention 中的 tiling 策略是如何做的,能否说一下它的整个流程以及具体的效果?
tiling 说白了就是对矩阵分块,

分块策略的主要动机,是通过将大矩阵分解为更小的块,以此来减少内存访问的开销,同时提高计算效率。分块策略允许我们在处理大矩阵时,只加载和处理一部分数据,而不是一次性加载整个矩阵,这样可以减少内存带宽的压力。
而具体到 Flash Attention 中,就是将 Q,K,V 分成更多个小块,其中 K,V 在外循环,Q 在内循环。
在计算注意力分数的时候,通常需要进行 softmax 操作。为了避免一次性计算整个 softmax,Flash Attention 会采用局部归一化策略。

对于每个块,我们只计算这个块内部的 softmax,并在累加结果的时候进行适当的归一化。

所以通过逐块计算,减少了全局内存的访问次数,这样就降低了内存带宽的压力。

这种策略特别适用于处理长序列的注意力机制,能够显著加速计算过程。

6 FlashAttention 对 MQA 和 GQA 是怎么处理的?

想考察的,首先是你是否知道 MQA 和 GQA,如果你都不知道这两个概念,这道题目也就无从答起。我们来看这张图。

之前我们讲过,MQA 只保留了一个 KV Head,多个 Query Heads 共享相同的 KV Head。

而 GQA 与 MQA 不同,它采取了折中的做法,GQA 把 Query Heads 进行分组,每组 Query Heads 对应一个 KV Head。

7 那在 Flash Attention 中对 MQA/GQA 是如何处理呢?

这里要听到的一个关键词,就是 Indexing 思想。

对于 MQA 和 GQA,FlashAttention 采用了 Indexing 的方式,而不是直接复制多份 KV Head 的内容到显存然后再进行计算。

Indexing 的思想,就是通过传入 KV Head 索引到 GPU Kernel 中,然后根据内存地址,直接从内存中读取 KV。

相关推荐
skywalk81634 小时前
在SCNet使用异构海光DCU 部署文心21B大模型报错HIP out of memory(未调通)
人工智能
ASKED_20194 小时前
深度强化学习之123-概念梳理
人工智能
攻城狮7号4 小时前
OpenAI 的 Sora 2来了:一场创意革命与失控的狂欢
人工智能·大模型·openai·ai视频·sora 2
胖头鱼的鱼缸(尹海文)5 小时前
数据库管理-第376期 Oracle AI DB 23.26新特性一览(20251016)
数据库·人工智能·oracle
瑞禧生物ruixibio5 小时前
4-ARM-PEG-Pyrene(2)/Biotin(2),多功能化聚乙二醇修饰荧光标记生物分子的设计与应用探索
arm开发·人工智能
大千AI助手5 小时前
Huber损失函数:稳健回归的智慧之选
人工智能·数据挖掘·回归·损失函数·mse·mae·huber损失函数
墨利昂5 小时前
10.17RNN情感分析实验:加载预训练词向量模块整理
人工智能·rnn·深度学习
【建模先锋】5 小时前
一区直接写!CEEMDAN分解 + Informer-LSTM +XGBoost组合预测模型
人工智能·lstm·ceemdan·预测模型·风速预测·时间序列预测模型