极智AI | 算子融合、矩阵分块 一图看懂大模型优化技术FlashAttention

欢迎关注我的公众号 [极智视界],获取我的更多经验分享

大家好,我是极智视界,本文来介绍一下 算子融合、矩阵分块 一图看懂大模型优化技术FlashAttention。

邀您加入我的知识星球「极智视界」,星球内有超多好玩的项目实战源码下载,链接:t.zsxq.com/0aiNxERDq

没错没错,就是这个图啦,

所谓一图胜千言,一张好的图对于一个工作的表达很重要,通常能够让人更能直观理解这个工作在做什么。

这里基于这张图,来解读大模型优化技术之 FlashAttention。

先用一句话来总结 FlashAttention 的优化之道:算子融合,矩阵分块分而治之

大家知道,基于 Transformer 架构的大模型,在模型推理优化方面已经将以往 CNN 模型的计算密集型 (Compute-Bound) 优化转移到了访存密集型 (Memory-Bound) 优化上,而且这种访存方面的负担是随着序列长度增加呈指数级增长的,为什么会是指数级增长的呢,因为 Transformer 的时间和存储复杂度是 O(N^2) 的,导致在实际中很难将输入 token 进一步放的比较大。所以在基于大模型的推理优化中,通常对于 Memory、对于 IO 的优化特别重要。这里的 FlashAttention 也是一样,主要就是对于 Memory、对于 IO 方面的优化。

FlashAttention 加速的原理非常简单,就是更多地去利用带宽更高的上层存储单元,减少对低速下层存储单元的访问频率,从而达到加速的目的。既然提到了高速带宽和低速带宽,说明 FlashAttention 是一种强硬件相关的加速技术。看上面图的最左边部分,它是 NVIDIA A100 GPU 的存储层次图,最上层也是带宽最高的存储为 GPU SRAM,大约为 20 MB 左右,具体大小其实可以自己算一下,A100 有 108 个 SM,每个 SM 192 KB,这样就可以算出是 20.25 MB。SRAM 虽然小,但它的带宽则高达 19 TB/s。接着是 HBM 高速带宽,也就是咱们通俗理解的 GPU 显存,A100 采用 HBM2e 的堆叠,显存达到 40 ~ 80 GB,带宽达到 1.5 ~ 2.0 TB/s。这样对比下来,SRAM 的访问速率要比 HBM 快 10 倍的样子,但是其所能容纳的数据量却远远小于 HBM。继续往下看,接下来就是主存了,也就是所谓的 CPU DRAM,它的容量能够达到 1 TB 以上,但带宽只有 12.8 GB/s,而对于这种带宽速率,优化的时候就需要避其远之了。

所以综上,你可以看到,对于推理效率优化,特别在考虑到大模型 "大" 的情况下,怎么样去更好地协同配合利用好 SRAM 和 HBM,就是重中之重了。而这种 SRAM 和 HBM 之间的协同优化,又会依赖于一种始于 Ampere 架构的新特性,那就是可以直接将数据从 HBM 拷贝到 SRAM 的新的异步拷贝指令,这让 SRAM 和 HBM 之间的数据传输更加直接。

对于 FlashAttention 的优化,有了硬件层面的优化概念之后,再来看软件层面是怎么做的。

先来看标准 Attention 的实现伪代码:

Algorithm 0 Standard Attention Implementation
Require: Matrices Q , K , V ∈ R^(𝑁 ×𝑑 ) in HBM.
1: Load Q , K by blocks from HBM, compute S = QK^( T), write S to HBM.
2: Read S from HBM, compute P = softmax(S ), write P to HBM.
3: Load P and V by blocks from HBM, compute O = PV , write O to HBM.
4: Return O.

上面的伪代码很好地展示了传统 Attention 中 Q、K、V 的计算过程,整个过程相对 "朴素",缺点主要有以下几个:

  • (1) 对于中间矩阵 S、P 的大小跟输入 token 的大小相关,空间复杂度为 O(N^2),而且在计算 Attention 的过程中需要反复计算 S、P,势必要反复存取 S、P,势必导致存取频率高。而这一切又都是发生在 HBM 上,HBM 的带宽即使 "全速支撑" 这么高频率的存取,在宏观上也会造成带宽瓶颈。再加上前面的显存瓶颈,就形成了两个常见的瓶颈;
  • (2) 由于计算都是大块大块的,不能很好利用小块而带宽极高的 SRAM,效率低;

针对这两个缺陷,很自然地就有了下面的优化思路,

  • 对于缺点 (1) => 从 Attention 的计算伪代码可以看出,Attention 的实现主要是包括了三个步骤,对应到 CUDA 的实现上就是三个 Kernel,而这三个 Kernel 的输入输出数据都是在 HBM 上,所以就会对 HBM 造成频繁的存取访问。那么是否可以采用算子融合的方式,将这三个 Kernel 融合为一个 Kernel 呢,这样至少不用多次地进行数据存取了 (把中间的 S 和 P 给干掉了),答案肯定是可以的,也就是对应于文中 "一图" 中的最右边部分;==> Fusion Kernel;
  • 对于缺点 (2) => 显存一旦高了,带来的问题可不止造成 HBM 显存容量负担大,除此之外,正是因为 "大块",而 SRAM 又相对小,所以没有办法把大的模型 "直接塞进" 容量小但带宽高的 SRAM 进行加速,这也是限制大模型效率的很重要的原因了。为了解决这个问题,很自然能想到的办法就是切片以分而治之。采用的思想跟深度学习编译器 TVM 里面的优化思想很像,进行 Tiling,也即矩阵分块。具体到 FlashAttention 中,就是对 Softmax 进行 Tiling 优化,以增量的方式执行 Softmx 的缩减。而这,也就是对应了文中 "一图" 中的中间部分了。

所以这么分析来看,FlashAttention 的优化方法主要就是两条:

  • Fusion Kernel;==> 算子融合
  • Softmax Tiling; ==> 矩阵分块;

算子融合太好理解,而矩阵分块,我再回过头仔细瞅了瞅,你可以说不是借鉴了 TVM,你猜我信不信。当然你也可以说不是借鉴了 TVM 而是借鉴了 Halide,那也能说你真棒。对于 TVM 中的 Tiling 的具体介绍,可以参考我的下面两篇对于 TVM 的解读,

好了,以上分享了 算子融合、矩阵分块 一图看懂大模型优化技术FlashAttention,希望我的分享能对你的学习有一点帮助。


【公众号传送】

《极智AI | 算子融合、矩阵分块 一图看懂大模型优化技术FlashAttention》

畅享人工智能的科技魅力,让好玩的AI项目不难玩。邀请您加入我的知识星球, 星球内我精心整备了大量好玩的AI项目,皆以工程源码形式开放使用,涵盖人脸、检测、分割、多模态、AIGC、自动驾驶、工业等。一定会对你学习有所帮助,也一定非常好玩,并持续更新更加有趣的项目。 t.zsxq.com/0aiNxERDq

相关推荐
誉鏐2 分钟前
PyTorch复现逻辑回归
人工智能·pytorch·逻辑回归
正脉科工 CAE仿真5 分钟前
基于ANSYS 概率设计和APDL编程的结构可靠性设计分析
人工智能·python·算法
EasyGBS11 分钟前
视频设备轨迹回放平台EasyCVR打造视频智能融合新平台,驱动智慧机场迈向数字新时代
网络·人工智能·安全·音视频
Chaos_Wang_16 分钟前
NLP高频面试题(三十三)——Vision Transformer(ViT)模型架构介绍
人工智能·自然语言处理·transformer
新知图书29 分钟前
OpenCV单窗口显示多图片
人工智能·opencv·计算机视觉
荷包蛋蛋怪31 分钟前
【北京化工大学】 神经网络与深度学习 实验6 MATAR图像分类
人工智能·深度学习·神经网络·opencv·机器学习·计算机视觉·分类
小马哥编程33 分钟前
【软测】AI助力测试用例
人工智能·测试用例
与火星的孩子对话43 分钟前
Unity3D开发AI桌面精灵/宠物系列 【三】 语音识别 ASR 技术、语音转文本多平台 - 支持科大讯飞、百度等 C# 开发
人工智能·unity·c#·游戏引擎·语音识别·宠物
事变天下1 小时前
今是科技发布全新测序仪G-seq1M:以效率与精准引领基因测序新标杆
人工智能·科技
贤小二AI1 小时前
贤小二c#版Yolov5 yolov8 yolov10 yolov11自动标注工具 + 免python环境 GPU一键训练包
人工智能·深度学习·yolo