【llm对话系统】大模型 Llama 源码分析之 Flash Attention

1. 写在前面

近年来,基于 Transformer 架构的大型语言模型 (LLM) 在自然语言处理 (NLP) 领域取得了巨大的成功。Transformer 的核心组件是自注意力 (Self-Attention) 机制,它允许模型捕捉输入序列中不同位置之间的关系。然而,标准的自注意力机制的计算复杂度与序列长度的平方成正比,这使得它在处理长序列时效率低下。

为了解决这个问题,Flash Attention 被提出,它是一种高效的注意力算法,通过利用现代 GPU 的特性,显著降低了计算复杂度和内存占用。本文将深入 Llama 源码,分析 Flash Attention 的实现逻辑,并与标准的自注意力机制进行比较。

2. Self-Attention 回顾

Self-Attention 的核心思想是:对于输入序列中的每个 token,都计算它与其他所有 token 之间的相关性,并根据这些相关性对所有 token 的表示进行加权求和,得到该 token 的新的表示。

标准的 Self-Attention 计算过程如下:

  1. 线性变换 : 将输入序列的每个 token 的 embedding 向量 x_i 通过三个线性变换矩阵 W_q, W_k, W_v 映射成三个向量:q_i (query), k_i (key), v_i (value)。

    python 复制代码
    # 假设 embedding_dim = 512, seq_len = 1024
    import torch
    x = torch.randn(1, 1024, 512)  # batch_size=1, seq_len=1024, embedding_dim=512
    W_q = torch.randn(512, 512)
    W_k = torch.randn(512, 512)
    W_v = torch.randn(512, 512)
    
    q = x @ W_q  # (1, 1024, 512)
    k = x @ W_k  # (1, 1024, 512)
    v = x @ W_v  # (1, 1024, 512)
  2. 计算注意力分数 : 计算每个 query q_i 与所有 key k_j 之间的点积,得到注意力分数 s_ij

    python 复制代码
    s = q @ k.transpose(-2, -1)  # (1, 1024, 1024)
  3. 缩放和掩码 : 对注意力分数进行缩放 (除以 sqrt(d_k), d_k 是 key 向量的维度),并应用掩码 (mask) 操作 (例如,在解码器中屏蔽未来 token)。

    python 复制代码
    import math
    d_k = k.shape[-1]
    s = s / math.sqrt(d_k)
    # 假设我们不需要 mask
  4. Softmax : 对缩放后的注意力分数应用 softmax 函数,得到注意力权重 a_ij

    python 复制代码
    a = torch.softmax(s, dim=-1)  # (1, 1024, 1024)
  5. 加权求和 : 使用注意力权重 a_ij 对所有 value 向量 v_j 进行加权求和,得到每个 token 的新的表示 y_i

    python 复制代码
    y = a @ v  # (1, 1024, 512)

问题 : 上述计算过程中,sa 这两个矩阵的大小都是 (seq_len, seq_len),当 seq_len 很大时 (例如 4096),这两个矩阵会占用大量的显存,并且计算 softmax 和矩阵乘法也非常耗时。

3. Flash Attention 原理

Flash Attention 的核心思想是:避免将整个注意力矩阵 sa 存储在 GPU 的高速缓存 (HBM) 中,而是将输入数据分块 (tiling),每次只加载一小部分数据到 SRAM 中进行计算,并将结果写回 HBM。

Flash Attention 主要利用了以下两个技术:

3.1 Tiling (分块)

将 Q, K, V 矩阵分成多个 block,每次只计算一个 block 的注意力。例如,可以将一个 (1024, 512) 的矩阵分成 16 个 (256, 512) 的 block。

3.2 Recomputation (重计算)

在反向传播时,不存储中间的注意力权重 a,而是在需要的时候重新计算。由于计算 a 的开销相对较小,这种方法可以节省大量的显存。

4. Llama 中 Flash Attention 的实现

Llama 使用了 Flash Attention 的改进版本,即 Paged Attention 。其核心思想与 Flash Attention 相同,但在处理长序列时更加高效。这里以llama2源码为例说明,其位于llama/model.py文件中,class Attention(nn.Module) 类下的forward函数中

以下是 Llama 源码中 Flash Attention 的简化版实现 (已去除部分细节):

python 复制代码
def flash_attention(q, k, v, block_size):
    """
    简化版的 Flash Attention 实现.

    Args:
        q: Query 矩阵 (B, H, N, D_head)
        k: Key 矩阵 (B, H, N, D_head)
        v: Value 矩阵 (B, H, N, D_head)
        block_size: 分块大小

    Returns:
        输出矩阵 (B, H, N, D_head)
    """
    B, H, N, D_head = q.shape
    O = torch.zeros_like(q)

    for i in range(0, N, block_size):
      # 加载当前 block 的数据到 SRAM
      qi = q[:, :, i:i + block_size, :]
      
      mi = -float('inf')  # 用于记录当前 block 的最大值
      li = 0.0  # 用于记录当前 block 的 softmax 的分母
      
      for j in range(0, N, block_size):
          # 加载当前 block 的数据到 SRAM
          kj = k[:, :, j:j + block_size, :]
          vj = v[:, :, j:j + block_size, :]

          # 计算注意力分数
          sij = torch.einsum('bhnd,bhmd->bhnm', qi, kj) / math.sqrt(D_head)

          # 更新最大值和 softmax 的分母
          mij = torch.max(sij, dim=-1).values
          li_new = torch.exp(mi - mij).unsqueeze(-1) * li + torch.sum(torch.exp(sij - mij.unsqueeze(-1)), dim=-1)
          
          # 更新输出
          
          O[:, :, i:i + block_size, :] = (li / li_new).unsqueeze(-1) * O[:, :, i:i + block_size, :] + \
                                        torch.einsum('bhnm,bhmd->bhnd', torch.exp(sij - mij.unsqueeze(-1)), vj) / li_new.unsqueeze(-1)

          mi = torch.max(mi, mij)
          li = li_new

    return O

# 示例:假设 block_size = 256
q = torch.randn(1, 8, 1024, 64)  # batch_size=1, heads=8, seq_len=1024, d_head=64
k = torch.randn(1, 8, 1024, 64)
v = torch.randn(1, 8, 1024, 64)
o = flash_attention(q, k, v, block_size=256)
print(o.shape) # torch.Size([1, 8, 1024, 64])

代码解释:

  1. q, k, v 分别表示 query, key, value 矩阵, block_size 表示分块大小。
  2. O 是输出矩阵,初始化为全零。
  3. 外层循环遍历 Q 矩阵的 block。
  4. 内层循环遍历 K, V 矩阵的 block。
  5. sij 计算当前 block 的注意力分数。
  6. mili 分别用于记录当前 block 的最大值和 softmax 的分母,以保证数值稳定性。
  7. O 使用增量更新的方式计算最终的输出结果。

注意: 上述代码只是 Flash Attention 的简化版实现,实际的 Llama 源码中还包括了 mask, dropout, causal mask 等操作,并且使用了更高效的 CUDA kernel 来加速计算。

简化版实现说明

上面的代码实现了一个简化版本的Flash Attention算法。它通过两个嵌套的循环来处理查询(Q)、键(K)和值(V)矩阵,这些矩阵被分成了多个块(block)。这种分块处理的方式旨在减少计算过程中的内存占用,特别是对于那些拥有大量头的注意力机制(如多头注意力机制)来说,可以显著提高计算效率。下面我们来逐步解释这段代码的核心逻辑:

初始化输出张量 O

  • O 被初始化为与查询张量 q 相同形状的全零张量。这个张量将累积每个块计算的结果,最终形成完整的输出。

外循环:遍历Q的块

  • 代码通过 for i in range(0, N, block_size): 循环遍历 Q 矩阵的块。变量 i 表示当前处理的块在序列维度上的起始位置。

内循环:遍历K和V的块

  • 对于Q中的每个块,代码通过 for j in range(0, N, block_size): 循环遍历 K 和 V 矩阵的块。变量 j 表示 K 和 V 当前处理的块在序列维度上的起始位置。

注意力分数的计算

  • sij = torch.einsum('bhnd,bhmd->bhnm', qi, kj) / math.sqrt(D_head) 计算当前 Q 块和 K 块之间的注意力分数。这里使用了爱因斯坦求和标记法(einsum),这是一种简洁表示张量操作的方式。

更新最大值和softmax的分母

  • mij = torch.max(sij, dim=-1).values 计算 sij 在最后一个维度上的最大值。
  • li_new = torch.exp(mi - mij).unsqueeze(-1) * li + torch.sum(torch.exp(sij - mij.unsqueeze(-1)), dim=-1) 更新softmax的分母。这里使用了数值稳定的技巧,通过减去最大值来避免指数运算产生过大的数值。

更新输出

  • O[:, :, i:i + block_size, :] = ... 这行代码是整个算法中最关键的部分。它根据当前计算的注意力分数和值(V)来更新输出张量 O 的相应块。这里通过加权求和的方式,将之前步骤的结果累加到 O 上。

更新 mili

  • mi = torch.max(mi, mij) 更新到目前为止遇到的最大值。
  • li = li_new 更新softmax的分母。

总结

这段代码实现了一种高效的注意力机制,通过分块处理和数值稳定的softmax计算,减少了内存占用并提高了计算效率。尽管代码进行了一定的简化以突出核心逻辑,但它捕捉了Flash Attention算法的关键思想。在实际应用中,还需要考虑如何高效地在硬件上实现这些操作,以及如何处理边界情况和性能优化。

5. Flash Attention 与标准 Self-Attention 的比较

特性 标准 Self-Attention Flash Attention
计算复杂度 O(N^2) O(N) (理论上, 实际取决于分块大小)
内存占用 O(N^2) O(N) (理论上, 实际取决于分块大小)
速度
适用场景 短序列 长序列
实现复杂性 简单 复杂
相关推荐
知识鱼丸1 小时前
自定义数据集 使用tensorflow框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测
人工智能
憨猪在度假2 小时前
Rk3588芯片介绍(含数据手册)
人工智能
西猫雷婶3 小时前
python学opencv|读取图像(五十二)使用cv.matchTemplate()函数实现最佳图像匹配
人工智能·python·opencv·计算机视觉
2301_793069823 小时前
OpenCV 图像旋转
人工智能·opencv·计算机视觉
纠结哥_Shrek3 小时前
基于最近邻数据进行分类
人工智能·分类·数据挖掘
kakaZhui4 小时前
【llm对话系统】大模型 Llama 源码分析之并行训练方案
人工智能·chatgpt·aigc·llama
Melancholy 啊4 小时前
细说机器学习算法之ROC曲线用于模型评估
人工智能·python·算法·机器学习·数据挖掘
爱研究的小牛5 小时前
Deepseek技术浅析(二):大语言模型
人工智能·机器学习·语言模型·自然语言处理·aigc
编程武士5 小时前
OpenCV 版本不兼容导致的问题
人工智能·opencv·计算机视觉