极智AI | 算子融合、矩阵分块 一图看懂大模型优化技术FlashAttention

欢迎关注我的公众号 [极智视界],获取我的更多经验分享

大家好,我是极智视界,本文来介绍一下 算子融合、矩阵分块 一图看懂大模型优化技术FlashAttention。

邀您加入我的知识星球「极智视界」,星球内有超多好玩的项目实战源码下载,链接:t.zsxq.com/0aiNxERDq

没错没错,就是这个图啦,

所谓一图胜千言,一张好的图对于一个工作的表达很重要,通常能够让人更能直观理解这个工作在做什么。

这里基于这张图,来解读大模型优化技术之 FlashAttention。

先用一句话来总结 FlashAttention 的优化之道:算子融合,矩阵分块分而治之

大家知道,基于 Transformer 架构的大模型,在模型推理优化方面已经将以往 CNN 模型的计算密集型 (Compute-Bound) 优化转移到了访存密集型 (Memory-Bound) 优化上,而且这种访存方面的负担是随着序列长度增加呈指数级增长的,为什么会是指数级增长的呢,因为 Transformer 的时间和存储复杂度是 O(N^2) 的,导致在实际中很难将输入 token 进一步放的比较大。所以在基于大模型的推理优化中,通常对于 Memory、对于 IO 的优化特别重要。这里的 FlashAttention 也是一样,主要就是对于 Memory、对于 IO 方面的优化。

FlashAttention 加速的原理非常简单,就是更多地去利用带宽更高的上层存储单元,减少对低速下层存储单元的访问频率,从而达到加速的目的。既然提到了高速带宽和低速带宽,说明 FlashAttention 是一种强硬件相关的加速技术。看上面图的最左边部分,它是 NVIDIA A100 GPU 的存储层次图,最上层也是带宽最高的存储为 GPU SRAM,大约为 20 MB 左右,具体大小其实可以自己算一下,A100 有 108 个 SM,每个 SM 192 KB,这样就可以算出是 20.25 MB。SRAM 虽然小,但它的带宽则高达 19 TB/s。接着是 HBM 高速带宽,也就是咱们通俗理解的 GPU 显存,A100 采用 HBM2e 的堆叠,显存达到 40 ~ 80 GB,带宽达到 1.5 ~ 2.0 TB/s。这样对比下来,SRAM 的访问速率要比 HBM 快 10 倍的样子,但是其所能容纳的数据量却远远小于 HBM。继续往下看,接下来就是主存了,也就是所谓的 CPU DRAM,它的容量能够达到 1 TB 以上,但带宽只有 12.8 GB/s,而对于这种带宽速率,优化的时候就需要避其远之了。

所以综上,你可以看到,对于推理效率优化,特别在考虑到大模型 "大" 的情况下,怎么样去更好地协同配合利用好 SRAM 和 HBM,就是重中之重了。而这种 SRAM 和 HBM 之间的协同优化,又会依赖于一种始于 Ampere 架构的新特性,那就是可以直接将数据从 HBM 拷贝到 SRAM 的新的异步拷贝指令,这让 SRAM 和 HBM 之间的数据传输更加直接。

对于 FlashAttention 的优化,有了硬件层面的优化概念之后,再来看软件层面是怎么做的。

先来看标准 Attention 的实现伪代码:

Algorithm 0 Standard Attention Implementation
Require: Matrices Q , K , V ∈ R^(𝑁 ×𝑑 ) in HBM.
1: Load Q , K by blocks from HBM, compute S = QK^( T), write S to HBM.
2: Read S from HBM, compute P = softmax(S ), write P to HBM.
3: Load P and V by blocks from HBM, compute O = PV , write O to HBM.
4: Return O.

上面的伪代码很好地展示了传统 Attention 中 Q、K、V 的计算过程,整个过程相对 "朴素",缺点主要有以下几个:

  • (1) 对于中间矩阵 S、P 的大小跟输入 token 的大小相关,空间复杂度为 O(N^2),而且在计算 Attention 的过程中需要反复计算 S、P,势必要反复存取 S、P,势必导致存取频率高。而这一切又都是发生在 HBM 上,HBM 的带宽即使 "全速支撑" 这么高频率的存取,在宏观上也会造成带宽瓶颈。再加上前面的显存瓶颈,就形成了两个常见的瓶颈;
  • (2) 由于计算都是大块大块的,不能很好利用小块而带宽极高的 SRAM,效率低;

针对这两个缺陷,很自然地就有了下面的优化思路,

  • 对于缺点 (1) => 从 Attention 的计算伪代码可以看出,Attention 的实现主要是包括了三个步骤,对应到 CUDA 的实现上就是三个 Kernel,而这三个 Kernel 的输入输出数据都是在 HBM 上,所以就会对 HBM 造成频繁的存取访问。那么是否可以采用算子融合的方式,将这三个 Kernel 融合为一个 Kernel 呢,这样至少不用多次地进行数据存取了 (把中间的 S 和 P 给干掉了),答案肯定是可以的,也就是对应于文中 "一图" 中的最右边部分;==> Fusion Kernel;
  • 对于缺点 (2) => 显存一旦高了,带来的问题可不止造成 HBM 显存容量负担大,除此之外,正是因为 "大块",而 SRAM 又相对小,所以没有办法把大的模型 "直接塞进" 容量小但带宽高的 SRAM 进行加速,这也是限制大模型效率的很重要的原因了。为了解决这个问题,很自然能想到的办法就是切片以分而治之。采用的思想跟深度学习编译器 TVM 里面的优化思想很像,进行 Tiling,也即矩阵分块。具体到 FlashAttention 中,就是对 Softmax 进行 Tiling 优化,以增量的方式执行 Softmx 的缩减。而这,也就是对应了文中 "一图" 中的中间部分了。

所以这么分析来看,FlashAttention 的优化方法主要就是两条:

  • Fusion Kernel;==> 算子融合
  • Softmax Tiling; ==> 矩阵分块;

算子融合太好理解,而矩阵分块,我再回过头仔细瞅了瞅,你可以说不是借鉴了 TVM,你猜我信不信。当然你也可以说不是借鉴了 TVM 而是借鉴了 Halide,那也能说你真棒。对于 TVM 中的 Tiling 的具体介绍,可以参考我的下面两篇对于 TVM 的解读,

好了,以上分享了 算子融合、矩阵分块 一图看懂大模型优化技术FlashAttention,希望我的分享能对你的学习有一点帮助。


【公众号传送】

《极智AI | 算子融合、矩阵分块 一图看懂大模型优化技术FlashAttention》

畅享人工智能的科技魅力,让好玩的AI项目不难玩。邀请您加入我的知识星球, 星球内我精心整备了大量好玩的AI项目,皆以工程源码形式开放使用,涵盖人脸、检测、分割、多模态、AIGC、自动驾驶、工业等。一定会对你学习有所帮助,也一定非常好玩,并持续更新更加有趣的项目。 t.zsxq.com/0aiNxERDq

相关推荐
这个男人是小帅28 分钟前
【GAT】 代码详解 (1) 运行方法【pytorch】可运行版本
人工智能·pytorch·python·深度学习·分类
__基本操作__30 分钟前
边缘提取函数 [OPENCV--2]
人工智能·opencv·计算机视觉
Doctor老王34 分钟前
TR3:Pytorch复现Transformer
人工智能·pytorch·transformer
热爱生活的五柒35 分钟前
pytorch中数据和模型都要部署在cuda上面
人工智能·pytorch·深度学习
HyperAI超神经3 小时前
【TVM 教程】使用 Tensorize 来利用硬件内联函数
人工智能·深度学习·自然语言处理·tvm·计算机技术·编程开发·编译框架
扫地的小何尚4 小时前
NVIDIA RTX 系统上使用 llama.cpp 加速 LLM
人工智能·aigc·llama·gpu·nvidia·cuda·英伟达
埃菲尔铁塔_CV算法7 小时前
深度学习神经网络创新点方向
人工智能·深度学习·神经网络
艾思科蓝-何老师【H8053】7 小时前
【ACM出版】第四届信号处理与通信技术国际学术会议(SPCT 2024)
人工智能·信号处理·论文发表·香港中文大学
weixin_452600697 小时前
《青牛科技 GC6125:驱动芯片中的璀璨之星,点亮 IPcamera 和云台控制(替代 BU24025/ROHM)》
人工智能·科技·单片机·嵌入式硬件·新能源充电桩·智能充电枪
学术搬运工7 小时前
【珠海科技学院主办,暨南大学协办 | IEEE出版 | EI检索稳定 】2024年健康大数据与智能医疗国际会议(ICHIH 2024)
大数据·图像处理·人工智能·科技·机器学习·自然语言处理