Large Language Model系列之三:大模型并行训练(Parallel Training of Large Language Models)

Large Language Model系列之三:大模型并行训练(Parallel Training of Large Language Models)

1 各类并行算法

参考资料:

1 大模型并行训练

2 ZeRO(Zero Redundancy Optimizer)零冗余优化

ZeRO(Zero Redundancy Optimizer)是由微软研究院开发的一种内存优化技术,专门设计用于优化大规模深度学习模型的训练过程。ZeRO的核心原理是通过减少内存冗余来提高训练效率,使得可以在有限的硬件资源上训练更大的模型。

以常用的Adam优化器为例,

GPU显存存储内容主要分为两大块:Model StatesResidual States :
Model States指和模型本身息息相关的,必须存储的内容,具体包括:

  • optimizer states:Adam优化算法中的m(梯度的一阶矩)和v(梯度的二阶矩)
  • gradients:模型梯度(g)
  • parameters:模型参数 Θ \Theta Θ

Residual States指并非模型必须的,但在训练过程中会额外产生的内容,具体包括:

  • activation:激活值。在流水线并行中我们曾详细介绍过。在backward过程中使用链式法则计算梯度时会用到。有了它算梯度会更快,但它不是必须存储的,因为可以通过重新做Forward来算它。
  • temporary buffers: 临时存储。例如把梯度发送到某块GPU上做加总聚合时产生的存储。
  • unusable fragment memory:碎片化的存储空间。虽然总存储空间是够的,但是如果取不到连续的存储空间,相关的请求也会被fail掉。对这类空间浪费可以通过内存整理来解决。
2-1 优化模型状态内存

模型状态占用了主要的机器内存,针对该问题提出的数据并行方法ZeRO-DP在实现数据并行高效计算的同时,拥有模型并行的内存节省优势。如下图所示,ZeRO-DP主要有三个优化阶段,分别对应了模型状态中优化器状态、梯度、以及模型参数的切分,也就是通常所说的ZeRO-1/2/3。

  • 优化器状态分区(Optimizer State Partitioning) P o s P_{os} Pos :将optimizer states分成若干份,每块GPU上各自维护一份。在这个阶段每块GPU还是完整的存储一份参数,一个batch的数据被划分成n份,每块GPU上用一份数据计算出一个完整的梯度值,然后计算出这n个GPU上的一个梯度均值,参数的更新都用这个梯度均值,这里注意:参数的更新是由optimizer states和梯度值共同所决定的,由于我们在这个阶段已经对optimizer states进行了分割,分别存储在了不同的GPU上,所以这里的参数只能更新一部分。分区优化器状态到各个计算卡中,在享有与普通数据并行相同通信量的情况下,可降低4倍的内存占用。
  • 添加梯度分区(Gradient Partitioning) P o s + g P_{os+g} Pos+g :在这一步中除了将optimizer states分成若干份,梯度也分成若干份。在这个阶段每块GPU还是完整的存储一份参数,一个batch的数据被划分成n份,这里注意:每块GPU上用一份数据计算出完整的梯度,然后每个GPU汇总自己维护的那部分梯度值,把不是自己维护的那部分梯度值从显存移除。用部分梯度值,部分optimizer states更新全参数中对应的那部分参数,同样再相互通信获得完整的更新后的参数。这一步骤参数和梯度值都需要通信交互。在 P o s P_{os} Pos的基础上,进一步将模型梯度切分到各个计算卡中,在享有与普通数据并行相同通信量的情况下,拥有8倍的内存降低能力。
  • 添加参数分区(Parameter Partitioning) P o s + g + p P_{os+g+p} Pos+g+p :这一步除了将optimizer states、梯度分成若干份,参数也要分区。每块GPU上只保存部分参数,前向反向传播时需要用到完整的参数的话相互通信获取全参数,用完立马从显存移除。梯度计算时也是计算出完整的梯度,然后每个GPU汇总自己维护的那部分梯度值,把不是自己维护的那部分梯度值从显存移除。用部分梯度值,部分optimizer states更新自己维护的那部分参数。在 P o s + g + p P_{os+g+p} Pos+g+p 的基础上,将模型参数也切分到各个计算卡中,内存降低能力与并行数量 N d N_{d} Nd成线性比例,通信量大约有50%的增长。

典型的以时间换空间的优化思想,为了节省显存的空间,增加了通信的时间消耗。当三阶段的ZeRO-DP优化全部启动以后,使用混合精度和Adam优化器的千亿模型(总共占用约16T内存)可以成功基于1024卡上使用常规的32G显卡训练(每卡占用约16G内存)

2-2 优化剩余状态内存

除了ZeRO-DP外,作者还设计了ZeRO-R来解决剩余状态带来的内存瓶颈问题。

剩余状态的冗余主要集中在两个方面:

  • 临时缓冲区:在模型训练过程中,临时缓冲区会累积大量的中间数据,这些数据在不再需要时若未能及时清理,便会造成内存资源的浪费。
  • 内存碎片:由于PyTorch等深度学习框架在变量生命周期管理中的特性,频繁地分配和释放内存会导致内存碎片化,这不仅减少了可用内存空间,还增加了内存分配的时间成本。

针对上述问题,作者创新性地提出了ZeRO-R方法,该方法通过以下策略来优化内存使用效率:

  1. 激活值分区检查点(Partitioned Activation Checkpointing)
  • 问题描述:在模型并行训练中,为了支持跨设备的数据交换,activation数据需要被复制,这导致了冗余。
  • 解决方案:ZeRO通过激活值分区技术来减少这种冗余。具体地,它在正向传播完成后立即对activation进行切分,并在需要时(如在反向传播中)通过all-gather操作重新组合这些分片,从而避免了不必要的数据复制。
  1. 固定大小缓冲区(Constant Size Buffers)
  • 问题描述:某些运算的效率与输入数据的大小密切相关,大型all-gather等操作在处理大批量数据时更为高效。然而,随着模型并行复杂度的增加,对内存的需求也急剧上升。
  • 解决方案:ZeRO引入固定大小的缓冲区策略,以优化这类操作的内存使用。这一策略类似于网络通信中的窗口大小调整,通过预先分配固定大小的内存块来减少内存管理的开销,同时提高数据处理的效率。
  1. 内存碎片整理(Memory Defragmentation)
  • 问题描述:PyTorch等框架在变量生命周期管理中,由于频繁的内存分配和释放,导致内存碎片化问题严重,影响了内存的有效利用率和分配效率。
  • 解决方案:ZeRO通过为activation checkpoint和gradients等关键数据预先分配连续的内存块,有效避免了内存碎片的产生。这种策略不仅提高了内存的利用率,还减少了内存分配时的搜索时间,从而提升了整体训练性能。

ZeRO++:降低4倍网络通信,显著提高大模型及类ChatGPT模型训练效率,核心思路如下:

ZeRO-1/2/3 以及 ZeRO++的汇总:

参考资料

1 ZeRO: Memory Optimizations Toward Training Trillion Parameter Models

2 大模型数据并行训练之DeepSpeed-ZeRO(零冗余优化)

3 ZeRO(零冗余优化器)

3 FlashAttention

Flash Attention 是一种针对 Transformer 模型中 Attention 机制的高效实现方式,旨在减少高带宽内存(如 HBM)的访问次数,同时利用 SRAM(静态随机存取存储器,通常指 GPU 上的 L1/L2 缓存)的高带宽特性来加速计算。

标准Attention算法执行过程如下:

在标准Attention算法的执行过程中,首先涉及从高带宽内存(HBM)中读取两个关键矩阵Q和K,它们各自具有N x d的维度。随后,这两个矩阵通过点积运算生成相似度得分矩阵S,其维度为N x N。这一步骤的HBM访问次数主要由读取Q和K矩阵以及写入S矩阵组成,总体上是O(Nd + N^2)次访问。

接下来,为了计算注意力权重P(同样为N x N矩阵),需要对相似度得分矩阵S进行softmax操作。softmax操作需要访问S矩阵中的所有元素以计算归一化权重,因此这一步也产生了O(N^2)次的HBM访问。

最后,在生成输出向量O(N x d矩阵)时,算法会将注意力权重P与值向量V(N x d矩阵)进行加权求和。此步骤的HBM访问主要集中在读取P和V矩阵以及写入O矩阵,共计O(Nd)次访问。

综上所述,标准Attention算法在整个执行过程中,其HBM访问的总次数达到了O(Nd + N^2)的复杂度。当处理的数据规模N非常大时,这种高频次的HBM访问会成为性能瓶颈,显著增加计算成本和时间消耗。因此,针对大规模数据,优化Attention算法以减少HBM访问次数成为了一个重要的研究方向。

FlashAttention的优化策略:

在Attention计算中,由于存在三个独立的核(kernel),每个核在处理时都要从HBM读取数据,并在计算后将结果写回HBM。通过将这三个核合并为一个,可以减少对HBM的访问次数。

在计算过程中,应优先利用SRAM进行计算,以减少对HBM的访问。尽管SRAM带宽较高,但其存储容量有限。采用分而治之的策略,通过Tiling将数据适配到SRAM容量。然而,当序列长度较大时,SRAM的限制可能导致序列被分割,这可能会干扰标准Softmax操作。

FlashAttention的优化策略如下:

Tiling(平铺):

采用"分治"策略,将大的注意力矩阵(如NxN的softmax/scores矩阵)分割成多个小得多的子矩阵。这些子矩阵的大小被精心设计,以确保它们能够完全存储在SRAM中,从而在计算过程中减少对HBM的依赖。

Block Softmax(分块Softmax):

然而,Attention机制中的softmax操作要求所有列(或行)的分数都必须参与归一化计算,这意味着子矩阵之间并非完全独立。为了解决这个问题,Flash Attention引入了分块SoftMax算法。这一算法在保持全局归一化的同时,对每个子矩阵独立进行softmax计算。通过一些巧妙的数学变换(如log-sum-exp技巧),能够确保分块SoftMax的结果与全局SoftMax高度一致,从而保证了Flash Attention的正确性。

Recomputation(重算):

为了进一步优化内存使用,Flash Attention还采用了Recomputation(重算)技术。这是一种在计算反向传播时减少内存占用的策略,通过避免存储所有正向传播的中间结果,并在需要时重新计算它们来节省内存。虽然这会增加一些计算成本,但相比于节省的内存和减少的HBM访问次数而言,这一代价通常是值得的。特别是在处理大规模数据集时,Recomputation技术能够显著提升训练效率和可扩展性。

FlashAttention的计算过程:

  1. 数据平铺(Tiling):
    将输入序列Q、K、V分割成较小的块,每块大小适合在快速访问的SRAM中处理。例如,如果序列长度为N,可以将其分割成t个大小为N/t的块。
  2. 分块计算相似度(Score Calculation):
    对于每个Q的块,计算与K的所有块的点积,得到局部相似度分数。这不是标准的自注意力计算,因为只计算了部分K对的相似度。
  3. 局部Softmax:
    对每个局部相似度分数块应用Softmax,得到局部注意力权重。这些权重是针对每个块内部的,可能不会反映整个序列的全局关系。
  4. 加权求和(Weighted Sum):
    使用局部注意力权重加权求和对应的V块,得到每个Q块的输出。
  5. 重算(Recomputation):
    在反向传播中,不存储所有中间状态。当需要计算梯度时,重新计算正向传播中的中间状态,从而减少内存占用。

示例说明:

假设有一个Transformer模型,输入序列长度为N=1024,特征维度为d=512。使用FlashAttention进行优化:

  1. Tiling:
    将Q、K、V矩阵分割成16个大小为64x512的块。
  2. 分块计算相似度:
    计算每个Q块与所有K块的点积,得到64x64的局部相似度矩阵。
  3. 局部Softmax:
    对每个64x64的局部相似度矩阵应用Softmax,得到局部注意力权重。
  4. 加权求和:
    使用局部注意力权重加权求和对应的V块,得到64x512的输出块。
  5. 拼接输出:
    将所有输出块拼接起来,形成最终的输出序列。
  6. 重算:
    在反向传播中,当需要计算某个Q块的梯度时,重新计算与该块相关的所有中间状态。

核心优势:

减少HBM访问:通过在SRAM中进行计算,FlashAttention减少了对HBM的访问次数。

内存效率:通过分块处理,FlashAttention适应了有限的内存资源,特别是对于大型序列。

灵活性:FlashAttention可以适应不同的硬件配置,通过调整块大小来优化性能。

FlashAttention通过这些策略在保持Transformer模型性能的同时,提高了模型的计算效率和内存效率。然而,这种优化可能需要特定的硬件支持,并且可能需要对模型架构进行调整以充分利用其优势。

参考资料

1 Fast and Memory-Efficient Exact Attention with IO-Awareness

2 通俗易懂聊flashAttention的加速原理

相关推荐
世优科技虚拟人3 分钟前
AI、VR与空间计算:教育和文旅领域的数字转型力量
人工智能·vr·空间计算
cloud studio AI应用9 分钟前
腾讯云 AI 代码助手:产品研发过程的思考和方法论
人工智能·云计算·腾讯云
禁默21 分钟前
第六届机器人、智能控制与人工智能国际学术会议(RICAI 2024)
人工智能·机器人·智能控制
Robot25128 分钟前
浅谈,华为切入具身智能赛道
人工智能
只怕自己不够好33 分钟前
OpenCV 图像运算全解析:加法、位运算(与、异或)在图像处理中的奇妙应用
图像处理·人工智能·opencv
果冻人工智能2 小时前
2025 年将颠覆商业的 8 大 AI 应用场景
人工智能·ai员工
代码不行的搬运工2 小时前
神经网络12-Time-Series Transformer (TST)模型
人工智能·神经网络·transformer
石小石Orz2 小时前
Three.js + AI:AI 算法生成 3D 萤火虫飞舞效果~
javascript·人工智能·算法
孤独且没人爱的纸鹤2 小时前
【深度学习】:从人工神经网络的基础原理到循环神经网络的先进技术,跨越智能算法的关键发展阶段及其未来趋势,探索技术进步与应用挑战
人工智能·python·深度学习·机器学习·ai
阿_旭2 小时前
TensorFlow构建CNN卷积神经网络模型的基本步骤:数据处理、模型构建、模型训练
人工智能·深度学习·cnn·tensorflow