大模型分布式训练并行技术分享

目前业内解决大模型问题,基本以多节点、分布式方案为主。分布式方案具体的实施时,又分为数据并行、参数并行、流水线并行等,针对具体的业务场景采取合适的并行方案方可带来更高的效率。

后续结合业内主流的分布式框架,具体介绍各种并行的思路以及可能带来的收益。

数据并行

一些基础知识的补充:
Pytorch DDP分布式细节分享

ZeRO(零冗余优化)

零冗余优化的核心思想:用通信换显存 ,数据算完即废,等需要的时候,再同步过来。

从效果来说,零冗余优化属于数据并行+张量并行,从根本来说属于数据并行。

模型在训练时需要的显存大小,假设模型的参数W大小是phi,以byte为单位,存储如下:

结论 :优化器、模型参数、梯度是占据显存的主要数据。

将优化器、模型参数、梯度等数据进行切分可达到不同程度的显存优化,可分为zero1、zero2、zero3

zero1(优化器切分)

由于每块GPU上只保管部分optimizer states,因此只能将相应的W(蓝色部分)进行更新;需要对W做一次All-Gather,从别的GPU上把更新好的部分W取回来,额外产生单卡通讯量phi。

zero2(优化器+梯度切分)
  • 对梯度做一次Reduce-Scatter,保证每个GPU上所维持的那块梯度是聚合梯度。单卡通讯量phi。
  • 每块GPU用自己对应的O和G去更新相应的W。更新完毕后,每块GPU维持了一块更新完毕的W。同理,对W做一次All-Gather,将别的GPU算好的W同步到自己这来。单卡通讯量phi。
zero3(优化器+梯度+参数切分)
  • 做forward时,对W做一次All-Gather,取回分布在别的GPU上的W,得到一份完整的W,单卡通讯量phi 。forward做完,立刻把不是自己维护的W抛弃。
  • 做backward时,对W做一次All-Gather,取回完整的W,单卡通讯量phi。backward做完,立刻把不是自己维护的W抛弃。
  • 做完backward,算得一份完整的梯度G,对G做一次Reduce-Scatter,从别的GPU上聚合自己维护的那部分梯度,单卡通讯量phi。聚合操作结束后,立刻把不是自己维护的G抛弃。

优化效果:

用1.5倍的通讯开销,换回近60倍的显存

基于zero的实现的工具有:

  • 微软Deepspeed
  • Pytorch fsdp(1.11+)

参考论文:

zero-deepspeed.pdf

模型并行

在数据并行训练中,一个明显的特点是每个 GPU 持有整个模型权重的副本,这就带来了冗余问题。如果将模型参数、优化器等分割在一个设备整列,将有效缓解显存的压力和副本冗余。

模型并行,主流上分为张量并行和流水线并行。

张量并行为层内并行,对模型 Transformer 层内进行分割、流水线为层间并行,对模型不同的 Transformer 层间进行分割。

张量并行(TP)

张量并行可视为层内并行,可分为按行进行切分和按列进行切分,分别对应行并行(Row Parallelism)与列并行(Column Parallelism)。

受 GSPMD、Oneflow 和 TF DTensor 的启发,PyTorch 从 2.0.0 开始引入 DTensor,通过DTensor抽象,我们可以无缝构建张量并行。

参考论文:
Megatron-LM 1D 2020-03-13

流水线并行(PP)

经典的流水线并行范式有Google推出的Gpipe,和微软推出的PipeDream。两者的推出时间都在2019年左右,大体设计框架一致。主要差别为:在梯度更新上,Gpipe是同步的,PipeDream是异步的。

多维混合并行

在进行上百亿/千亿级以上参数规模的超大模型预训练时,通常会组合多种并行技术一起使用。

常见的组合方式:

DP+PP

3D 并行(DP + PP + TP)

ZeRO-DP + PP + TP

相关推荐
Coding茶水间5 小时前
基于深度学习的非机动车头盔检测系统演示与介绍(YOLOv12/v11/v8/v5模型+Pyqt5界面+训练代码+数据集)
图像处理·人工智能·深度学习·yolo·目标检测·机器学习·计算机视觉
baby_hua6 小时前
20251024_PyTorch深度学习快速入门教程
人工智能·pytorch·深度学习
脸大是真的好~7 小时前
分布式锁-基于redis实现分布式锁(不推荐)- 改进利用LUA脚本(不推荐)前面都是原理 - Redisson分布式锁
redis·分布式·lua
another heaven8 小时前
【深度学习 YOLO官方模型全解析】
人工智能·深度学习·yolo
liuniansilence8 小时前
🚀 高并发场景下的救星:BullMQ如何实现智能流量削峰填谷
前端·分布式·消息队列
极度畅想10 小时前
脑电模型实战系列(三):DEAP 数据集处理与 Russell 环状模型实战(一)
深度学习·特征提取·情感计算·脑机接口 bci·deap数据集
Wang's Blog11 小时前
RabbitMQ: 实现高效消息监听之从基础到自动配置
分布式·rabbitmq
CoovallyAIHub12 小时前
从“模仿”到“进化”!华科&小米开源MindDrive:在线强化学习重塑「语言-动作」闭环驾驶
深度学习·算法·计算机视觉
OpenBayes12 小时前
Open-AutoGLM 实现手机端自主操作;PhysDrive 数据集采集真实驾驶生理信号
人工智能·深度学习·机器学习·数据集·文档转换·图片生成·蛋白质设计