LLM分布式训练---混合并行(2D 3D)

近年来,随着Transformer、MOE架构的提出,使得深度学习模型轻松突破上万亿规模参数,传统的单机单卡模式已经无法满足超大模型进行训练的要求。因此,我们需要基于单机多卡、甚至是多机多卡进行分布式大模型的训练。

而利用AI集群,使深度学习算法更好地从大量数据中高效地训练出性能优良的大模型是分布式机器学习的首要目标。为了实现该目标,一般需要根据硬件资源与数据/模型规模的匹配情况,考虑对计算任务、训练数据和模型进行划分,从而进行分布式存储和分布式训练。

因此,分布式训练相关技术值得我们进行深入分析其背后的机理。

但是在进行上百亿/千亿级以上参数规模的超大模型预训练时,通常会组合多种并行技术一起使用。

常见的分布式并行技术组合

1. DP + PP

下图演示了如何将 DP 与 PP 结合起来使用。

这里重要的是要了解 DP rank 0 是看不见 GPU2 的, 同理,DP rank 1 是看不到 GPU3 的。

对于 DP 而言,只有 GPU 0 和 1,并向它们供给数据。GPU0 使用 PP 将它的一些负载转移到 GPU2。同样地, GPU1 也会将它的一些负载转移到 GPU3 。

由于每个维度至少需要 2 个 GPU;因此,这儿至少需要 4 个 GPU。

2. 3D 并行(DP + PP + TP)

而为了更高效地训练,可以将 PP、TP 和 DP 相结合,被业界称为 3D 并行,如下图所示。

由于每个维度至少需要 2 个 GPU,因此在这里你至少需要 8 个 GPU 才能实现完整的 3D 并行。

3. ZeRO --- DP + PP + TP

作为 DeepSpeed 的主要功能之一,它是 DP 的超级可伸缩增强版,并启发了 PyTorch FSDP 的诞生。通常它是一个独立的功能,不需要 PP 或 TP。但它也可以与 PP、TP 结合使用。

当 ZeRO-DP 与 PP 和 TP 结合使用时,通常只启用 ZeRO 阶段 1(只对优化器状态进行分片)

而 ZeRO 阶段 2 还会对梯度进行分片,ZeRO 阶段 3 还会对模型权重进行分片。

虽然理论上可以将 ZeRO 阶段 2 与 流水线并行PP一起使用,但它会对性能产生不良影响。每个 micro batch 都需要一个额外的 reduce-scatter 通信来在分片之前聚合梯度,这会增加潜在的显著通信开销。根据流水线并行的性质,我们会使用小的 micro batch ,并把重点放在算术强度 (micro batch size) 与最小化流水线气泡 (micro batch 的数量) 两者间折衷。因此,增加的通信开销会损害流水线并行。

此外,由于 PP,层数已经比正常情况下少,因此并不会节省很多内存。PP 已经将梯度大小减少了 1/PP,因此在此基础之上的梯度分片和纯 DP 相比节省不了多少内存。

除此之外,我们也可以采用 DP + TP 进行组合、也可以使用 PP + TP 进行组合,还可以使用 ZeRO3 代替 DP + PP + TP,ZeRO3 本质上是DP+MP的组合,并且无需对模型进行过多改造,使用更方便。

【深度学习】【分布式训练】一文捋顺千亿模型训练技术:流水线并行、张量并行和3D并行 - 知乎 (zhihu.com)

用通俗易懂的方式讲解大模型分布式训练并行技术:多维混合并行-CSDN博客

相关推荐
嘉禾望岗50321 小时前
spark算子类型
大数据·分布式·spark
大厂技术总监下海21 小时前
来自美团生产环境的实战派:开源CAT监控,如何保障超大规模分布式系统可观测性?
分布式·开源
大厂技术总监下海1 天前
深入 Apache Dubbo 架构:解读一个开源高性能 RPC 框架的设计哲学与核心源码
分布式·微服务
前端不太难1 天前
不写 Socket,也能做远程任务?HarmonyOS 分布式任务同步实战
分布式·华为·harmonyos
回家路上绕了弯1 天前
Spring Retry框架实战指南:优雅处理分布式系统中的瞬时故障
分布式·后端
前端不太难1 天前
HarmonyOS 分布式开发第一课:设备间协同调试实战
分布式·华为·harmonyos
AutoMQ1 天前
当 Kafka 架构显露“疲态”:共享存储领域正迎来创新变革
分布式·架构·kafka
程序员阿鹏1 天前
RabbitMQ持久化到磁盘中有个节点断掉了怎么办?
java·开发语言·分布式·后端·spring·缓存·rabbitmq
独自破碎E1 天前
Kafka的索引设计有什么亮点?
数据库·分布式·kafka
武子康1 天前
Java-218 RocketMQ Java API 实战:同步/异步 Producer 与 Pull/Push Consumer
java·大数据·分布式·消息队列·rocketmq·java-rocketmq·mq