飞桨首创 FlashMask :加速大模型灵活注意力掩码计算,长序列训练的利器

在 Transformer 类大模型训练任务中,注意力掩码(Attention Mask)一方面带来了大量的冗余计算,另一方面因其 O ( N 2 ) O(N^2) O(N2)巨大的存储占用导致难以实现长序列场景的高效训练(其中 N N N为序列长度)。虽然业界已有 FlashAttention 等针对特定注意力掩码的计算加速方法,但其支持的注意力掩码模式有限,难以满足大模型训练任务对灵活注意力掩码的需求。为了解决上述问题,飞桨独创 FlashMask 技术,提出了列式稀疏的注意力掩码表示方法,支持灵活多样的注意力掩码模式,使得存储复杂度从 O ( N 2 ) O(N^2) O(N2)降低至 O ( N ) O(N) O(N),并在此基础上实现了高效的算子 Kernel,极致加速大模型训练效率,尤其是长序列场景下的训练效率。

我们在NVIDIA A100 (80G) GPU上对 FlashMask 在大语言模型微调和对齐训练中的表现进行了评估,包括 SFT、LoRA、DPO 和 RM。与现有的 FlashAttention 稠密掩码方法相比,FlashMask 在端到端训练速度上实现了显著提升,速度提高幅度在 1.65 倍到 3.22 倍之间。此外,我们还评估了其 Kernel 层次上的性能。FlashMask 在理论最大浮点运算次数上达到了37.8%到62.3%,在 Kernel 每秒浮点运算次数(TFLOPs/s)方面,其性能超过 FlexAttention,提升幅度为 12.1% 到 60.7%。

arXiv 论文: https://arxiv.org/abs/2410.01359

PaddlePaddle 官方文档: https://www.paddlepaddle.org.cn/documentation/docs/en/develop/api/paddle/nn/functional/flashmask_attention_en.html

PaddleNLP 开源集成:https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/docs/flashmask.md

星河社区快速体验:https://aistudio.baidu.com/projectdetail/8459413

一、大语言模型的挑战

随着人工智能技术的迅猛发展,以 Transformer 为代表的大模型在自然语言处理、计算机视觉和多模态应用中展现出了非凡的能力。在这些大模型中,注意力(Attention)机制是一个关键环节。为了在大模型训练任务中确定哪些 Query-Key token 之间需要进行有效的 Attention 计算,业界通常使用注意力掩码(Attention Mask)。然而,目前的注意力掩码通常采用二维稠密矩阵表示,这导致了一些问题。一方面,这种表示方法引入了大量冗余计算,因为许多无效 token 的 Attention 仍需计算;另一方面,这种掩码的空间复杂度为 O ( N 2 ) O(N^2) O(N2)(其中 N N N为序列长度),在长序列的训练场景中会造成巨大的存储压力,难以进行高效训练。为了解决这些问题,业界已经提出了一些方案,如 Memory Efficient Attention (MEA) [1] 和 FlashAttention [2]。然而,这些方案支持的注意力掩码类型较为有限。正如图 1 所示,FlashAttention 只能支持如纯因果掩码(Causal)、滑动窗口掩码(Sliding Window)、因果文档掩码(Causal Document Mask)和文档掩码(Document Mask)等几种固定形式的掩码。然而,实际训练任务中使用的注意力掩码形式往往丰富多变,当前技术难以满足大模型不同训练任务对注意力掩码灵活性的要求。

图1 常见的注意力掩码类型

二、FlashMask 的创新:列式稀疏掩码表示方法与高效计算

1. 关键洞察

FlashMask 的核心发现是,在大模型常见的注意力掩码模式中,Query-Key token 的掩码模式具有一定的连续性。具体而言,对于每一个 Key token,无效注意力计算的 Query token 是相邻排列的。也就是说,在图 1 中二维掩码矩阵中,Query token 作用在每一列的 Key token 的灰色部分沿列方向连续分布。基于这一洞察,FlashMask 巧妙地将二维稠密掩码矩阵转换为一维的行索引区间,从而实现更为紧凑的表示形式,并显著降低了存储需求。我们可以公式化表示为:

M j = [ s t a r t j , e n d j ) , ∀ j ∈ { 1 , ... , N } M_{j} = [start_j, end_j), \quad \forall j \in \{1, \ldots, N\} Mj=[startj,endj),∀j∈{1,...,N}

其中 N N N为 Key 的序列长度, M j M_j Mj为二维的稠密掩码矩阵的第 j j j列, [ s t a r t j , e n d j ) [start_j, end_j) [startj,endj)为连续的行索引区间,表示 s t a r t j start_j startj到 e n d j − 1 end_{j} - 1 endj−1的连续 Query token 是被 mask 掉,置为无效 Attention 计算。

2. 注意力掩码的列式稀疏掩码表示方法

为了高效处理因果和双向注意力场景中的复杂掩码模式,FlashMask 提出了一种新颖的列式稀疏表示方法。以对角线为区分,它使用四个一维向量来表示掩码:

  • 下三角起始行索引(Lower Triangular Start,简称 L T S LTS LTS)
  • 下三角结束行索引(Lower Triangular End,简称 L T E LTE LTE)
  • 上三角起始行索引(Upper Triangular Start,简称 U T S UTS UTS)
  • 上三角结束行索引(Upper Triangular End,简称 U T E UTE UTE)
    其中下三角被 mask 掉的行索引区间使用 [ 𝐿 𝑇 𝑆 , 𝐿 𝑇 𝐸 ) [𝐿𝑇𝑆, 𝐿𝑇𝐸) [LTS,LTE)表示,上三角被 mask 掉的行索引区间使用 [ 𝑈 𝑇 𝑆 , 𝑈 𝑇 𝐸 ) [𝑈𝑇𝑆, 𝑈𝑇𝐸) [UTS,UTE) 表示。 如图2所示,我们展示了16个Query token 和16个Key token 做 Attention 计算时较为复杂的二维稠密因果注意力的掩码矩阵,灰色单元格是 mask 区域。

    图2 较为复杂的二维稠密因果注意力的掩码矩阵示意图

可以通过 [ 𝐿 𝑇 𝑆 , 𝐿 𝑇 𝐸 ) [𝐿𝑇𝑆, 𝐿𝑇𝐸) [LTS,LTE)两个向量进行表达,如下所示:

以第 0 列为例,开始 mask 的行为 13,结束 mask 的行为 15(开区间),表示位置为 13 和 14 的 Query token 不与位置为 0 的 Key token 做有效 Attention 计算。

更多的例子参考图 3,FlashMask 使用列式稀疏掩码表示方法,表达了图 1 中所有的注意力掩码模式。其中 - 的空缺表示在不同的场景下有不同的默认值, L T S LTS LTS和 U T S UTS UTS中的默认值是 0,表示 mask 区域默认从第 0 行开始, L T E LTE LTE和 U T E UTE UTE中的默认值是 Query 的序列长度,表示 mask 区域默认结束于最后一行。

图3 使用 FlashMask 的列式稀疏掩码表示方法表示图1的注意力掩码模式

3. 扩展 FlashAttention 支持复杂掩码

FlashMask 将列式掩码表示方法集成到 FlashAttention-2 算法中,增强了其对注意力掩码的支持能力。在 FlashAttention Kernel 的分块计算基础上,FlashMask 利用上述的 L T S LTS LTS、 L T E LTE LTE、 U T S UTS UTS、 U T E UTE UTE 等掩码向量,来判断当前分块的掩码类型:

  • 完全掩码块:此类块的所有元素均被掩码,计算时可直接跳过。
  • 部分掩码块:此类块仅部分元素被掩码,因此需要对该块进行逐元素的掩码处理。
  • 未掩码块:此类块中的所有元素均未被掩码,可以简化计算过程,无需额外的掩码操作。
    通过这种分类处理,FlashMask 显著提升了计算效率:完全掩码块的计算被直接跳过,未掩码块的计算得到简化,仅对部分掩码块执行必要的掩码操作,如图4所示。
    图4 FlashMask 计算过程示意图

算法1详细描述了 FlashMask 扩展 FlashAttention-2 的前向计算过程,其中浅蓝色阴影部分表示 FlashMask 新增的计算步骤 [3]。

算法1 FlashMask 的前向计算伪代码

4. 效率提升与精度保证

FlashMask 充分利用了注意力掩码中的稀疏性,通过跳过完全掩码块的计算,减少了计算开销,同时不改变算法的精度。与使用稠密掩码矩阵的注意力计算保持比特级别的数值等效性,确保了精度无损。

三、FlashMask 的优势:速度与存储的双重提升

1. 端到端训练吞吐量提升

在 Llama-2 7B、13B、70B 等模型规模下,针对 SFT、LoRA、DPO、RM 四种下游训练场景和不同序列长度的实验表明,FlashMask 在各个模型规模和序列长度下均实现了端到端的加速和存储效率的提升。相比现有的基于稠密掩码矩阵的计算方法,FlashMask 实现了 1.65 倍至 3.22 倍的吞吐量提升,并支持更长的序列长度。

图5 在四个下游训练任务(SFT、LoRA、DPO 和 RM)中,3 个 Llama2 模型规模,在不同序列长度下的端到端训练吞吐量

图6 在四个下游训练任务(SFT、LoRA、DPO 和 RM)中,3 个 Llama2 模型规模,不同序列长度下的端到端训练峰值显存消耗

表2 在 Llama2 7B 模型上 FlashMask 对比 FlashAttention (Causal=True) 的显存消耗,单位(GB)

2. 端到端训练收敛验证

在 Llama 3.1 模型上的实验验证了 FlashMask 对收敛精度没有影响。作为一种精确的算法,通过控制计算过程的随机性(如 FlashAttention 反向 Query 梯度计算使用 atomicAdd 操作),FlashMask 可以与使用稠密掩码的 FlashAttention 在比特级别精确对齐。

图7 在四个下游训练任务(SFT、LoRA、DPO 和 RM)中,Llama3.1 8B 模型端到端训练 Loss 对比

3. 稀疏度与 Kernel 计算时延的线性关系

FlashMask 利用注意力掩码的块稀疏性,跳过完全掩码块的计算,将计算复杂度降低到 O ( ( 1 − ρ ) T r T c ) O((1 - ρ)T_rT_c) O((1−ρ)TrTc),其中 ρ ρ ρ 表示块稀疏性。为了验证这一关系,FlashMask 进行了多组实验,测试了三种不同的掩码类型(因果文档掩码、共享问题掩码和文档掩码),并使用不同稀疏度的数据。实验结果(如图 5 所示)表明,Kernel 执行延迟与稀疏性之间呈线性关系,意味着随着稀疏性的增加,FlashMask 的计算速度进一步提升。

图8 不同块稀疏度下的 Kernel 计算时延

4. Kernel 性能对比

关注到近期 PyTorch 推出了 FlexAttention [4](使用编译器技术支持 Attention Mask),FlashMask 与之在 Kernel 级别进行了对比。在各种常见的注意力掩码模式下,FlashMask 展现了更高的计算效率。在 TFLOPs/s 指标上,FlashMask 比 FlexAttention 高出 12.1% 至 60.7%,在 A100 GPU 上实现了 37.8% 至 62.3% 的理论峰值计算性能。

图9 在 A100-SXM 80G GPU 上的 Kernel 前向和反向速度对比。FlexAttention 使用 PyTorch 2.6.0.dev20240920+cu124

四、FlashMask 的应用:赋能大语言模型

FlashMask 的创新和优势为 Transformer 类大模型的注意力机制训练加速开辟了新的可能,可广泛应用于各种任务,并支持超长序列高效训练。

1. 可广泛应用于大语言模型的下游训练加速

FlashMask 可以应用于大语言模型的下游任务训练,例如 SFT、LoRA、DPO、RM 等。特别是在 DPO 和 RM 的训练中,其数据由问题和回答对组成,训练时多个答案可以共享一个问题,从而大幅减少对问题 token 的冗余计算。

2. 支持单向/双向混合注意力掩码模式训练

FlashMask 支持多种注意力模式,包括因果掩码(单向注意力)和文档掩码(双向注意力),因此能够灵活地应用于需要混合注意力的场景。例如:

  • 全局 + 滑动窗口掩码:这种掩码结合了全局注意力和滑动窗口注意力,既能捕捉全局上下文信息,又能关注局部细节。FlashMask 能高效处理这种混合掩码,提升模型性能。
  • 前缀语言模型:在生成文本时,前缀部分需要关注所有的 token,而其他部分使用因果掩码(如 T5 模型的预训练)。FlashMask 可以同时支持这两种注意力模式,提高前缀语言模型的训练和推理效率。

3. 支持多模态图文数据的混合多分辨率训练

在多模态数据处理中,不同模态的数据可能具有不同的分辨率。虽然文中未明确提及 FlashMask 在多模态和多分辨率训练中的应用,但 FlashMask 可以通过不同的注意力模式和掩码策略,有效处理这些具有不同分辨率的数据。针对长序列处理能力的优化,使得FlashMask能够帮助模型更好地学习不同模态数据之间的关联。例如,在图文匹配任务中,FlashMask 可以帮助模型更有效地对齐图像和文本中的关键信息。

FlashMask 的开源代码已在 PaddlePaddle 和 PaddleNLP 平台发布,支持超过千亿参数的模型以及超过 128K tokens 的上下文长度。我们相信,FlashMask 将成为推动大语言模型发展的重要力量,为算法研究人员提供更广阔的注意力掩码创新与研究空间。

结语

飞桨首创 FlashMask 技术,通过创新的列式稀疏掩码表示方法和高效的Kernel实现方式,解决了传统注意力掩码方法冗余计算和存储占用过高等难题,助力大模型训练加速,尤其是基于长文场景的训练加速。未来,飞桨将持续研发高效的大模型训练加速技术,不断推进大模型技术更新迭代。

参考文献

[1] Self-attention Does Not Need O(n^2) Memory. https://arxiv.org/abs/2112.05682

[2] FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. https://arxiv.org/abs/2307.08691

[3] FlashMask: Efficient and Rich Mask Extension of FlashAttention. https://arxiv.org/abs/2410.01359

[4] FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention. https://pytorch.org/blog/flexattention/

相关推荐
古希腊掌管学习的神11 分钟前
[机器学习]XGBoost(3)——确定树的结构
人工智能·机器学习
ZHOU_WUYI39 分钟前
4.metagpt中的软件公司智能体 (ProjectManager 角色)
人工智能·metagpt
靴子学长1 小时前
基于字节大模型的论文翻译(含免费源码)
人工智能·深度学习·nlp
AI_NEW_COME2 小时前
知识库管理系统可扩展性深度测评
人工智能
海棠AI实验室3 小时前
AI的进阶之路:从机器学习到深度学习的演变(一)
人工智能·深度学习·机器学习
hunteritself3 小时前
AI Weekly『12月16-22日』:OpenAI公布o3,谷歌发布首个推理模型,GitHub Copilot免费版上线!
人工智能·gpt·chatgpt·github·openai·copilot
IT古董3 小时前
【机器学习】机器学习的基本分类-强化学习-策略梯度(Policy Gradient,PG)
人工智能·机器学习·分类
centurysee3 小时前
【最佳实践】Anthropic:Agentic系统实践案例
人工智能
mahuifa3 小时前
混合开发环境---使用编程AI辅助开发Qt
人工智能·vscode·qt·qtcreator·编程ai
四口鲸鱼爱吃盐3 小时前
Pytorch | 从零构建GoogleNet对CIFAR10进行分类
人工智能·pytorch·分类