深度学习专题:模型训练的数据并行(三)
ZeRO 系列优化:ZeRO-1、ZeRO-2、ZeRO-3、ZeRO-R
(一)Adam 优化器
传统的 SGD 优化器第 t+1 轮参数更新公式为:
wt+1=wt−lr⋅gt+1 w_{t+1} = w_t - lr \cdot g_{t+1} wt+1=wt−lr⋅gt+1
其中 wtw_twt 为第 t 轮参数,lrlrlr 为学习率(不变的定值),gtg_tgt 为第 t 轮参数的梯度。
这种优化策略容易遇到驻点问题:当模型参数的梯度接近于零(即 gt+1g_{t+1}gt+1 接近于零)时,参数更新会非常缓慢,甚至停滞不前。
Adam 优化器为了解决驻点问题,在 SGD 优化器的基础上,引入了动量和自适应学习率的概念。其第 t+1 轮参数更新公式为:
wt+1=wt−lr⋅mt+1vt+1+ϵ w_{t+1} = w_t - lr \cdot \frac{m_{t+1}}{\sqrt{v_{t+1}} + \epsilon} wt+1=wt−lr⋅vt+1 +ϵmt+1
其中 mtm_tmt 为第 t 轮的动量,vtv_tvt 为第 t 轮的方差值,ϵ\epsilonϵ 为一个小的常量(防止除零错误)。特别地,动量与方差的计算公式为:
mt+1=β1mt+(1−β1)gt+11−β1t+1vt+1=β2vt+(1−β2)gt+121−β2t+1 m_{t+1} = \frac{\beta_1 m_t + (1 - \beta_1) g_{t+1}}{1 - \beta_1^{t+1}} \\ v_{t+1} = \frac{\beta_2 v_t + (1 - \beta_2) g_{t+1}^2}{1 - \beta_2^{t+1}} mt+1=1−β1t+1β1mt+(1−β1)gt+1vt+1=1−β2t+1β2vt+(1−β2)gt+12
其中 β1\beta_1β1 和 β2\beta_2β2 为两个超参数(常设为 0.9 和 0.999)。初值为 m0=0m_0 = 0m0=0 和 v0=0v_0 = 0v0=0。
(二)Adam 优化器数据并行存储代价
按照 Adam 优化器的公式,模型在训练过程中不仅要存储当前轮的参数 wtw_twt 和梯度 gtg_tgt,还要存储当前轮的动量 mtm_tmt 和方差 vtv_tvt。假设有 N 个参数,在数据并行场景下,与 SGD 优化器比较分析第 t+1 轮 epoch 的反向传播过程:
1. SGD 优化器
SGD 优化器在每个 GPU 中需要额外存储 2N2N2N 个浮点数。
| 存储项 | 1 | 2 | ... | N |
|---|---|---|---|---|
| 参数向量 www | t | t | ... | t |
| 梯度向量 ggg | - | - | ... | - |
第一步:第 k 个 GPU 读取分配给自己的训练数据,读取第 t 轮 epoch 的参数向量 wtw_twt 并进行前向传播、损失计算和反向传播,得到第 t+1 轮 epoch 的梯度向量
| 存储项 | 1 | 2 | ... | N |
|---|---|---|---|---|
| 参数向量 www | t | t | ... | t |
| 梯度向量 ggg | t+1, k | t+1, k | ... | t+1, k |
第二步:基于 Ring All-Reduce 操作获取其他所有 GPU 上的第 t+1 轮 epoch 的梯度向量,计算梯度向量均值 gt+1g_{t+1}gt+1 并存储;
| 存储项 | 1 | 2 | ... | N |
|---|---|---|---|---|
| 参数向量 www | t | t | ... | t |
| 梯度向量 ggg | t+1 | t+1 | ... | t+1 |
第三步:根据聚合后的梯度向量 gt+1g_{t+1}gt+1,更新参数向量 wt+1w_{t+1}wt+1。
| 存储项 | 1 | 2 | ... | N |
|---|---|---|---|---|
| 参数向量 www | t+1 | t+1 | ... | t+1 |
| 梯度向量 ggg | - | - | ... | - |
2. Adam 优化器
Adam 优化器在每个 GPU 中需要额外存储 4N4N4N 个浮点数。
| 存储项 | 1 | 2 | ... | N |
|---|---|---|---|---|
| 参数向量 www | t | t | ... | t |
| 动量向量 mmm | t | t | ... | t |
| 方差向量 vvv | t | t | ... | t |
| 梯度向量 ggg | - | - | ... | - |
第一步:第 k 个 GPU 读取分配给自己的训练数据,读取第 t 轮 epoch 的参数向量 wtw_twt 并进行前向传播、损失计算和反向传播,得到第 t+1 轮 epoch 的梯度向量;
| 存储项 | 1 | 2 | ... | N |
|---|---|---|---|---|
| 参数向量 www | t | t | ... | t |
| 动量向量 mmm | t | t | ... | t |
| 方差向量 vvv | t | t | ... | t |
| 梯度向量 ggg | t+1, k | t+1, k | ... | t+1, k |
第二步:基于 Ring All-Reduce 操作获取其他所有 GPU 上的第 t+1 轮 epoch 的梯度向量,计算梯度向量均值 gt+1g_{t+1}gt+1 并存储;
| 存储项 | 1 | 2 | ... | N |
|---|---|---|---|---|
| 参数向量 www | t | t | ... | t |
| 动量向量 mmm | t | t | ... | t |
| 方差向量 vvv | t | t | ... | t |
| 梯度向量 ggg | t+1 | t+1 | ... | t+1 |
第三步:读取第 t 轮 epoch 的动量 mtm_tmt 和方差 vtv_tvt 并根据聚合后的第 t+1 轮 epoch 的梯度向量 gt+1g_{t+1}gt+1 计算第 t+1 轮 epoch 的动量 mt+1m_{t+1}mt+1 和方差 vt+1v_{t+1}vt+1 并存储;
| 存储项 | 1 | 2 | ... | N |
|---|---|---|---|---|
| 参数向量 www | t | t | ... | t |
| 动量向量 mmm | t+1 | t+1 | ... | t+1 |
| 方差向量 vvv | t+1 | t+1 | ... | t+1 |
| 梯度向量 ggg | t+1 | t+1 | ... | t+1 |
第四步:根据动量 mt+1m_{t+1}mt+1 和方差 vt+1v_{t+1}vt+1,更新参数向量 wt+1w_{t+1}wt+1。
| 存储项 | 1 | 2 | ... | N |
|---|---|---|---|---|
| 参数向量 www | t+1 | t+1 | ... | t+1 |
| 动量向量 mmm | t+1 | t+1 | ... | t+1 |
| 方差向量 vvv | t+1 | t+1 | ... | t+1 |
| 梯度向量 ggg | - | - | ... | - |
3. 分析与总结
GPU 需要存储第 t 轮 epoch 的参数向量 wtw_twt,因为在第 t+1 轮训练过程中需要根据 wtw_twt 进行前向传播、损失计算和反向传播,从而计算出第 t+1 轮 epoch 的梯度向量 gt+1g_{t+1}gt+1;
GPU 还需要存储第 t 轮 epoch 的动量向量 mtm_tmt 和方差向量 vtv_tvt,因为在第 t+1 轮训练过程中需要根据 mtm_tmt 和 vtv_tvt 计算出第 t+1 轮 epoch 的动量向量 mt+1m_{t+1}mt+1 和方差向量 vt+1v_{t+1}vt+1;
上面的两个存储项均涉及到跨 epoch 存储,所以对应的空间在整个训练的生命周期都不能被释放。
GPU 还需要临时存储第 t+1 轮 epoch 的梯度向量 gt+1g_{t+1}gt+1,因为 GPU 需要使用 Ring All-Reduce 算法在所有 GPU 之间归约梯度向量并计算均值。如果是 SGD 优化器,这个梯度向量会直接用于更新参数;如果是 Adam 优化器,这个梯度向量会用于更新动量和方差。不同的是,完成对应的更新操作后,梯度向量就不需要继续存储了,对应的空间就可以被释放了。所以梯度向量只需要临时存储,而不需要跨 epoch 存储。
(三)ZeRO-1 优化原理
为了减少 Adam 优化器的存储代价,ZeRO-1 提出了一种优化策略:将优化器状态在多个 GPU 之间进行分区存储,而不是在每个 GPU 上完整保存所有优化器状态。
在数据并行训练中,每个 GPU 通常保存完整的模型参数、梯度和优化器状态。ZeRO-1 观察到,优化器状态(如 Adam 中的动量和方差)在训练过程中主要参与参数更新计算,但不同参数的优化器状态之间是相互独立的。因此,可以将优化器状态按照参数维度进行分区,每个 GPU 只需要存储自己负责的那部分动量向量 mt+1m_{t+1}mt+1 和方差向量 vt+1v_{t+1}vt+1 并更新对应的参数,然后通过 GPU 的 All-Gather 算法同步所有更新后的参数。
1. 步骤分析
准备步骤:假设模型有 N 个参数,有 P 个 GPU 负责并行训练
- 每个 GPU 负责管理 M=⌈N/P⌉M = \lceil N/P \rceilM=⌈N/P⌉ 个参数
- 第 k 个 GPU 将负责第 sk=1+(k−1)Ms_k = 1+(k-1)Msk=1+(k−1)M 到 ek=kMe_k=kMek=kM 个参数
| 存储项 | 1 | 2 | ... | sk−1s_k-1sk−1 | sks_ksk | sk+1s_k+1sk+1 | ... | eke_kek | ek+1e_k+1ek+1 | ... | N |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 参数向量 www | t | t | ... | t | t | t | ... | t | t | ... | t |
| 动量向量 mmm | - | - | ... | - | t | t | ... | t | - | ... | - |
| 方差向量 vvv | - | - | ... | - | t | t | ... | t | - | ... | - |
| 梯度向量 ggg | - | - | ... | - | - | - | ... | - | - | ... | - |
第一步:第 k 个 GPU 读取分配给自己的训练数据,读取第 t 轮 epoch 的参数向量 wtw_twt 并进行前向传播、损失计算和反向传播,得到第 t+1 轮 epoch 的梯度向量;
| 存储项 | 1 | 2 | ... | sk−1s_k-1sk−1 | sks_ksk | sk+1s_k+1sk+1 | ... | eke_kek | ek+1e_k+1ek+1 | ... | N |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 参数向量 www | t | t | ... | t | t | t | ... | t | t | ... | t |
| 动量向量 mmm | - | - | ... | - | t | t | ... | t | - | ... | - |
| 方差向量 vvv | - | - | ... | - | t | t | ... | t | - | ... | - |
| 梯度向量 ggg | t+1, k | t+1, k | ... | t+1, k | t+1, k | t+1, k | ... | t+1, k | t+1, k | ... | t+1, k |
第二步:基于 Ring All-Reduce 操作获取其他所有 GPU 上的第 t+1 轮 epoch 的梯度向量,计算梯度向量均值 gt+1g_{t+1}gt+1 并存储;
| 存储项 | 1 | 2 | ... | sk−1s_k-1sk−1 | sks_ksk | sk+1s_k+1sk+1 | ... | eke_kek | ek+1e_k+1ek+1 | ... | N |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 参数向量 www | t | t | ... | t | t | t | ... | t | t | ... | t |
| 动量向量 mmm | - | - | ... | - | t | t | ... | t | - | ... | - |
| 方差向量 vvv | - | - | ... | - | t | t | ... | t | - | ... | - |
| 梯度向量 ggg | t+1 | t+1 | ... | t+1 | t+1 | t+1 | ... | t+1 | t+1 | ... | t+1 |
第三步:根据聚合后的梯度向量 gt+1g_{t+1}gt+1,更新自己负责的那一部分参数的动量和方差;
| 存储项 | 1 | 2 | ... | sk−1s_k-1sk−1 | sks_ksk | sk+1s_k+1sk+1 | ... | eke_kek | ek+1e_k+1ek+1 | ... | N |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 参数向量 www | t | t | ... | t | t | t | ... | t | t | ... | t |
| 动量向量 mmm | - | - | ... | - | t+1 | t+1 | ... | t+1 | - | ... | - |
| 方差向量 vvv | - | - | ... | - | t+1 | t+1 | ... | t+1 | - | ... | - |
| 梯度向量 ggg | t+1 | t+1 | ... | t+1 | t+1 | t+1 | ... | t+1 | t+1 | ... | t+1 |
第四步:每个 GPU 根据自己负责的那部分更新后的动量和方差,更新自己负责的那部分参数向量,然后释放梯度向量 gt+1g_{t+1}gt+1 占用的存储空间;
| 存储项 | 1 | 2 | ... | sk−1s_k-1sk−1 | sks_ksk | sk+1s_k+1sk+1 | ... | eke_kek | ek+1e_k+1ek+1 | ... | N |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 参数向量 www | t | t | ... | t | t+1 | t+1 | ... | t+1 | t | ... | t |
| 动量向量 mmm | - | - | ... | - | t+1 | t+1 | ... | t+1 | - | ... | - |
| 方差向量 vvv | - | - | ... | - | t+1 | t+1 | ... | t+1 | - | ... | - |
| 梯度向量 ggg | - | - | ... | - | - | - | ... | - | - | ... | - |
第五步:每个 GPU 与其他所有 GPU 通过 All-Gather 算法通信,得到更新后的参数向量 wt+1w_{t+1}wt+1。
| 存储项 | 1 | 2 | ... | sk−1s_k-1sk−1 | sks_ksk | sk+1s_k+1sk+1 | ... | eke_kek | ek+1e_k+1ek+1 | ... | N |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 参数向量 www | t+1 | t+1 | ... | t+1 | t+1 | t+1 | ... | t+1 | t+1 | ... | t+1 |
| 动量向量 mmm | - | - | ... | - | t+1 | t+1 | ... | t+1 | - | ... | - |
| 方差向量 vvv | - | - | ... | - | t+1 | t+1 | ... | t+1 | - | ... | - |
| 梯度向量 ggg | - | - | ... | - | - | - | ... | - | - | ... | - |
2. 存储优化分析
- 传统数据并行:每个 GPU 需要存储 4N4N4N 个浮点数;
- ZeRO-1 优化:每个 GPU 负责管理 M=⌈N/P⌉M = \lceil N/P \rceilM=⌈N/P⌉ 个参数,所以每个 GPU 需要存储 2M+2N2M + 2N2M+2N 个浮点数。当 M≪NM \ll NM≪N 时,存储量趋近于 2N2N2N,存储代价可以显著减少。
(四)ZeRO-2 优化原理
ZeRO-2 在 ZeRO-1 的基础上进一步优化,提出了梯度分区存储的策略,其核心通信模式是 Reduce-Scatter 而不是 All-Reduce。
在传统数据并行中,梯度同步使用 All-Reduce,每个 GPU 最终都获得完整的平均梯度。ZeRO-2 改为使用 Reduce-Scatter,每个 GPU 只获得部分梯度的聚合结果,与优化器状态的分区相对应。
1. 步骤分析
准备步骤:假设模型有 N 个参数,有 P 个 GPU 负责并行训练
- 每个 GPU 负责管理 M=⌈N/P⌉M = \lceil N/P \rceilM=⌈N/P⌉ 个参数
- 第 k 个 GPU 将负责第 sk=1+(k−1)Ms_k = 1+(k-1)Msk=1+(k−1)M 到 ek=kMe_k=kMek=kM 个参数
| 存储项 | 1 | 2 | ... | sk−1s_k-1sk−1 | sks_ksk | sk+1s_k+1sk+1 | ... | eke_kek | ek+1e_k+1ek+1 | ... | N |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 参数向量 www | t | t | ... | t | t | t | ... | t | t | ... | t |
| 动量向量 mmm | - | - | ... | - | t | t | ... | t | - | ... | - |
| 方差向量 vvv | - | - | ... | - | t | t | ... | t | - | ... | - |
| 梯度向量 ggg | - | - | ... | - | - | - | ... | - | - | ... | - |
第一步:第 k 个 GPU 读取分配给自己的训练数据,读取第 t 轮 epoch 的参数向量 wtw_twt 并进行前向传播、损失计算和反向传播,得到第 t+1 轮 epoch 的梯度向量;
| 存储项 | 1 | 2 | ... | sk−1s_k-1sk−1 | sks_ksk | sk+1s_k+1sk+1 | ... | eke_kek | ek+1e_k+1ek+1 | ... | N |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 参数向量 www | t | t | ... | t | t | t | ... | t | t | ... | t |
| 动量向量 mmm | - | - | ... | - | t | t | ... | t | - | ... | - |
| 方差向量 vvv | - | - | ... | - | t | t | ... | t | - | ... | - |
| 梯度向量 ggg | t+1, k | t+1, k | ... | t+1, k | t+1, k | t+1, k | ... | t+1, k | t+1, k | ... | t+1, k |
第二步:基于 Reduce-Scatter 操作获取其他所有 GPU 上的第 t+1 轮 epoch 的梯度向量的第 sks_ksk 到 eke_kek 个元素,计算梯度向量均值 gt+1g_{t+1}gt+1 的第 sks_ksk 到 eke_kek 个元素并存储,然后释放梯度向量均值 gt+1g_{t+1}gt+1 的 sks_ksk 到 eke_kek 之外的元素占用的存储空间;
| 存储项 | 1 | 2 | ... | sk−1s_k-1sk−1 | sks_ksk | sk+1s_k+1sk+1 | ... | eke_kek | ek+1e_k+1ek+1 | ... | N |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 参数向量 www | t | t | ... | t | t | t | ... | t | t | ... | t |
| 动量向量 mmm | - | - | ... | - | t | t | ... | t | - | ... | - |
| 方差向量 vvv | - | - | ... | - | t | t | ... | t | - | ... | - |
| 梯度向量 ggg | - | - | ... | - | t+1 | t+1 | ... | t+1 | - | ... | - |
第三步:根据聚合后的梯度向量 gt+1g_{t+1}gt+1 的 sks_ksk 到 eke_kek 个元素,更新自己负责的那一部分参数的动量和方差;
| 存储项 | 1 | 2 | ... | sk−1s_k-1sk−1 | sks_ksk | sk+1s_k+1sk+1 | ... | eke_kek | ek+1e_k+1ek+1 | ... | N |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 参数向量 www | t | t | ... | t | t | t | ... | t | t | ... | t |
| 动量向量 mmm | - | - | ... | - | t+1 | t+1 | ... | t+1 | - | ... | - |
| 方差向量 vvv | - | - | ... | - | t+1 | t+1 | ... | t+1 | - | ... | - |
| 梯度向量 ggg | - | - | ... | - | t+1 | t+1 | ... | t+1 | - | ... | - |
第四步:每个 GPU 根据自己负责的那部分更新后的动量和方差,更新自己负责的那部分参数向量,然后释放梯度向量 gt+1g_{t+1}gt+1 的第 sks_ksk 到 eke_kek 的元素占用的存储空间;
| 存储项 | 1 | 2 | ... | sk−1s_k-1sk−1 | sks_ksk | sk+1s_k+1sk+1 | ... | eke_kek | ek+1e_k+1ek+1 | ... | N |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 参数向量 www | t | t | ... | t | t+1 | t+1 | ... | t+1 | t | ... | t |
| 动量向量 mmm | - | - | ... | - | t+1 | t+1 | ... | t+1 | - | ... | - |
| 方差向量 vvv | - | - | ... | - | t+1 | t+1 | ... | t+1 | - | ... | - |
| 梯度向量 ggg | - | - | ... | - | - | - | ... | - | - | ... | - |
第五步:每个 GPU 与其他所有 GPU 通过 All-Gather 算法通信,得到更新后的参数向量 wt+1w_{t+1}wt+1。
| 存储项 | 1 | 2 | ... | sk−1s_k-1sk−1 | sks_ksk | sk+1s_k+1sk+1 | ... | eke_kek | ek+1e_k+1ek+1 | ... | N |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 参数向量 www | t+1 | t+1 | ... | t+1 | t+1 | t+1 | ... | t+1 | t+1 | ... | t+1 |
| 动量向量 mmm | - | - | ... | - | t+1 | t+1 | ... | t+1 | - | ... | - |
| 方差向量 vvv | - | - | ... | - | t+1 | t+1 | ... | t+1 | - | ... | - |
| 梯度向量 ggg | - | - | ... | - | - | - | ... | - | - | ... | - |
2. 存储优化分析
- ZeRO-1 优化:每个 GPU 需要存储 2M+2N2M + 2N2M+2N 个浮点数,当 M≪NM \ll NM≪N 时,存储量趋近于 2N2N2N,存储代价可以显著减少;
- ZeRO-2 优化:每个 GPU 需要存储 3M+N3M + N3M+N 个浮点数,当 M≪NM \ll NM≪N 时,存储量趋近于 NNN,存储代价可以显著减少。
(五)ZeRO-3 优化原理
ZeRO-3 在 ZeRO-2 的基础上进一步优化,提出了参数分区存储的策略,每个 GPU 只存储自己负责的那部分参数,这需要在训练的前向传播和反向传播之前通过 All-Gather 通信收集全部参数,从而达成一种通信换存储的优化效果。
1. 步骤分析
准备步骤:假设模型有 N 个参数,有 P 个 GPU 负责并行训练
- 每个 GPU 负责管理 M=⌈N/P⌉M = \lceil N/P \rceilM=⌈N/P⌉ 个参数
- 第 k 个 GPU 将负责第 sk=1+(k−1)Ms_k = 1+(k-1)Msk=1+(k−1)M 到 ek=kMe_k=kMek=kM 个参数,且只存储这些参数。
| 存储项 | 1 | 2 | ... | sk−1s_k-1sk−1 | sks_ksk | sk+1s_k+1sk+1 | ... | eke_kek | ek+1e_k+1ek+1 | ... | N |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 参数向量 www | - | - | ... | - | t | t | ... | t | - | ... | - |
| 动量向量 mmm | - | - | ... | - | t | t | ... | t | - | ... | - |
| 方差向量 vvv | - | - | ... | - | t | t | ... | t | - | ... | - |
| 梯度向量 ggg | - | - | ... | - | - | - | ... | - | - | ... | - |
第一步:第 k 个 GPU 通过 All-Gather 通信获得第 t 轮 epoch 的参数向量 wtw_twt,然后读取分配给自己的训练数据,完成前向传播,计算得到损失值;
| 存储项 | 1 | 2 | ... | sk−1s_k-1sk−1 | sks_ksk | sk+1s_k+1sk+1 | ... | eke_kek | ek+1e_k+1ek+1 | ... | N |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 参数向量 www | t | t | ... | t | t | t | ... | t | t | ... | t |
| 动量向量 mmm | - | - | ... | - | t | t | ... | t | - | ... | - |
| 方差向量 vvv | - | - | ... | - | t | t | ... | t | - | ... | - |
| 梯度向量 ggg | - | - | ... | - | - | - | ... | - | - | ... | - |
第二步:进行反向传播,计算得到第 t+1 轮 epoch 的梯度向量,然后释放第 t 轮参数向量的 sks_ksk 到 eke_kek 之外的元素占用的存储空间;
| 存储项 | 1 | 2 | ... | sk−1s_k-1sk−1 | sks_ksk | sk+1s_k+1sk+1 | ... | eke_kek | ek+1e_k+1ek+1 | ... | N |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 参数向量 www | - | - | ... | - | t | t | ... | t | - | ... | - |
| 动量向量 mmm | - | - | ... | - | t | t | ... | t | - | ... | - |
| 方差向量 vvv | - | - | ... | - | t | t | ... | t | - | ... | - |
| 梯度向量 ggg | t+1, k | t+1, k | ... | t+1, k | t+1, k | t+1, k | ... | t+1, k | t+1, k | ... | t+1, k |
第三步:基于 Reduce-Scatter 操作获取其他所有 GPU 上的第 t+1 轮 epoch 的梯度向量的第 sks_ksk 到 eke_kek 个元素,计算梯度向量均值 gt+1g_{t+1}gt+1 的第 sks_ksk 到 eke_kek 个元素并存储,然后释放梯度向量均值 gt+1g_{t+1}gt+1 的 sks_ksk 到 eke_kek 之外的元素占用的存储空间;
| 存储项 | 1 | 2 | ... | sk−1s_k-1sk−1 | sks_ksk | sk+1s_k+1sk+1 | ... | eke_kek | ek+1e_k+1ek+1 | ... | N |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 参数向量 www | - | - | ... | - | t | t | ... | t | - | ... | - |
| 动量向量 mmm | - | - | ... | - | t | t | ... | t | - | ... | - |
| 方差向量 vvv | - | - | ... | - | t | t | ... | t | - | ... | - |
| 梯度向量 ggg | - | - | ... | - | t+1 | t+1 | ... | t+1 | - | ... | - |
第四步:根据聚合后的梯度向量 gt+1g_{t+1}gt+1 的 sks_ksk 到 eke_kek 个元素,更新自己负责的那一部分参数的动量和方差;
| 存储项 | 1 | 2 | ... | sk−1s_k-1sk−1 | sks_ksk | sk+1s_k+1sk+1 | ... | eke_kek | ek+1e_k+1ek+1 | ... | N |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 参数向量 www | - | - | ... | - | t | t | ... | t | - | ... | - |
| 动量向量 mmm | - | - | ... | - | t+1 | t+1 | ... | t+1 | - | ... | - |
| 方差向量 vvv | - | - | ... | - | t+1 | t+1 | ... | t+1 | - | ... | - |
| 梯度向量 ggg | - | - | ... | - | t+1 | t+1 | ... | t+1 | - | ... | - |
第五步:每个 GPU 根据自己负责的那部分更新后的动量和方差,更新自己负责的那部分参数向量,然后释放梯度向量 gt+1g_{t+1}gt+1 的第 sks_ksk 到 eke_kek 的元素占用的存储空间;
| 存储项 | 1 | 2 | ... | sk−1s_k-1sk−1 | sks_ksk | sk+1s_k+1sk+1 | ... | eke_kek | ek+1e_k+1ek+1 | ... | N |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 参数向量 www | - | - | ... | - | t+1 | t+1 | ... | t+1 | - | ... | - |
| 动量向量 mmm | - | - | ... | - | t+1 | t+1 | ... | t+1 | - | ... | - |
| 方差向量 vvv | - | - | ... | - | t+1 | t+1 | ... | t+1 | - | ... | - |
| 梯度向量 ggg | - | - | ... | - | - | - | ... | - | - | ... | - |
2. 存储优化分析
- ZeRO-3 优化:每个 GPU 需要存储 4M4M4M 个浮点数,当 M≪NM \ll NM≪N 时,存储量趋近于 000,存储代价可以显著减少。
(四)ZeRO-3 层间流水线优化
由于层间流水线优化可以将多个层的前向传播和反向传播并行执行,而每个层只需要存储自己负责的那部分参数,因此 ZeRO-3 与层间流水线优化结合可以进一步减少存储代价。
1. 前向传播和反向传播的计算资料
考虑下面的模型:
- 输入输出都是标量,分别记作 xxx 和 yyy;
- 模型有 4 个层,其中前三个层是全连接层加上 ReLU 激活函数,最后一个层是全连接层;
- 每个层的参数数量都是 2 个,分别是 权重 www 和 偏置 bbb;
- 模型的前向传播:
- 第 1 层:h1=ReLU(w1×x+b1)h_1 = \text{ReLU}(w_1 \times x + b_1)h1=ReLU(w1×x+b1)
- 第 2 层:h2=ReLU(w2×h1+b2)h_2 = \text{ReLU}(w_2 \times h_1 + b_2)h2=ReLU(w2×h1+b2)
- 第 3 层:h3=ReLU(w3×h2+b3)h_3 = \text{ReLU}(w_3 \times h_2 + b_3)h3=ReLU(w3×h2+b3)
- 第 4 层:y=w4×h3+b4y = w_4 \times h_3 + b_4y=w4×h3+b4
- 损失函数:L=12(y−x)2L = \frac{1}{2}(y - x)^2L=21(y−x)2
- 模型的反向传播:
-
输出层:
- ∂L∂y=y−x\frac{\partial L}{\partial y} = y - x∂y∂L=y−x
-
第 4 层:
- ∂L∂w4=∂L∂y×∂y∂w4=(y−x)×h3\frac{\partial L}{\partial w_4} = \frac{\partial L}{\partial y} \times \frac{\partial y}{\partial w_4} = (y - x) \times h_3∂w4∂L=∂y∂L×∂w4∂y=(y−x)×h3
- ∂L∂b4=∂L∂y×∂y∂b4=y−x\frac{\partial L}{\partial b_4} = \frac{\partial L}{\partial y} \times \frac{\partial y}{\partial b_4} = y - x∂b4∂L=∂y∂L×∂b4∂y=y−x
-
第 3 层:
- ∂L∂h3=∂L∂y×∂y∂h3=(y−x)×w4\frac{\partial L}{\partial h_3} = \frac{\partial L}{\partial y} \times \frac{\partial y}{\partial h_3} = (y - x) \times w_4∂h3∂L=∂y∂L×∂h3∂y=(y−x)×w4
- ∂L∂w3=∂L∂h3×∂h3∂w3=∂L∂h3×h2×I(h3>0)\frac{\partial L}{\partial w_3} = \frac{\partial L}{\partial h_3} \times \frac{\partial h_3}{\partial w_3} = \frac{\partial L}{\partial h_3} \times h_2 \times \mathbb{I}(h_3 > 0)∂w3∂L=∂h3∂L×∂w3∂h3=∂h3∂L×h2×I(h3>0)
- ∂L∂b3=∂L∂h3×∂h3∂b3=∂L∂h3×I(h3>0)\frac{\partial L}{\partial b_3} = \frac{\partial L}{\partial h_3} \times \frac{\partial h_3}{\partial b_3} = \frac{\partial L}{\partial h_3} \times \mathbb{I}(h_3 > 0)∂b3∂L=∂h3∂L×∂b3∂h3=∂h3∂L×I(h3>0)
-
第 2 层:
-
∂L∂h2=∂L∂h3×∂h3∂h2=∂L∂h3×w3×I(h3>0)\frac{\partial L}{\partial h_2} = \frac{\partial L}{\partial h_3} \times \frac{\partial h_3}{\partial h_2} = \frac{\partial L}{\partial h_3} \times w_3 \times \mathbb{I}(h_3 > 0)∂h2∂L=∂h3∂L×∂h2∂h3=∂h3∂L×w3×I(h3>0)
-
∂L∂w2=∂L∂h2×∂h2∂w2=∂L∂h2×h1×I(h2>0)\frac{\partial L}{\partial w_2} = \frac{\partial L}{\partial h_2} \times \frac{\partial h_2}{\partial w_2} = \frac{\partial L}{\partial h_2} \times h_1 \times \mathbb{I}(h_2 > 0)∂w2∂L=∂h2∂L×∂w2∂h2=∂h2∂L×h1×I(h2>0)
-
∂L∂b2=∂L∂h2×∂h2∂b2=∂L∂h2×I(h2>0)\frac{\partial L}{\partial b_2} = \frac{\partial L}{\partial h_2} \times \frac{\partial h_2}{\partial b_2} = \frac{\partial L}{\partial h_2} \times \mathbb{I}(h_2 > 0)∂b2∂L=∂h2∂L×∂b2∂h2=∂h2∂L×I(h2>0)
-
-
第 1 层:
-
∂L∂h1=∂L∂h2×∂h2∂h1=∂L∂h2×w2×I(h2>0)\frac{\partial L}{\partial h_1} = \frac{\partial L}{\partial h_2} \times \frac{\partial h_2}{\partial h_1} = \frac{\partial L}{\partial h_2} \times w_2 \times \mathbb{I}(h_2 > 0)∂h1∂L=∂h2∂L×∂h1∂h2=∂h2∂L×w2×I(h2>0)
-
∂L∂w1=∂L∂h1×∂h1∂w1=∂L∂h1×x×I(h1>0)\frac{\partial L}{\partial w_1} = \frac{\partial L}{\partial h_1} \times \frac{\partial h_1}{\partial w_1} = \frac{\partial L}{\partial h_1} \times x \times \mathbb{I}(h_1 > 0)∂w1∂L=∂h1∂L×∂w1∂h1=∂h1∂L×x×I(h1>0)
-
∂L∂b1=∂L∂h1×∂h1∂b1=∂L∂h1×I(h1>0)\frac{\partial L}{\partial b_1} = \frac{\partial L}{\partial h_1} \times \frac{\partial h_1}{\partial b_1} = \frac{\partial L}{\partial h_1} \times \mathbb{I}(h_1 > 0)∂b1∂L=∂h1∂L×∂b1∂h1=∂h1∂L×I(h1>0)
-
-
分析可知:
(1)第 m 层的前向传播,需要的计算资料有:
- wmw_mwm
- bmb_mbm
- hm−1h_{m-1}hm−1
其中,hm−1h_{m-1}hm−1 是前一层的输出,由前一层计算结束后保存;wmw_mwm 和 bmb_mbm 是本层的参数,如果参数是分区存储的,那么参数需要通过 All-Gather 通信获取。
(2)第 m 层的反向传播,需要的计算资料有:
- ∂L∂hm+1\frac{\partial L}{\partial h_{m+1}}∂hm+1∂L
- wm+1w_{m+1}wm+1
- hm−1h_{m-1}hm−1
其中,∂L∂hm+1\frac{\partial L}{\partial h_{m+1}}∂hm+1∂L 是后一层的梯度,由后一层计算结束后保存;hm−1h_{m-1}hm−1 是前一层的输出,由前一层计算结束后保存;wm+1w_{m+1}wm+1 是下一层的参数,需要通过 All-Gather 通信获取。
2. 前向传播和反向传播的层间流水线并行执行
基于上面的讨论,我们可以让前向传播和反向传播进行层间流水线并行执行,即通信过程和计算过程重叠。
假设有 4 块 GPU,每块 GPU 负责存储 1 个层的参数,每个层的参数数量都是 2 个;此外每块 GPU 都负责 1/4 的 mini batch size 的训练数据进行训练。
前向传播重叠过程:
- 第一步:第 1 层参数的广播通信。由于第 1 层参数存储在 GPU-1 上,因此基于 All-Gather 通信将第 1 层参数广播到所有 GPU 上。
- 第二步:第 1 层前向传播的计算,第 2 层参数的广播通信。同时完成这两件事:①在每个 GPU 上计算第 1 层的前向传播,并保存结果 h1h_1h1,然后删除上一步广播到的参数(除 GPU-1 外);②由于第 2 层参数存储在 GPU-2 上,因此基于 All-Gather 通信将第 2 层参数广播到所有 GPU 上。
- 第三步:第 2 层前向传播的计算,第 3 层参数的广播通信。同时完成这两件事:①在每个 GPU 上计算第 2 层的前向传播,并保存结果 h2h_2h2,然后删除上一步广播到的参数(除 GPU-2 外);②由于第 3 层参数存储在 GPU-3 上,因此基于 All-Gather 通信将第 3 层参数广播到所有 GPU 上。
- 第四步:第 3 层前向传播的计算,第 4 层参数的广播通信。同时完成这两件事:①在每个 GPU 上计算第 3 层的前向传播,并保存结果 h3h_3h3,然后删除上一步广播到的参数(除 GPU-3 外);②由于第 4 层参数存储在 GPU-4 上,因此基于 All-Gather 通信将第 4 层参数广播到所有 GPU 上。
- 第五步:第 4 层前向传播的计算。在每个 GPU 上计算第 4 层的前向传播,并保存结果 yyy,然后删除上一步广播到的参数(除 GPU-4 外)(删除操作可选)。
反向传播重叠过程:
- 第一步:输出层反向传播的计算,第 4 层反向传播的计算,第 4 层参数的广播通信。同时完成这两件事:①在每个 GPU 上计算输出层和第 4 层的反向传播,并保存结果 ∂L∂y\frac{\partial L}{\partial y}∂y∂L,∂L∂w4\frac{\partial L}{\partial w_4}∂w4∂L,∂L∂b4\frac{\partial L}{\partial b_4}∂b4∂L;②由于第 4 层参数存储在 GPU-4 上,因此基于 All-Gather 通信将第 4 层参数广播到所有 GPU 上(如果前向传播时没有进行删除操作,则此操作省略)。
- 第二步:第 3 层反向传播的计算,第 3 层参数的广播通信,第 4 层参数的更新。同时完成这三件事:①在每个 GPU 上计算第 3 层的反向传播,并保存结果 ∂L∂h3\frac{\partial L}{\partial h_3}∂h3∂L,∂L∂w3\frac{\partial L}{\partial w_3}∂w3∂L,∂L∂b3\frac{\partial L}{\partial b_3}∂b3∂L,然后删除上一步广播到的参数(除 GPU-4 外);②由于第 3 层参数存储在 GPU-3 上,因此基于 All-Gather 通信将第 3 层参数广播到所有 GPU 上;③基于 Reduce-Scatter 操作将第 4 层梯度收集到 GPU-4 上,然后计算第 4 层优化器状态,更新第 4 层参数。
- 第三步:第 2 层反向传播的计算,第 2 层参数的广播通信,第 3 层参数的更新。同时完成这三件事:①在每个 GPU 上计算第 2 层的反向传播,并保存结果 ∂L∂h2\frac{\partial L}{\partial h_2}∂h2∂L,∂L∂w2\frac{\partial L}{\partial w_2}∂w2∂L,∂L∂b2\frac{\partial L}{\partial b_2}∂b2∂L,然后删除上一步广播到的参数(除 GPU-3 外);②由于第 2 层参数存储在 GPU-2 上,因此基于 All-Gather 通信将第 2 层参数广播到所有 GPU 上;③基于 Reduce-Scatter 操作将第 3 层梯度收集到 GPU-3 上,然后计算第 3 层优化器状态,更新第 3 层参数。
- 第四步:第 1 层反向传播的计算,第 2 层参数的更新:在每个 GPU 上计算第 1 层的反向传播,并保存结果 ∂L∂h1\frac{\partial L}{\partial h_1}∂h1∂L,∂L∂w1\frac{\partial L}{\partial w_1}∂w1∂L,∂L∂b1\frac{\partial L}{\partial b_1}∂b1∂L,然后删除上一步广播到的参数(除 GPU-2 外);基于 Reduce-Scatter 操作将第 2 层梯度收集到 GPU-2 上,然后计算第 2 层优化器状态,更新第 2 层参数。
- 第五步:第 1 层参数的更新:在每个 GPU 上计算第 1 层优化器状态,更新第 1 层参数。
(五)ZeRO-R 优化
ZeRO-R(ZeRO-Redundancy)主要针对非参数存储进行优化,包括激活值、临时缓冲区和内存碎片等问题。
1. 激活值存储优化
在前向传播中,需要保存中间激活值用于反向传播,对于大模型和大 batch size,激活值显存占用可能过大。
传统方法:每个GPU保存完整激活值
- GPU-0: [a1, a2, a3, a4]
- GPU-1: [a1, a2, a3, a4]
- GPU-2: [a1, a2, a3, a4]
- GPU-3: [a1, a2, a3, a4]
ZeRO-R:激活值分区存储
- GPU-0: [a1]
- GPU-1: [a2]
- GPU-2: [a3]
- GPU-3: [a4]
2. 临时缓冲区优化
All-Reduce、All-Gather等操作需要临时缓冲区,这些缓冲区可能很大,且需要频繁分配释放。
传统方法:每次操作动态分配缓冲区
ZeRO-R:
- 固定大小的持久化缓冲区
- 每次操作复用这块缓冲区,避免频繁分配释放
3. 内存碎片优化
由于张量的频繁分配和释放,GPU显存会出现碎片化,导致即使总空闲内存足够,也无法分配大张量。
ZeRO-R 解决方案:
- 首先尝试在现有空闲块中分配大张量
- 如果不存在足够大的空闲块,则尝试合并相邻的小碎片,形成大碎片
- 如果合并后仍无法满足需求,则触发垃圾回收,释放所有未使用的内存