10.一文学会GPU与cuda原理,并从其原理来理解FlashAttention

CPU与cuda原理,并从其原理来理解flashattention

GPU 和 CUDA结构

cuda的每个block对应gpu中的每个sm(stream multiprocessor ), cuda的一个block可以分为多个warp,一个warp内有32个线程(在不同的数据上执行相同指令)

GPU MEMORY结构

让GPU运行更快的几种方法

技巧0 .条件语句在一个warp中有害,尽量在一个warp中别使用(跟内存无关)
技巧1. 低精度计算(low precision computation)
技巧2.内核融合(kernel fused)

例子 计算 sinx平方 + cosx 平方

没算子融合之前

算子融合

技巧3.重计算(recomputation)(计算换memory)
技巧4. memory coalescing(内存合并) ,突发模式(burst mode) 这是硬件层面的优化

当读内存中的一个值的时候,你读到的不仅只是你要的那个值,实际上读到的是跟这个值相连的一个burst section内存。gpu觉得后面可能也需要这些值。

一个warp(32线程)可以一次获取4个字节(一个burst section)而不是一个字节一个字节的读,吞吐量提升4倍。

解释:一个 Warp 的 32 个线程,在同一时刻访问的地址是否连续。---这里列是连续的

在 row-major 矩阵中,真正按列访问(m00, m10, m20)是非合并的

所谓"列是 coalesced",指的是 warp 的线程索引对应列下标,使同一时刻访问的是一行中连续的元素

在 CUDA 里,"连续不连续"不是看一个线程怎么走,而是看 一个 warp 在同一拍读的地址是不是连续

在 row-major 矩阵中,让 warp 的线程对应列索引,才能触发 coalesced access。

如果warp中线程索引按行进行,每个线程读的都是不同burst section 内存不连续

如果warp中线程索引按列进行,每个线程读的都是同一个burst section中的不同元素,内存连续

技巧5 分块( Tiling )

为了最小化内存占用,将事物组合到一起。

问题如下:

我们能否避免过多的全局内存读写?

Tiling解决方案:从全局内存 加载一部分数据(尽可能多)到 共享内存中(加载到时候可以按列索引 进行内存连续),共享内存进行运算。

tile的复杂之一 设置tile_size

设置tile size应该考虑几个因素如下

1 连续内存访问 memory coalescing

2 不得超过共享内存大小

3 矩阵的大小 / tile_size 尽可能能整除 (不会陷入到sm未充分利用的情况)

tile的复杂之二 memory alignment(内存对齐)

问题:加上在algned layout的每个行多加一个元素,成了unaliged layout,继续回导致内存访问量double。

解决:对此进行padding 避免内存访问double

总结-分析曲线图

当k=1,k=2时 读tile 没有很好的办法用 burst section对齐,故内存访问占据更多时间,会造成计算量跟不上读的速度。
为什么矩阵维度从1792 升到1793 ,计算性能下降很多?

不能整除tile的行列数,且tile个数超过了a100的sms数

从各个部分去分析,应该怎么做

根据所学理解flashattention

一句话:用了分块(tiling)和recomputation,用算子融合(fusion)进行了在线softmax计算。

标准attention计算流程如下:

flashattention tiliing流程如下:

flashattention softmax流程如下:

flashattention 代码演示

py 复制代码
import numpy as np

def flash_attention_single_query_debug(Q, K, V):
    m = -np.inf   # running max:当前看到的最大 Q·K
    d = 0.0       # running sum:∑ exp(score - m)
    o = np.zeros(V.shape[1])  # running output numerator:∑ exp(score - m) * V

    print("=" * 100)
    print(f"Query Q = {Q}")
    print("解释:我们要计算 Softmax(QK^T) @ V")
    print("       即 O = (∑ exp(Q·K_j) * V_j) / (∑ exp(Q·K_j))")
    print("=" * 100)

    for j in range(K.shape[0]):
        score = np.dot(Q, K[j])  # Q · K_j

        # 新的最大值(数值稳定关键)
        m_new = max(m, score)

        # 如果 max 变大,旧结果要整体缩放
        scale = np.exp(m - m_new)

        # 当前项在新 max 下的指数值
        exp_score = np.exp(score - m_new)

        # 更新 softmax 分母
        d_new = d * scale + exp_score

        # 更新 softmax 分子(带 V)
        o_new = o * scale + exp_score * V[j]

        print(f"[Step {j}] 处理第 {j} 个 Key-Value")
        print(f"K[{j}] = {K[j]}")
        print(f"V[{j}] = {V[j]}")
        print(f"score = Q · K[{j}] = {score:.6f}")

        print(f"\n【数值稳定处理】")
        print(f"之前最大值 m_prev = {m:.6f}")
        print(f"当前 score        = {score:.6f}")
        print(f"新的最大值 m_new = max(m_prev, score) = {m_new:.6f}")

        print(f"\n【指数修正】")
        print(f"scale = exp(m_prev - m_new)")
        print(f"      = exp({m:.6f} - {m_new:.6f})")
        print(f"      = {scale:.6e}")

        print(f"exp(score - m_new)")
        print(f"= exp({score:.6f} - {m_new:.6f})")
        print(f"= {exp_score:.6e}")

        print(f"\n【分母更新(Softmax sum)】")
        print(f"d_new = d_prev * scale + exp(score - m_new)")
        print(f"      = {d:.6e} * {scale:.6e} + {exp_score:.6e}")
        print(f"      = {d_new:.6e}")

        print(f"\n【分子更新(Softmax * V)】")
        print(f"o_new = o_prev * scale + exp(score - m_new) * V[{j}]")
        print(f"      = {o} * {scale:.6e} + {exp_score:.6e} * {V[j]}")
        print(f"      = {o_new}")

        print("-" * 100)

        m, d, o = m_new, d_new, o_new

    O = o / d
    print("\n✅ Final Output")
    print("O = o / d")
    print(f"  = {o} / {d:.6e}")
    print(f"  = {O}")
    print("=" * 100)
    return O


# 测试用例1
Q = np.array([1.0])
K = np.array([[5.0], [1.0], [0.0]])
V = np.array([[10.0], [1.0], [1.0]])
flash_attention_single_query_debug(Q, K, V)
相关推荐
YongGit11 分钟前
OpenClaw 本地 AI 助手完全指南:飞书接入 + 远程部署实战
人工智能
程序员鱼皮2 小时前
斯坦福大学竟然开了个 AI 编程课?!我已经学上了
人工智能·ai编程
星浩AI2 小时前
Skill 的核心要素与渐进式加载架构——如何设计一个生产可用的 Skill?
人工智能·agent
树獭非懒2 小时前
告别繁琐多端开发:DivKit 带你玩转 Server-Driven UI!
android·前端·人工智能
阿尔的代码屋3 小时前
[大模型实战 07] 基于 LlamaIndex ReAct 框架手搓全自动博客监控 Agent
人工智能·python
小小小怪兽3 小时前
🔨聊一聊Skills
人工智能·agent
穿过生命散发芬芳3 小时前
OpenClaw:开启OpenCloudOS 操作系统智能运维初体验
人工智能·aigc
老金带你玩AI3 小时前
Claude Code自动记忆来了!配合老金三层记忆系统全开源!加强Plus!
人工智能
Halo咯咯3 小时前
无限免费 OpenClaw:接入本地模型后,你的 AI Agent 就可以 24 小时自动干活(Mac Mini 可用)
人工智能
NAGNIP15 小时前
一文搞懂深度学习中的通用逼近定理!
人工智能·算法·面试