10.一文学会GPU与cuda原理,并从其原理来理解FlashAttention

CPU与cuda原理,并从其原理来理解flashattention

GPU 和 CUDA结构

cuda的每个block对应gpu中的每个sm(stream multiprocessor ), cuda的一个block可以分为多个warp,一个warp内有32个线程(在不同的数据上执行相同指令)

GPU MEMORY结构

让GPU运行更快的几种方法

技巧0 .条件语句在一个warp中有害,尽量在一个warp中别使用(跟内存无关)
技巧1. 低精度计算(low precision computation)
技巧2.内核融合(kernel fused)

例子 计算 sinx平方 + cosx 平方

没算子融合之前

算子融合

技巧3.重计算(recomputation)(计算换memory)
技巧4. memory coalescing(内存合并) ,突发模式(burst mode) 这是硬件层面的优化

当读内存中的一个值的时候,你读到的不仅只是你要的那个值,实际上读到的是跟这个值相连的一个burst section内存。gpu觉得后面可能也需要这些值。

一个warp(32线程)可以一次获取4个字节(一个burst section)而不是一个字节一个字节的读,吞吐量提升4倍。

解释:一个 Warp 的 32 个线程,在同一时刻访问的地址是否连续。---这里列是连续的

在 row-major 矩阵中,真正按列访问(m00, m10, m20)是非合并的

所谓"列是 coalesced",指的是 warp 的线程索引对应列下标,使同一时刻访问的是一行中连续的元素

在 CUDA 里,"连续不连续"不是看一个线程怎么走,而是看 一个 warp 在同一拍读的地址是不是连续

在 row-major 矩阵中,让 warp 的线程对应列索引,才能触发 coalesced access。

如果warp中线程索引按行进行,每个线程读的都是不同burst section 内存不连续

如果warp中线程索引按列进行,每个线程读的都是同一个burst section中的不同元素,内存连续

技巧5 分块( Tiling )

为了最小化内存占用,将事物组合到一起。

问题如下:

我们能否避免过多的全局内存读写?

Tiling解决方案:从全局内存 加载一部分数据(尽可能多)到 共享内存中(加载到时候可以按列索引 进行内存连续),共享内存进行运算。

tile的复杂之一 设置tile_size

设置tile size应该考虑几个因素如下

1 连续内存访问 memory coalescing

2 不得超过共享内存大小

3 矩阵的大小 / tile_size 尽可能能整除 (不会陷入到sm未充分利用的情况)

tile的复杂之二 memory alignment(内存对齐)

问题:加上在algned layout的每个行多加一个元素,成了unaliged layout,继续回导致内存访问量double。

解决:对此进行padding 避免内存访问double

总结-分析曲线图

当k=1,k=2时 读tile 没有很好的办法用 burst section对齐,故内存访问占据更多时间,会造成计算量跟不上读的速度。
为什么矩阵维度从1792 升到1793 ,计算性能下降很多?

不能整除tile的行列数,且tile个数超过了a100的sms数

从各个部分去分析,应该怎么做

根据所学理解flashattention

一句话:用了分块(tiling)和recomputation,用算子融合(fusion)进行了在线softmax计算。

标准attention计算流程如下:

flashattention tiliing流程如下:

flashattention softmax流程如下:

flashattention 代码演示

py 复制代码
import numpy as np

def flash_attention_single_query_debug(Q, K, V):
    m = -np.inf   # running max:当前看到的最大 Q·K
    d = 0.0       # running sum:∑ exp(score - m)
    o = np.zeros(V.shape[1])  # running output numerator:∑ exp(score - m) * V

    print("=" * 100)
    print(f"Query Q = {Q}")
    print("解释:我们要计算 Softmax(QK^T) @ V")
    print("       即 O = (∑ exp(Q·K_j) * V_j) / (∑ exp(Q·K_j))")
    print("=" * 100)

    for j in range(K.shape[0]):
        score = np.dot(Q, K[j])  # Q · K_j

        # 新的最大值(数值稳定关键)
        m_new = max(m, score)

        # 如果 max 变大,旧结果要整体缩放
        scale = np.exp(m - m_new)

        # 当前项在新 max 下的指数值
        exp_score = np.exp(score - m_new)

        # 更新 softmax 分母
        d_new = d * scale + exp_score

        # 更新 softmax 分子(带 V)
        o_new = o * scale + exp_score * V[j]

        print(f"[Step {j}] 处理第 {j} 个 Key-Value")
        print(f"K[{j}] = {K[j]}")
        print(f"V[{j}] = {V[j]}")
        print(f"score = Q · K[{j}] = {score:.6f}")

        print(f"\n【数值稳定处理】")
        print(f"之前最大值 m_prev = {m:.6f}")
        print(f"当前 score        = {score:.6f}")
        print(f"新的最大值 m_new = max(m_prev, score) = {m_new:.6f}")

        print(f"\n【指数修正】")
        print(f"scale = exp(m_prev - m_new)")
        print(f"      = exp({m:.6f} - {m_new:.6f})")
        print(f"      = {scale:.6e}")

        print(f"exp(score - m_new)")
        print(f"= exp({score:.6f} - {m_new:.6f})")
        print(f"= {exp_score:.6e}")

        print(f"\n【分母更新(Softmax sum)】")
        print(f"d_new = d_prev * scale + exp(score - m_new)")
        print(f"      = {d:.6e} * {scale:.6e} + {exp_score:.6e}")
        print(f"      = {d_new:.6e}")

        print(f"\n【分子更新(Softmax * V)】")
        print(f"o_new = o_prev * scale + exp(score - m_new) * V[{j}]")
        print(f"      = {o} * {scale:.6e} + {exp_score:.6e} * {V[j]}")
        print(f"      = {o_new}")

        print("-" * 100)

        m, d, o = m_new, d_new, o_new

    O = o / d
    print("\n✅ Final Output")
    print("O = o / d")
    print(f"  = {o} / {d:.6e}")
    print(f"  = {O}")
    print("=" * 100)
    return O


# 测试用例1
Q = np.array([1.0])
K = np.array([[5.0], [1.0], [0.0]])
V = np.array([[10.0], [1.0], [1.0]])
flash_attention_single_query_debug(Q, K, V)
相关推荐
RockHopper20252 小时前
工业AMR场景融合设计原理5——约束体系的价值
人工智能·系统架构·智能制造·具身智能·amr·工业amr
AI工具测评大师2 小时前
怎么有效降低英文文本的GPTZero AI检测率?3步有效降低AI率方法与工
人工智能·深度学习·自然语言处理·ai写作·ai自动写作
轻微的风格艾丝凡2 小时前
圆周率(π)2-10进制转换及随机性量化分析技术文档
人工智能·算法
测试专家2 小时前
反射内存卡在航空电子中的应用
网络·人工智能
GAOJ_K2 小时前
弧形导轨在安装时的关键方式
人工智能·科技·机器人·自动化·制造
醒雷工程师2 小时前
AI人工智能发展方向和对能源过度依赖的解决设想
人工智能·能源
yumgpkpm2 小时前
AI校服识别算法的成本+规划
人工智能·算法
d0ublεU0x002 小时前
task03深入大模型架构
人工智能
linmoo19862 小时前
Langchain4j 系列之二十七 - Ollama集成Deepseek
人工智能·langchain·ollama·deepseek·langchain4j