看了几篇关于FlashAttention2的文章,对于其中移除冗余的CUDA操作这个算法优化进行了一个综合梳理。
https://zhuanlan.zhihu.com/p/1993815603383902344
https://zhuanlan.zhihu.com/p/668888063
https://zhuanlan.zhihu.com/p/665170554


注意,第10行在部分文章中错写成了diag的逆,应该根据这篇文章的伪代码为准(推测是之前存在笔误,改了之后又重新上传了)。
这里FlashAttention2与FlashAttention1看起来有很大差别,推导如下;
- 首先比较重要的一点是,在FA2里,关于m, P的计算都没有mijm_{ij}mij, pijp_{ij}pij的概念,而是直接计算mim_imi和minewm_i^{new}minew,pip_ipi和pinewp_i^{new}pinew。因此此处的mijm_i^jmij就是FA1中的mijm_{ij}mij - minewm_i^{new}minew。另外此处的P也就是FA1中的emij−minew∗Pe^{m_{ij} - m_i^{new}} * Pemij−minew∗P。
- 另外第二个点,就是在中间的迭代中不计算L,只在最后一个迭代计算。