【FlashAttention】 FA2与FA1算法区别辨析

看了几篇关于FlashAttention2的文章,对于其中移除冗余的CUDA操作这个算法优化进行了一个综合梳理。

https://zhuanlan.zhihu.com/p/1993815603383902344

https://zhuanlan.zhihu.com/p/668888063

https://zhuanlan.zhihu.com/p/665170554

注意,第10行在部分文章中错写成了diag的逆,应该根据这篇文章的伪代码为准(推测是之前存在笔误,改了之后又重新上传了)。

这里FlashAttention2与FlashAttention1看起来有很大差别,推导如下;

  1. 首先比较重要的一点是,在FA2里,关于m, P的计算都没有mijm_{ij}mij, pijp_{ij}pij的概念,而是直接计算mim_imi和minewm_i^{new}minew,pip_ipi和pinewp_i^{new}pinew。因此此处的mijm_i^jmij就是FA1中的mijm_{ij}mij - minewm_i^{new}minew。另外此处的P也就是FA1中的emij−minew∗Pe^{m_{ij} - m_i^{new}} * Pemij−minew∗P。
  2. 另外第二个点,就是在中间的迭代中不计算L,只在最后一个迭代计算。
相关推荐
basketball6165 天前
AI Infra 硬件体系与编程模型:17. CUDA编程基础:底层驱动 API 调用
人工智能·microsoft·nvidia·cuda
fpcc6 天前
并行编程实战——CUDA编程的pipelines
c++·cuda
basketball6168 天前
AI Infra 硬件体系与编程模型:14. CUDA编程基础:事件与精确性能测量
人工智能·nvidia·cuda
kyle~8 天前
推理部署---CUDA 执行模型(SM、Block、Warp 与 SIMT)
人工智能·nvidia·cuda
June`8 天前
如何组织一个并行程序
开发语言·cuda
basketball6168 天前
AI Infra 硬件体系与编程模型:15. CUDA编程基础:混合精度计算
人工智能·nvidia·cuda
June`8 天前
CUDA执行模型深入刨析
c++·人工智能·cuda
June`9 天前
CUDA程序效率如何计算以及工具如何使用
算法·cuda
插件开发9 天前
CUDA11-VS2015安装-工具链测试-Helloworld程序
c++·gpu·cuda
虎妞050010 天前
PyTorch 2.0 生产级部署与性能优化指南
pytorch·深度学习·ai·模型部署·cuda