LLM分布式训练---混合并行(2D 3D)

近年来,随着Transformer、MOE架构的提出,使得深度学习模型轻松突破上万亿规模参数,传统的单机单卡模式已经无法满足超大模型进行训练的要求。因此,我们需要基于单机多卡、甚至是多机多卡进行分布式大模型的训练。

而利用AI集群,使深度学习算法更好地从大量数据中高效地训练出性能优良的大模型是分布式机器学习的首要目标。为了实现该目标,一般需要根据硬件资源与数据/模型规模的匹配情况,考虑对计算任务、训练数据和模型进行划分,从而进行分布式存储和分布式训练。

因此,分布式训练相关技术值得我们进行深入分析其背后的机理。

但是在进行上百亿/千亿级以上参数规模的超大模型预训练时,通常会组合多种并行技术一起使用。

常见的分布式并行技术组合

1. DP + PP

下图演示了如何将 DP 与 PP 结合起来使用。

这里重要的是要了解 DP rank 0 是看不见 GPU2 的, 同理,DP rank 1 是看不到 GPU3 的。

对于 DP 而言,只有 GPU 0 和 1,并向它们供给数据。GPU0 使用 PP 将它的一些负载转移到 GPU2。同样地, GPU1 也会将它的一些负载转移到 GPU3 。

由于每个维度至少需要 2 个 GPU;因此,这儿至少需要 4 个 GPU。

2. 3D 并行(DP + PP + TP)

而为了更高效地训练,可以将 PP、TP 和 DP 相结合,被业界称为 3D 并行,如下图所示。

由于每个维度至少需要 2 个 GPU,因此在这里你至少需要 8 个 GPU 才能实现完整的 3D 并行。

3. ZeRO --- DP + PP + TP

作为 DeepSpeed 的主要功能之一,它是 DP 的超级可伸缩增强版,并启发了 PyTorch FSDP 的诞生。通常它是一个独立的功能,不需要 PP 或 TP。但它也可以与 PP、TP 结合使用。

当 ZeRO-DP 与 PP 和 TP 结合使用时,通常只启用 ZeRO 阶段 1(只对优化器状态进行分片)

而 ZeRO 阶段 2 还会对梯度进行分片,ZeRO 阶段 3 还会对模型权重进行分片。

虽然理论上可以将 ZeRO 阶段 2 与 流水线并行PP一起使用,但它会对性能产生不良影响。每个 micro batch 都需要一个额外的 reduce-scatter 通信来在分片之前聚合梯度,这会增加潜在的显著通信开销。根据流水线并行的性质,我们会使用小的 micro batch ,并把重点放在算术强度 (micro batch size) 与最小化流水线气泡 (micro batch 的数量) 两者间折衷。因此,增加的通信开销会损害流水线并行。

此外,由于 PP,层数已经比正常情况下少,因此并不会节省很多内存。PP 已经将梯度大小减少了 1/PP,因此在此基础之上的梯度分片和纯 DP 相比节省不了多少内存。

除此之外,我们也可以采用 DP + TP 进行组合、也可以使用 PP + TP 进行组合,还可以使用 ZeRO3 代替 DP + PP + TP,ZeRO3 本质上是DP+MP的组合,并且无需对模型进行过多改造,使用更方便。

【深度学习】【分布式训练】一文捋顺千亿模型训练技术:流水线并行、张量并行和3D并行 - 知乎 (zhihu.com)

用通俗易懂的方式讲解大模型分布式训练并行技术:多维混合并行-CSDN博客

相关推荐
风吹夏回13 天前
RabbitMQ 核心术语 + Python pika 方法完整讲解
分布式·python·rabbitmq
风吹夏回13 天前
RabbitMQ 三种模式入门:HelloWorld、WorkQueue、PubSub
分布式·rabbitmq·ruby
霸道流氓气质13 天前
分布式追踪与 RequestId 传播完全指南
分布式
cheems952713 天前
[RabbitMQ高级特性] 消息确认机制:从 Ready / Unacked 到 basicAck、basicReject、basicNack 的底层拆解
分布式·rabbitmq·ruby
枫华落尽13 天前
【Hadoop01-完全分布式运行模式】
分布式
隔壁阿布都13 天前
ShedLock 分布式定时任务锁框架介绍
spring boot·分布式
文艺倾年13 天前
【强化学习】数学推导专题,20W字总结(十五)
人工智能·分布式·大模型·强化学习·vibecoding
ACP广源盛1392462567313 天前
GSV9001S@ACP#1080P 级视频处理芯片,物理 AI 普及终端的高性价比选择
大数据·人工智能·分布式·嵌入式硬件·spark
guslegend13 天前
第1章:初始Kafka
分布式·kafka
ACP广源盛1392462567313 天前
GSV5600@ACP#多接口协议转换芯片,物理 AI 便携终端的互联核心
大数据·人工智能·分布式·嵌入式硬件·spark