Large Language Model系列之三:大模型并行训练(Parallel Training of Large Language Models)
1 各类并行算法
参考资料:
1 大模型并行训练
2 ZeRO(Zero Redundancy Optimizer)零冗余优化
ZeRO(Zero Redundancy Optimizer)是由微软研究院开发的一种内存优化技术,专门设计用于优化大规模深度学习模型的训练过程。ZeRO的核心原理是通过减少内存冗余来提高训练效率,使得可以在有限的硬件资源上训练更大的模型。
以常用的Adam优化器为例,
GPU显存存储内容主要分为两大块:Model States 和Residual States :
Model States指和模型本身息息相关的,必须存储的内容,具体包括:
- optimizer states:Adam优化算法中的m(梯度的一阶矩)和v(梯度的二阶矩)
- gradients:模型梯度(g)
- parameters:模型参数 Θ \Theta Θ
Residual States指并非模型必须的,但在训练过程中会额外产生的内容,具体包括:
- activation:激活值。在流水线并行中我们曾详细介绍过。在backward过程中使用链式法则计算梯度时会用到。有了它算梯度会更快,但它不是必须存储的,因为可以通过重新做Forward来算它。
- temporary buffers: 临时存储。例如把梯度发送到某块GPU上做加总聚合时产生的存储。
- unusable fragment memory:碎片化的存储空间。虽然总存储空间是够的,但是如果取不到连续的存储空间,相关的请求也会被fail掉。对这类空间浪费可以通过内存整理来解决。
2-1 优化模型状态内存
模型状态占用了主要的机器内存,针对该问题提出的数据并行方法ZeRO-DP在实现数据并行高效计算的同时,拥有模型并行的内存节省优势。如下图所示,ZeRO-DP主要有三个优化阶段,分别对应了模型状态中优化器状态、梯度、以及模型参数的切分,也就是通常所说的ZeRO-1/2/3。
- 优化器状态分区(Optimizer State Partitioning) P o s P_{os} Pos :将optimizer states分成若干份,每块GPU上各自维护一份。在这个阶段每块GPU还是完整的存储一份参数,一个batch的数据被划分成n份,每块GPU上用一份数据计算出一个完整的梯度值,然后计算出这n个GPU上的一个梯度均值,参数的更新都用这个梯度均值,这里注意:参数的更新是由optimizer states和梯度值共同所决定的,由于我们在这个阶段已经对optimizer states进行了分割,分别存储在了不同的GPU上,所以这里的参数只能更新一部分。分区优化器状态到各个计算卡中,在享有与普通数据并行相同通信量的情况下,可降低4倍的内存占用。
- 添加梯度分区(Gradient Partitioning) P o s + g P_{os+g} Pos+g :在这一步中除了将optimizer states分成若干份,梯度也分成若干份。在这个阶段每块GPU还是完整的存储一份参数,一个batch的数据被划分成n份,这里注意:每块GPU上用一份数据计算出完整的梯度,然后每个GPU汇总自己维护的那部分梯度值,把不是自己维护的那部分梯度值从显存移除。用部分梯度值,部分optimizer states更新全参数中对应的那部分参数,同样再相互通信获得完整的更新后的参数。这一步骤参数和梯度值都需要通信交互。在 P o s P_{os} Pos的基础上,进一步将模型梯度切分到各个计算卡中,在享有与普通数据并行相同通信量的情况下,拥有8倍的内存降低能力。
- 添加参数分区(Parameter Partitioning) P o s + g + p P_{os+g+p} Pos+g+p :这一步除了将optimizer states、梯度分成若干份,参数也要分区。每块GPU上只保存部分参数,前向反向传播时需要用到完整的参数的话相互通信获取全参数,用完立马从显存移除。梯度计算时也是计算出完整的梯度,然后每个GPU汇总自己维护的那部分梯度值,把不是自己维护的那部分梯度值从显存移除。用部分梯度值,部分optimizer states更新自己维护的那部分参数。在 P o s + g + p P_{os+g+p} Pos+g+p 的基础上,将模型参数也切分到各个计算卡中,内存降低能力与并行数量 N d N_{d} Nd成线性比例,通信量大约有50%的增长。
典型的以时间换空间的优化思想,为了节省显存的空间,增加了通信的时间消耗。当三阶段的ZeRO-DP优化全部启动以后,使用混合精度和Adam优化器的千亿模型(总共占用约16T内存)可以成功基于1024卡上使用常规的32G显卡训练(每卡占用约16G内存)
2-2 优化剩余状态内存
除了ZeRO-DP外,作者还设计了ZeRO-R来解决剩余状态带来的内存瓶颈问题。
剩余状态的冗余主要集中在两个方面:
- 临时缓冲区:在模型训练过程中,临时缓冲区会累积大量的中间数据,这些数据在不再需要时若未能及时清理,便会造成内存资源的浪费。
- 内存碎片:由于PyTorch等深度学习框架在变量生命周期管理中的特性,频繁地分配和释放内存会导致内存碎片化,这不仅减少了可用内存空间,还增加了内存分配的时间成本。
针对上述问题,作者创新性地提出了ZeRO-R方法,该方法通过以下策略来优化内存使用效率:
- 激活值分区检查点(Partitioned Activation Checkpointing):
- 问题描述:在模型并行训练中,为了支持跨设备的数据交换,activation数据需要被复制,这导致了冗余。
- 解决方案:ZeRO通过激活值分区技术来减少这种冗余。具体地,它在正向传播完成后立即对activation进行切分,并在需要时(如在反向传播中)通过all-gather操作重新组合这些分片,从而避免了不必要的数据复制。
- 固定大小缓冲区(Constant Size Buffers):
- 问题描述:某些运算的效率与输入数据的大小密切相关,大型all-gather等操作在处理大批量数据时更为高效。然而,随着模型并行复杂度的增加,对内存的需求也急剧上升。
- 解决方案:ZeRO引入固定大小的缓冲区策略,以优化这类操作的内存使用。这一策略类似于网络通信中的窗口大小调整,通过预先分配固定大小的内存块来减少内存管理的开销,同时提高数据处理的效率。
- 内存碎片整理(Memory Defragmentation):
- 问题描述:PyTorch等框架在变量生命周期管理中,由于频繁的内存分配和释放,导致内存碎片化问题严重,影响了内存的有效利用率和分配效率。
- 解决方案:ZeRO通过为activation checkpoint和gradients等关键数据预先分配连续的内存块,有效避免了内存碎片的产生。这种策略不仅提高了内存的利用率,还减少了内存分配时的搜索时间,从而提升了整体训练性能。
ZeRO++:降低4倍网络通信,显著提高大模型及类ChatGPT模型训练效率,核心思路如下:
ZeRO-1/2/3 以及 ZeRO++的汇总:
参考资料
1 ZeRO: Memory Optimizations Toward Training Trillion Parameter Models
2 大模型数据并行训练之DeepSpeed-ZeRO(零冗余优化)
3 FlashAttention
Flash Attention 是一种针对 Transformer 模型中 Attention 机制的高效实现方式,旨在减少高带宽内存(如 HBM)的访问次数,同时利用 SRAM(静态随机存取存储器,通常指 GPU 上的 L1/L2 缓存)的高带宽特性来加速计算。
标准Attention算法执行过程如下:
在标准Attention算法的执行过程中,首先涉及从高带宽内存(HBM)中读取两个关键矩阵Q和K,它们各自具有N x d的维度。随后,这两个矩阵通过点积运算生成相似度得分矩阵S,其维度为N x N。这一步骤的HBM访问次数主要由读取Q和K矩阵以及写入S矩阵组成,总体上是O(Nd + N^2)次访问。
接下来,为了计算注意力权重P(同样为N x N矩阵),需要对相似度得分矩阵S进行softmax操作。softmax操作需要访问S矩阵中的所有元素以计算归一化权重,因此这一步也产生了O(N^2)次的HBM访问。
最后,在生成输出向量O(N x d矩阵)时,算法会将注意力权重P与值向量V(N x d矩阵)进行加权求和。此步骤的HBM访问主要集中在读取P和V矩阵以及写入O矩阵,共计O(Nd)次访问。
综上所述,标准Attention算法在整个执行过程中,其HBM访问的总次数达到了O(Nd + N^2)的复杂度。当处理的数据规模N非常大时,这种高频次的HBM访问会成为性能瓶颈,显著增加计算成本和时间消耗。因此,针对大规模数据,优化Attention算法以减少HBM访问次数成为了一个重要的研究方向。
FlashAttention的优化策略:
在Attention计算中,由于存在三个独立的核(kernel),每个核在处理时都要从HBM读取数据,并在计算后将结果写回HBM。通过将这三个核合并为一个,可以减少对HBM的访问次数。
在计算过程中,应优先利用SRAM进行计算,以减少对HBM的访问。尽管SRAM带宽较高,但其存储容量有限。采用分而治之的策略,通过Tiling将数据适配到SRAM容量。然而,当序列长度较大时,SRAM的限制可能导致序列被分割,这可能会干扰标准Softmax操作。
FlashAttention的优化策略如下:
Tiling(平铺):
采用"分治"策略,将大的注意力矩阵(如NxN的softmax/scores矩阵)分割成多个小得多的子矩阵。这些子矩阵的大小被精心设计,以确保它们能够完全存储在SRAM中,从而在计算过程中减少对HBM的依赖。
Block Softmax(分块Softmax):
然而,Attention机制中的softmax操作要求所有列(或行)的分数都必须参与归一化计算,这意味着子矩阵之间并非完全独立。为了解决这个问题,Flash Attention引入了分块SoftMax算法。这一算法在保持全局归一化的同时,对每个子矩阵独立进行softmax计算。通过一些巧妙的数学变换(如log-sum-exp技巧),能够确保分块SoftMax的结果与全局SoftMax高度一致,从而保证了Flash Attention的正确性。
Recomputation(重算):
为了进一步优化内存使用,Flash Attention还采用了Recomputation(重算)技术。这是一种在计算反向传播时减少内存占用的策略,通过避免存储所有正向传播的中间结果,并在需要时重新计算它们来节省内存。虽然这会增加一些计算成本,但相比于节省的内存和减少的HBM访问次数而言,这一代价通常是值得的。特别是在处理大规模数据集时,Recomputation技术能够显著提升训练效率和可扩展性。
FlashAttention的计算过程:
- 数据平铺(Tiling):
将输入序列Q、K、V分割成较小的块,每块大小适合在快速访问的SRAM中处理。例如,如果序列长度为N,可以将其分割成t个大小为N/t的块。 - 分块计算相似度(Score Calculation):
对于每个Q的块,计算与K的所有块的点积,得到局部相似度分数。这不是标准的自注意力计算,因为只计算了部分K对的相似度。 - 局部Softmax:
对每个局部相似度分数块应用Softmax,得到局部注意力权重。这些权重是针对每个块内部的,可能不会反映整个序列的全局关系。 - 加权求和(Weighted Sum):
使用局部注意力权重加权求和对应的V块,得到每个Q块的输出。 - 重算(Recomputation):
在反向传播中,不存储所有中间状态。当需要计算梯度时,重新计算正向传播中的中间状态,从而减少内存占用。
示例说明:
假设有一个Transformer模型,输入序列长度为N=1024,特征维度为d=512。使用FlashAttention进行优化:
- Tiling:
将Q、K、V矩阵分割成16个大小为64x512的块。 - 分块计算相似度:
计算每个Q块与所有K块的点积,得到64x64的局部相似度矩阵。 - 局部Softmax:
对每个64x64的局部相似度矩阵应用Softmax,得到局部注意力权重。 - 加权求和:
使用局部注意力权重加权求和对应的V块,得到64x512的输出块。 - 拼接输出:
将所有输出块拼接起来,形成最终的输出序列。 - 重算:
在反向传播中,当需要计算某个Q块的梯度时,重新计算与该块相关的所有中间状态。
核心优势:
减少HBM访问:通过在SRAM中进行计算,FlashAttention减少了对HBM的访问次数。
内存效率:通过分块处理,FlashAttention适应了有限的内存资源,特别是对于大型序列。
灵活性:FlashAttention可以适应不同的硬件配置,通过调整块大小来优化性能。
FlashAttention通过这些策略在保持Transformer模型性能的同时,提高了模型的计算效率和内存效率。然而,这种优化可能需要特定的硬件支持,并且可能需要对模型架构进行调整以充分利用其优势。
参考资料
1 Fast and Memory-Efficient Exact Attention with IO-Awareness