【FlashAttention】 FA2与FA1算法区别辨析

看了几篇关于FlashAttention2的文章,对于其中移除冗余的CUDA操作这个算法优化进行了一个综合梳理。

https://zhuanlan.zhihu.com/p/1993815603383902344

https://zhuanlan.zhihu.com/p/668888063

https://zhuanlan.zhihu.com/p/665170554

注意,第10行在部分文章中错写成了diag的逆,应该根据这篇文章的伪代码为准(推测是之前存在笔误,改了之后又重新上传了)。

这里FlashAttention2与FlashAttention1看起来有很大差别,推导如下;

  1. 首先比较重要的一点是,在FA2里,关于m, P的计算都没有mijm_{ij}mij, pijp_{ij}pij的概念,而是直接计算mim_imi和minewm_i^{new}minew,pip_ipi和pinewp_i^{new}pinew。因此此处的mijm_i^jmij就是FA1中的mijm_{ij}mij - minewm_i^{new}minew。另外此处的P也就是FA1中的emij−minew∗Pe^{m_{ij} - m_i^{new}} * Pemij−minew∗P。
  2. 另外第二个点,就是在中间的迭代中不计算L,只在最后一个迭代计算。
相关推荐
ouliten10 小时前
cuda编程笔记(36)-- 应用Tensor Core加速矩阵乘法
笔记·cuda
人工智能训练6 天前
【极速部署】Ubuntu24.04+CUDA13.0 玩转 VLLM 0.15.0:预编译 Wheel 包 GPU 版安装全攻略
运维·前端·人工智能·python·ai编程·cuda·vllm
X-Vision6 天前
Visual Studio 2022中配置cuda环境
visual studio·cuda
安全二次方security²7 天前
CUDA C++编程指南(7.31&32&33&34)——C++语言扩展之性能分析计数器函数和断言、陷阱、断点函数
c++·人工智能·nvidia·cuda·断点·断言·性能分析计数器函数
安全二次方security²8 天前
CUDA C++编程指南(7.25)——C++语言扩展之DPX
c++·人工智能·nvidia·cuda·dpx·cuda c++编程指南
不教书的塞涅卡10 天前
SSH远程接入PyTorch-CUDA-v2.9镜像,随时随地训练大模型
pytorch·ssh·cuda
安全二次方security²13 天前
CUDA C++编程指南(7.19&20)——C++语言扩展之Warp投票函数和Warp匹配函数
c++·人工智能·nvidia·cuda·投票函数·匹配函数·vote
安全二次方security²14 天前
CUDA C++编程指南(7.15&16)——C++语言扩展之内存空间谓词和转化函数
c++·人工智能·nvidia·cuda·内存空间谓词函数·内存空间转化函数·address space
安全二次方security²14 天前
CUDA C++编程指南(7.5&6)——C++语言扩展之内存栅栏函数和同步函数
c++·人工智能·nvidia·cuda·内存栅栏函数·同步函数·syncthreads
安全二次方security²15 天前
CUDA C++编程指南(7.2)——C++语言扩展之变量内存空间指定符
c++·人工智能·nvidia·cuda·内存空间指定符·__shared__·__device__