【FlashAttention】 FA2与FA1算法区别辨析

看了几篇关于FlashAttention2的文章,对于其中移除冗余的CUDA操作这个算法优化进行了一个综合梳理。

https://zhuanlan.zhihu.com/p/1993815603383902344

https://zhuanlan.zhihu.com/p/668888063

https://zhuanlan.zhihu.com/p/665170554

注意,第10行在部分文章中错写成了diag的逆,应该根据这篇文章的伪代码为准(推测是之前存在笔误,改了之后又重新上传了)。

这里FlashAttention2与FlashAttention1看起来有很大差别,推导如下;

  1. 首先比较重要的一点是,在FA2里,关于m, P的计算都没有mijm_{ij}mij, pijp_{ij}pij的概念,而是直接计算mim_imi和minewm_i^{new}minew,pip_ipi和pinewp_i^{new}pinew。因此此处的mijm_i^jmij就是FA1中的mijm_{ij}mij - minewm_i^{new}minew。另外此处的P也就是FA1中的emij−minew∗Pe^{m_{ij} - m_i^{new}} * Pemij−minew∗P。
  2. 另外第二个点,就是在中间的迭代中不计算L,只在最后一个迭代计算。
相关推荐
封奚泽优3 天前
使用mmdetection项目进行训练记录
pytorch·python·cuda·mmdetection·mmcv
fpcc4 天前
并行编程实战——CUDA编程的其它Warp函数
c++·cuda
Autumn72995 天前
【系统重装】PYTHON 入门——速通版
开发语言·python·conda·cuda
fpcc6 天前
并行编程实战——CUDA编程的Warp Vote
c++·cuda
fpcc6 天前
并行编程实战——CUDA编程的Warp Shuffle
c++·cuda
风流倜傥唐伯虎8 天前
N卡深度学习环境配置
人工智能·深度学习·cuda
fpcc9 天前
并行编程实战——CUDA编程的Enhancing Memory Allocation
c++·cuda
fpcc9 天前
AI和大模型之一介绍
人工智能·cuda
闪电橘子9 天前
Pycharm运行程序报错 Process finished with exit code -1066598273 (0xC06D007F)
ide·python·pycharm·cuda
fpcc10 天前
并行编程实战——CUDA编程的内存建议
c++·cuda