深度学习专题:模型训练的数据并行(三)

深度学习专题:模型训练的数据并行(三)

ZeRO 系列优化:ZeRO-1、ZeRO-2、ZeRO-3、ZeRO-R

(一)Adam 优化器

传统的 SGD 优化器第 t+1 轮参数更新公式为:
wt+1=wt−lr⋅gt+1 w_{t+1} = w_t - lr \cdot g_{t+1} wt+1=wt−lr⋅gt+1

其中 wtw_twt 为第 t 轮参数,lrlrlr 为学习率(不变的定值),gtg_tgt 为第 t 轮参数的梯度。

这种优化策略容易遇到驻点问题:当模型参数的梯度接近于零(即 gt+1g_{t+1}gt+1 接近于零)时,参数更新会非常缓慢,甚至停滞不前。

Adam 优化器为了解决驻点问题,在 SGD 优化器的基础上,引入了动量和自适应学习率的概念。其第 t+1 轮参数更新公式为:
wt+1=wt−lr⋅mt+1vt+1+ϵ w_{t+1} = w_t - lr \cdot \frac{m_{t+1}}{\sqrt{v_{t+1}} + \epsilon} wt+1=wt−lr⋅vt+1 +ϵmt+1

其中 mtm_tmt 为第 t 轮的动量,vtv_tvt 为第 t 轮的方差值,ϵ\epsilonϵ 为一个小的常量(防止除零错误)。特别地,动量与方差的计算公式为:
mt+1=β1mt+(1−β1)gt+11−β1t+1vt+1=β2vt+(1−β2)gt+121−β2t+1 m_{t+1} = \frac{\beta_1 m_t + (1 - \beta_1) g_{t+1}}{1 - \beta_1^{t+1}} \\ v_{t+1} = \frac{\beta_2 v_t + (1 - \beta_2) g_{t+1}^2}{1 - \beta_2^{t+1}} mt+1=1−β1t+1β1mt+(1−β1)gt+1vt+1=1−β2t+1β2vt+(1−β2)gt+12

其中 β1\beta_1β1 和 β2\beta_2β2 为两个超参数(常设为 0.9 和 0.999)。初值为 m0=0m_0 = 0m0=0 和 v0=0v_0 = 0v0=0。

(二)Adam 优化器数据并行存储代价

按照 Adam 优化器的公式,模型在训练过程中不仅要存储当前轮的参数 wtw_twt 和梯度 gtg_tgt,还要存储当前轮的动量 mtm_tmt 和方差 vtv_tvt。假设有 N 个参数,在数据并行场景下,与 SGD 优化器比较分析第 t+1 轮 epoch 的反向传播过程:

1. SGD 优化器

SGD 优化器在每个 GPU 中需要额外存储 2N2N2N 个浮点数。

存储项 1 2 ... N
参数向量 www t t ... t
梯度向量 ggg - - ... -

第一步:第 k 个 GPU 读取分配给自己的训练数据,读取第 t 轮 epoch 的参数向量 wtw_twt 并进行前向传播、损失计算和反向传播,得到第 t+1 轮 epoch 的梯度向量

存储项 1 2 ... N
参数向量 www t t ... t
梯度向量 ggg t+1, k t+1, k ... t+1, k

第二步:基于 Ring All-Reduce 操作获取其他所有 GPU 上的第 t+1 轮 epoch 的梯度向量,计算梯度向量均值 gt+1g_{t+1}gt+1 并存储;

存储项 1 2 ... N
参数向量 www t t ... t
梯度向量 ggg t+1 t+1 ... t+1

第三步:根据聚合后的梯度向量 gt+1g_{t+1}gt+1,更新参数向量 wt+1w_{t+1}wt+1。

存储项 1 2 ... N
参数向量 www t+1 t+1 ... t+1
梯度向量 ggg - - ... -
2. Adam 优化器

Adam 优化器在每个 GPU 中需要额外存储 4N4N4N 个浮点数。

存储项 1 2 ... N
参数向量 www t t ... t
动量向量 mmm t t ... t
方差向量 vvv t t ... t
梯度向量 ggg - - ... -

第一步:第 k 个 GPU 读取分配给自己的训练数据,读取第 t 轮 epoch 的参数向量 wtw_twt 并进行前向传播、损失计算和反向传播,得到第 t+1 轮 epoch 的梯度向量;

存储项 1 2 ... N
参数向量 www t t ... t
动量向量 mmm t t ... t
方差向量 vvv t t ... t
梯度向量 ggg t+1, k t+1, k ... t+1, k

第二步:基于 Ring All-Reduce 操作获取其他所有 GPU 上的第 t+1 轮 epoch 的梯度向量,计算梯度向量均值 gt+1g_{t+1}gt+1 并存储;

存储项 1 2 ... N
参数向量 www t t ... t
动量向量 mmm t t ... t
方差向量 vvv t t ... t
梯度向量 ggg t+1 t+1 ... t+1

第三步:读取第 t 轮 epoch 的动量 mtm_tmt 和方差 vtv_tvt 并根据聚合后的第 t+1 轮 epoch 的梯度向量 gt+1g_{t+1}gt+1 计算第 t+1 轮 epoch 的动量 mt+1m_{t+1}mt+1 和方差 vt+1v_{t+1}vt+1 并存储;

存储项 1 2 ... N
参数向量 www t t ... t
动量向量 mmm t+1 t+1 ... t+1
方差向量 vvv t+1 t+1 ... t+1
梯度向量 ggg t+1 t+1 ... t+1

第四步:根据动量 mt+1m_{t+1}mt+1 和方差 vt+1v_{t+1}vt+1,更新参数向量 wt+1w_{t+1}wt+1。

存储项 1 2 ... N
参数向量 www t+1 t+1 ... t+1
动量向量 mmm t+1 t+1 ... t+1
方差向量 vvv t+1 t+1 ... t+1
梯度向量 ggg - - ... -
3. 分析与总结

GPU 需要存储第 t 轮 epoch 的参数向量 wtw_twt,因为在第 t+1 轮训练过程中需要根据 wtw_twt 进行前向传播、损失计算和反向传播,从而计算出第 t+1 轮 epoch 的梯度向量 gt+1g_{t+1}gt+1;

GPU 还需要存储第 t 轮 epoch 的动量向量 mtm_tmt 和方差向量 vtv_tvt,因为在第 t+1 轮训练过程中需要根据 mtm_tmt 和 vtv_tvt 计算出第 t+1 轮 epoch 的动量向量 mt+1m_{t+1}mt+1 和方差向量 vt+1v_{t+1}vt+1;

上面的两个存储项均涉及到跨 epoch 存储,所以对应的空间在整个训练的生命周期都不能被释放。

GPU 还需要临时存储第 t+1 轮 epoch 的梯度向量 gt+1g_{t+1}gt+1,因为 GPU 需要使用 Ring All-Reduce 算法在所有 GPU 之间归约梯度向量并计算均值。如果是 SGD 优化器,这个梯度向量会直接用于更新参数;如果是 Adam 优化器,这个梯度向量会用于更新动量和方差。不同的是,完成对应的更新操作后,梯度向量就不需要继续存储了,对应的空间就可以被释放了。所以梯度向量只需要临时存储,而不需要跨 epoch 存储。

(三)ZeRO-1 优化原理

为了减少 Adam 优化器的存储代价,ZeRO-1 提出了一种优化策略:​​将优化器状态在多个 GPU 之间进行分区存储​​,而不是在每个 GPU 上完整保存所有优化器状态。

在数据并行训练中,每个 GPU 通常保存完整的模型参数、梯度和优化器状态。ZeRO-1 观察到,优化器状态(如 Adam 中的动量和方差)在训练过程中主要参与参数更新计算,但不同参数的优化器状态之间是相互独立的。因此,可以将优化器状态按照参数维度进行分区,每个 GPU 只需要存储自己负责的那部分动量向量 mt+1m_{t+1}mt+1 和方差向量 vt+1v_{t+1}vt+1 并更新对应的参数,然后通过 GPU 的 All-Gather 算法同步所有更新后的参数。

1. 步骤分析

准备步骤:假设模型有 N 个参数,有 P 个 GPU 负责并行训练

  • 每个 GPU 负责管理 M=⌈N/P⌉M = \lceil N/P \rceilM=⌈N/P⌉ 个参数
  • 第 k 个 GPU 将负责第 sk=1+(k−1)Ms_k = 1+(k-1)Msk=1+(k−1)M 到 ek=kMe_k=kMek=kM 个参数
存储项 1 2 ... sk−1s_k-1sk−1 sks_ksk sk+1s_k+1sk+1 ... eke_kek ek+1e_k+1ek+1 ... N
参数向量 www t t ... t t t ... t t ... t
动量向量 mmm - - ... - t t ... t - ... -
方差向量 vvv - - ... - t t ... t - ... -
梯度向量 ggg - - ... - - - ... - - ... -

第一步:第 k 个 GPU 读取分配给自己的训练数据,读取第 t 轮 epoch 的参数向量 wtw_twt 并进行前向传播、损失计算和反向传播,得到第 t+1 轮 epoch 的梯度向量;

存储项 1 2 ... sk−1s_k-1sk−1 sks_ksk sk+1s_k+1sk+1 ... eke_kek ek+1e_k+1ek+1 ... N
参数向量 www t t ... t t t ... t t ... t
动量向量 mmm - - ... - t t ... t - ... -
方差向量 vvv - - ... - t t ... t - ... -
梯度向量 ggg t+1, k t+1, k ... t+1, k t+1, k t+1, k ... t+1, k t+1, k ... t+1, k

第二步:基于 Ring All-Reduce 操作获取其他所有 GPU 上的第 t+1 轮 epoch 的梯度向量,计算梯度向量均值 gt+1g_{t+1}gt+1 并存储;

存储项 1 2 ... sk−1s_k-1sk−1 sks_ksk sk+1s_k+1sk+1 ... eke_kek ek+1e_k+1ek+1 ... N
参数向量 www t t ... t t t ... t t ... t
动量向量 mmm - - ... - t t ... t - ... -
方差向量 vvv - - ... - t t ... t - ... -
梯度向量 ggg t+1 t+1 ... t+1 t+1 t+1 ... t+1 t+1 ... t+1

第三步:根据聚合后的梯度向量 gt+1g_{t+1}gt+1,更新自己负责的那一部分参数的动量和方差;

存储项 1 2 ... sk−1s_k-1sk−1 sks_ksk sk+1s_k+1sk+1 ... eke_kek ek+1e_k+1ek+1 ... N
参数向量 www t t ... t t t ... t t ... t
动量向量 mmm - - ... - t+1 t+1 ... t+1 - ... -
方差向量 vvv - - ... - t+1 t+1 ... t+1 - ... -
梯度向量 ggg t+1 t+1 ... t+1 t+1 t+1 ... t+1 t+1 ... t+1

第四步:每个 GPU 根据自己负责的那部分更新后的动量和方差,更新自己负责的那部分参数向量,然后释放梯度向量 gt+1g_{t+1}gt+1 占用的存储空间;

存储项 1 2 ... sk−1s_k-1sk−1 sks_ksk sk+1s_k+1sk+1 ... eke_kek ek+1e_k+1ek+1 ... N
参数向量 www t t ... t t+1 t+1 ... t+1 t ... t
动量向量 mmm - - ... - t+1 t+1 ... t+1 - ... -
方差向量 vvv - - ... - t+1 t+1 ... t+1 - ... -
梯度向量 ggg - - ... - - - ... - - ... -

第五步:每个 GPU 与其他所有 GPU 通过 All-Gather 算法通信,得到更新后的参数向量 wt+1w_{t+1}wt+1。

存储项 1 2 ... sk−1s_k-1sk−1 sks_ksk sk+1s_k+1sk+1 ... eke_kek ek+1e_k+1ek+1 ... N
参数向量 www t+1 t+1 ... t+1 t+1 t+1 ... t+1 t+1 ... t+1
动量向量 mmm - - ... - t+1 t+1 ... t+1 - ... -
方差向量 vvv - - ... - t+1 t+1 ... t+1 - ... -
梯度向量 ggg - - ... - - - ... - - ... -
2. 存储优化分析
  • 传统数据并行:每个 GPU 需要存储 4N4N4N 个浮点数;
  • ZeRO-1 优化:每个 GPU 负责管理 M=⌈N/P⌉M = \lceil N/P \rceilM=⌈N/P⌉ 个参数,所以每个 GPU 需要存储 2M+2N2M + 2N2M+2N 个浮点数。当 M≪NM \ll NM≪N 时,存储量趋近于 2N2N2N,存储代价可以显著减少。

(四)ZeRO-2 优化原理

ZeRO-2 在 ZeRO-1 的基础上进一步优化,提出了​​梯度分区存储​​的策略,其核心通信模式是 ​​Reduce-Scatter​​ 而不是 All-Reduce。

在传统数据并行中,梯度同步使用 All-Reduce,每个 GPU 最终都获得完整的平均梯度。ZeRO-2 改为使用 Reduce-Scatter,每个 GPU 只获得部分梯度的聚合结果,与优化器状态的分区相对应。

1. 步骤分析

准备步骤:假设模型有 N 个参数,有 P 个 GPU 负责并行训练

  • 每个 GPU 负责管理 M=⌈N/P⌉M = \lceil N/P \rceilM=⌈N/P⌉ 个参数
  • 第 k 个 GPU 将负责第 sk=1+(k−1)Ms_k = 1+(k-1)Msk=1+(k−1)M 到 ek=kMe_k=kMek=kM 个参数
存储项 1 2 ... sk−1s_k-1sk−1 sks_ksk sk+1s_k+1sk+1 ... eke_kek ek+1e_k+1ek+1 ... N
参数向量 www t t ... t t t ... t t ... t
动量向量 mmm - - ... - t t ... t - ... -
方差向量 vvv - - ... - t t ... t - ... -
梯度向量 ggg - - ... - - - ... - - ... -

第一步:第 k 个 GPU 读取分配给自己的训练数据,读取第 t 轮 epoch 的参数向量 wtw_twt 并进行前向传播、损失计算和反向传播,得到第 t+1 轮 epoch 的梯度向量;

存储项 1 2 ... sk−1s_k-1sk−1 sks_ksk sk+1s_k+1sk+1 ... eke_kek ek+1e_k+1ek+1 ... N
参数向量 www t t ... t t t ... t t ... t
动量向量 mmm - - ... - t t ... t - ... -
方差向量 vvv - - ... - t t ... t - ... -
梯度向量 ggg t+1, k t+1, k ... t+1, k t+1, k t+1, k ... t+1, k t+1, k ... t+1, k

第二步:基于 Reduce-Scatter 操作获取其他所有 GPU 上的第 t+1 轮 epoch 的梯度向量的第 sks_ksk 到 eke_kek 个元素,计算梯度向量均值 gt+1g_{t+1}gt+1 的第 sks_ksk 到 eke_kek 个元素并存储,然后释放梯度向量均值 gt+1g_{t+1}gt+1 的 sks_ksk 到 eke_kek 之外的元素占用的存储空间;

存储项 1 2 ... sk−1s_k-1sk−1 sks_ksk sk+1s_k+1sk+1 ... eke_kek ek+1e_k+1ek+1 ... N
参数向量 www t t ... t t t ... t t ... t
动量向量 mmm - - ... - t t ... t - ... -
方差向量 vvv - - ... - t t ... t - ... -
梯度向量 ggg - - ... - t+1 t+1 ... t+1 - ... -

第三步:根据聚合后的梯度向量 gt+1g_{t+1}gt+1 的 sks_ksk 到 eke_kek 个元素,更新自己负责的那一部分参数的动量和方差;

存储项 1 2 ... sk−1s_k-1sk−1 sks_ksk sk+1s_k+1sk+1 ... eke_kek ek+1e_k+1ek+1 ... N
参数向量 www t t ... t t t ... t t ... t
动量向量 mmm - - ... - t+1 t+1 ... t+1 - ... -
方差向量 vvv - - ... - t+1 t+1 ... t+1 - ... -
梯度向量 ggg - - ... - t+1 t+1 ... t+1 - ... -

第四步:每个 GPU 根据自己负责的那部分更新后的动量和方差,更新自己负责的那部分参数向量,然后释放梯度向量 gt+1g_{t+1}gt+1 的第 sks_ksk 到 eke_kek 的元素占用的存储空间;

存储项 1 2 ... sk−1s_k-1sk−1 sks_ksk sk+1s_k+1sk+1 ... eke_kek ek+1e_k+1ek+1 ... N
参数向量 www t t ... t t+1 t+1 ... t+1 t ... t
动量向量 mmm - - ... - t+1 t+1 ... t+1 - ... -
方差向量 vvv - - ... - t+1 t+1 ... t+1 - ... -
梯度向量 ggg - - ... - - - ... - - ... -

第五步:每个 GPU 与其他所有 GPU 通过 All-Gather 算法通信,得到更新后的参数向量 wt+1w_{t+1}wt+1。

存储项 1 2 ... sk−1s_k-1sk−1 sks_ksk sk+1s_k+1sk+1 ... eke_kek ek+1e_k+1ek+1 ... N
参数向量 www t+1 t+1 ... t+1 t+1 t+1 ... t+1 t+1 ... t+1
动量向量 mmm - - ... - t+1 t+1 ... t+1 - ... -
方差向量 vvv - - ... - t+1 t+1 ... t+1 - ... -
梯度向量 ggg - - ... - - - ... - - ... -
2. 存储优化分析
  • ZeRO-1 优化:每个 GPU 需要存储 2M+2N2M + 2N2M+2N 个浮点数,当 M≪NM \ll NM≪N 时,存储量趋近于 2N2N2N,存储代价可以显著减少;
  • ZeRO-2 优化:每个 GPU 需要存储 3M+N3M + N3M+N 个浮点数,当 M≪NM \ll NM≪N 时,存储量趋近于 NNN,存储代价可以显著减少。

(五)ZeRO-3 优化原理

ZeRO-3 在 ZeRO-2 的基础上进一步优化,提出了​​参数分区存储​​的策略,每个 GPU 只存储自己负责的那部分参数,这需要在训练的前向传播和反向传播之前通过 All-Gather 通信收集全部参数,从而达成一种通信换存储的优化效果。

1. 步骤分析

准备步骤:假设模型有 N 个参数,有 P 个 GPU 负责并行训练

  • 每个 GPU 负责管理 M=⌈N/P⌉M = \lceil N/P \rceilM=⌈N/P⌉ 个参数
  • 第 k 个 GPU 将负责第 sk=1+(k−1)Ms_k = 1+(k-1)Msk=1+(k−1)M 到 ek=kMe_k=kMek=kM 个参数,且只存储这些参数。
存储项 1 2 ... sk−1s_k-1sk−1 sks_ksk sk+1s_k+1sk+1 ... eke_kek ek+1e_k+1ek+1 ... N
参数向量 www - - ... - t t ... t - ... -
动量向量 mmm - - ... - t t ... t - ... -
方差向量 vvv - - ... - t t ... t - ... -
梯度向量 ggg - - ... - - - ... - - ... -

第一步:第 k 个 GPU 通过 All-Gather 通信获得第 t 轮 epoch 的参数向量 wtw_twt,然后读取分配给自己的训练数据,完成前向传播,计算得到损失值;

存储项 1 2 ... sk−1s_k-1sk−1 sks_ksk sk+1s_k+1sk+1 ... eke_kek ek+1e_k+1ek+1 ... N
参数向量 www t t ... t t t ... t t ... t
动量向量 mmm - - ... - t t ... t - ... -
方差向量 vvv - - ... - t t ... t - ... -
梯度向量 ggg - - ... - - - ... - - ... -

第二步:进行反向传播,计算得到第 t+1 轮 epoch 的梯度向量,然后释放第 t 轮参数向量的 sks_ksk 到 eke_kek 之外的元素占用的存储空间;

存储项 1 2 ... sk−1s_k-1sk−1 sks_ksk sk+1s_k+1sk+1 ... eke_kek ek+1e_k+1ek+1 ... N
参数向量 www - - ... - t t ... t - ... -
动量向量 mmm - - ... - t t ... t - ... -
方差向量 vvv - - ... - t t ... t - ... -
梯度向量 ggg t+1, k t+1, k ... t+1, k t+1, k t+1, k ... t+1, k t+1, k ... t+1, k

第三步:基于 Reduce-Scatter 操作获取其他所有 GPU 上的第 t+1 轮 epoch 的梯度向量的第 sks_ksk 到 eke_kek 个元素,计算梯度向量均值 gt+1g_{t+1}gt+1 的第 sks_ksk 到 eke_kek 个元素并存储,然后释放梯度向量均值 gt+1g_{t+1}gt+1 的 sks_ksk 到 eke_kek 之外的元素占用的存储空间;

存储项 1 2 ... sk−1s_k-1sk−1 sks_ksk sk+1s_k+1sk+1 ... eke_kek ek+1e_k+1ek+1 ... N
参数向量 www - - ... - t t ... t - ... -
动量向量 mmm - - ... - t t ... t - ... -
方差向量 vvv - - ... - t t ... t - ... -
梯度向量 ggg - - ... - t+1 t+1 ... t+1 - ... -

第四步:根据聚合后的梯度向量 gt+1g_{t+1}gt+1 的 sks_ksk 到 eke_kek 个元素,更新自己负责的那一部分参数的动量和方差;

存储项 1 2 ... sk−1s_k-1sk−1 sks_ksk sk+1s_k+1sk+1 ... eke_kek ek+1e_k+1ek+1 ... N
参数向量 www - - ... - t t ... t - ... -
动量向量 mmm - - ... - t+1 t+1 ... t+1 - ... -
方差向量 vvv - - ... - t+1 t+1 ... t+1 - ... -
梯度向量 ggg - - ... - t+1 t+1 ... t+1 - ... -

第五步:每个 GPU 根据自己负责的那部分更新后的动量和方差,更新自己负责的那部分参数向量,然后释放梯度向量 gt+1g_{t+1}gt+1 的第 sks_ksk 到 eke_kek 的元素占用的存储空间;

存储项 1 2 ... sk−1s_k-1sk−1 sks_ksk sk+1s_k+1sk+1 ... eke_kek ek+1e_k+1ek+1 ... N
参数向量 www - - ... - t+1 t+1 ... t+1 - ... -
动量向量 mmm - - ... - t+1 t+1 ... t+1 - ... -
方差向量 vvv - - ... - t+1 t+1 ... t+1 - ... -
梯度向量 ggg - - ... - - - ... - - ... -
2. 存储优化分析
  • ZeRO-3 优化:每个 GPU 需要存储 4M4M4M 个浮点数,当 M≪NM \ll NM≪N 时,存储量趋近于 000,存储代价可以显著减少。

(四)ZeRO-3 层间流水线优化

由于层间流水线优化可以将多个层的前向传播和反向传播并行执行,而每个层只需要存储自己负责的那部分参数,因此 ZeRO-3 与层间流水线优化结合可以进一步减少存储代价。

1. 前向传播和反向传播的计算资料

考虑下面的模型:

  • 输入输出都是标量,分别记作 xxx 和 yyy;
  • 模型有 4 个层,其中前三个层是全连接层加上 ReLU 激活函数,最后一个层是全连接层;
  • 每个层的参数数量都是 2 个,分别是 权重 www 和 偏置 bbb;
  • 模型的前向传播:
    • 第 1 层:h1=ReLU(w1×x+b1)h_1 = \text{ReLU}(w_1 \times x + b_1)h1=ReLU(w1×x+b1)
    • 第 2 层:h2=ReLU(w2×h1+b2)h_2 = \text{ReLU}(w_2 \times h_1 + b_2)h2=ReLU(w2×h1+b2)
    • 第 3 层:h3=ReLU(w3×h2+b3)h_3 = \text{ReLU}(w_3 \times h_2 + b_3)h3=ReLU(w3×h2+b3)
    • 第 4 层:y=w4×h3+b4y = w_4 \times h_3 + b_4y=w4×h3+b4
    • 损失函数:L=12(y−x)2L = \frac{1}{2}(y - x)^2L=21(y−x)2
  • 模型的反向传播:
    • 输出层:

      • ∂L∂y=y−x\frac{\partial L}{\partial y} = y - x∂y∂L=y−x
    • 第 4 层:

      • ∂L∂w4=∂L∂y×∂y∂w4=(y−x)×h3\frac{\partial L}{\partial w_4} = \frac{\partial L}{\partial y} \times \frac{\partial y}{\partial w_4} = (y - x) \times h_3∂w4∂L=∂y∂L×∂w4∂y=(y−x)×h3
      • ∂L∂b4=∂L∂y×∂y∂b4=y−x\frac{\partial L}{\partial b_4} = \frac{\partial L}{\partial y} \times \frac{\partial y}{\partial b_4} = y - x∂b4∂L=∂y∂L×∂b4∂y=y−x
    • 第 3 层:

      • ∂L∂h3=∂L∂y×∂y∂h3=(y−x)×w4\frac{\partial L}{\partial h_3} = \frac{\partial L}{\partial y} \times \frac{\partial y}{\partial h_3} = (y - x) \times w_4∂h3∂L=∂y∂L×∂h3∂y=(y−x)×w4
      • ∂L∂w3=∂L∂h3×∂h3∂w3=∂L∂h3×h2×I(h3>0)\frac{\partial L}{\partial w_3} = \frac{\partial L}{\partial h_3} \times \frac{\partial h_3}{\partial w_3} = \frac{\partial L}{\partial h_3} \times h_2 \times \mathbb{I}(h_3 > 0)∂w3∂L=∂h3∂L×∂w3∂h3=∂h3∂L×h2×I(h3>0)
      • ∂L∂b3=∂L∂h3×∂h3∂b3=∂L∂h3×I(h3>0)\frac{\partial L}{\partial b_3} = \frac{\partial L}{\partial h_3} \times \frac{\partial h_3}{\partial b_3} = \frac{\partial L}{\partial h_3} \times \mathbb{I}(h_3 > 0)∂b3∂L=∂h3∂L×∂b3∂h3=∂h3∂L×I(h3>0)
    • 第 2 层:

      • ∂L∂h2=∂L∂h3×∂h3∂h2=∂L∂h3×w3×I(h3>0)\frac{\partial L}{\partial h_2} = \frac{\partial L}{\partial h_3} \times \frac{\partial h_3}{\partial h_2} = \frac{\partial L}{\partial h_3} \times w_3 \times \mathbb{I}(h_3 > 0)∂h2∂L=∂h3∂L×∂h2∂h3=∂h3∂L×w3×I(h3>0)

      • ∂L∂w2=∂L∂h2×∂h2∂w2=∂L∂h2×h1×I(h2>0)\frac{\partial L}{\partial w_2} = \frac{\partial L}{\partial h_2} \times \frac{\partial h_2}{\partial w_2} = \frac{\partial L}{\partial h_2} \times h_1 \times \mathbb{I}(h_2 > 0)∂w2∂L=∂h2∂L×∂w2∂h2=∂h2∂L×h1×I(h2>0)

      • ∂L∂b2=∂L∂h2×∂h2∂b2=∂L∂h2×I(h2>0)\frac{\partial L}{\partial b_2} = \frac{\partial L}{\partial h_2} \times \frac{\partial h_2}{\partial b_2} = \frac{\partial L}{\partial h_2} \times \mathbb{I}(h_2 > 0)∂b2∂L=∂h2∂L×∂b2∂h2=∂h2∂L×I(h2>0)

    • 第 1 层:

      • ∂L∂h1=∂L∂h2×∂h2∂h1=∂L∂h2×w2×I(h2>0)\frac{\partial L}{\partial h_1} = \frac{\partial L}{\partial h_2} \times \frac{\partial h_2}{\partial h_1} = \frac{\partial L}{\partial h_2} \times w_2 \times \mathbb{I}(h_2 > 0)∂h1∂L=∂h2∂L×∂h1∂h2=∂h2∂L×w2×I(h2>0)

      • ∂L∂w1=∂L∂h1×∂h1∂w1=∂L∂h1×x×I(h1>0)\frac{\partial L}{\partial w_1} = \frac{\partial L}{\partial h_1} \times \frac{\partial h_1}{\partial w_1} = \frac{\partial L}{\partial h_1} \times x \times \mathbb{I}(h_1 > 0)∂w1∂L=∂h1∂L×∂w1∂h1=∂h1∂L×x×I(h1>0)

      • ∂L∂b1=∂L∂h1×∂h1∂b1=∂L∂h1×I(h1>0)\frac{\partial L}{\partial b_1} = \frac{\partial L}{\partial h_1} \times \frac{\partial h_1}{\partial b_1} = \frac{\partial L}{\partial h_1} \times \mathbb{I}(h_1 > 0)∂b1∂L=∂h1∂L×∂b1∂h1=∂h1∂L×I(h1>0)

分析可知:

(1)第 m 层的前向传播,需要的计算资料有:

  • wmw_mwm
  • bmb_mbm
  • hm−1h_{m-1}hm−1

其中,hm−1h_{m-1}hm−1 是前一层的输出,由前一层计算结束后保存;wmw_mwm 和 bmb_mbm 是本层的参数,如果参数是分区存储的,那么参数需要通过 All-Gather 通信获取。

(2)第 m 层的反向传播,需要的计算资料有:

  • ∂L∂hm+1\frac{\partial L}{\partial h_{m+1}}∂hm+1∂L
  • wm+1w_{m+1}wm+1
  • hm−1h_{m-1}hm−1

其中,∂L∂hm+1\frac{\partial L}{\partial h_{m+1}}∂hm+1∂L 是后一层的梯度,由后一层计算结束后保存;hm−1h_{m-1}hm−1 是前一层的输出,由前一层计算结束后保存;wm+1w_{m+1}wm+1 是下一层的参数,需要通过 All-Gather 通信获取。

2. 前向传播和反向传播的层间流水线并行执行

基于上面的讨论,我们可以让前向传播和反向传播进行层间流水线并行执行,即通信过程和计算过程重叠。

假设有 4 块 GPU,每块 GPU 负责存储 1 个层的参数,每个层的参数数量都是 2 个;此外每块 GPU 都负责 1/4 的 mini batch size 的训练数据进行训练。

前向传播重叠过程:

  • 第一步:第 1 层参数的广播通信。由于第 1 层参数存储在 GPU-1 上,因此基于 All-Gather 通信将第 1 层参数广播到所有 GPU 上。
  • 第二步:第 1 层前向传播的计算,第 2 层参数的广播通信。同时完成这两件事:①在每个 GPU 上计算第 1 层的前向传播,并保存结果 h1h_1h1,然后删除上一步广播到的参数(除 GPU-1 外);②由于第 2 层参数存储在 GPU-2 上,因此基于 All-Gather 通信将第 2 层参数广播到所有 GPU 上。
  • 第三步:第 2 层前向传播的计算,第 3 层参数的广播通信。同时完成这两件事:①在每个 GPU 上计算第 2 层的前向传播,并保存结果 h2h_2h2,然后删除上一步广播到的参数(除 GPU-2 外);②由于第 3 层参数存储在 GPU-3 上,因此基于 All-Gather 通信将第 3 层参数广播到所有 GPU 上。
  • 第四步:第 3 层前向传播的计算,第 4 层参数的广播通信。同时完成这两件事:①在每个 GPU 上计算第 3 层的前向传播,并保存结果 h3h_3h3,然后删除上一步广播到的参数(除 GPU-3 外);②由于第 4 层参数存储在 GPU-4 上,因此基于 All-Gather 通信将第 4 层参数广播到所有 GPU 上。
  • 第五步:第 4 层前向传播的计算。在每个 GPU 上计算第 4 层的前向传播,并保存结果 yyy,然后删除上一步广播到的参数(除 GPU-4 外)(删除操作可选)。

反向传播重叠过程:

  • 第一步:输出层反向传播的计算,第 4 层反向传播的计算,第 4 层参数的广播通信。同时完成这两件事:①在每个 GPU 上计算输出层和第 4 层的反向传播,并保存结果 ∂L∂y\frac{\partial L}{\partial y}∂y∂L,∂L∂w4\frac{\partial L}{\partial w_4}∂w4∂L,∂L∂b4\frac{\partial L}{\partial b_4}∂b4∂L;②由于第 4 层参数存储在 GPU-4 上,因此基于 All-Gather 通信将第 4 层参数广播到所有 GPU 上(如果前向传播时没有进行删除操作,则此操作省略)。
  • 第二步:第 3 层反向传播的计算,第 3 层参数的广播通信,第 4 层参数的更新。同时完成这三件事:①在每个 GPU 上计算第 3 层的反向传播,并保存结果 ∂L∂h3\frac{\partial L}{\partial h_3}∂h3∂L,∂L∂w3\frac{\partial L}{\partial w_3}∂w3∂L,∂L∂b3\frac{\partial L}{\partial b_3}∂b3∂L,然后删除上一步广播到的参数(除 GPU-4 外);②由于第 3 层参数存储在 GPU-3 上,因此基于 All-Gather 通信将第 3 层参数广播到所有 GPU 上;③基于 Reduce-Scatter 操作将第 4 层梯度收集到 GPU-4 上,然后计算第 4 层优化器状态,更新第 4 层参数。
  • 第三步:第 2 层反向传播的计算,第 2 层参数的广播通信,第 3 层参数的更新。同时完成这三件事:①在每个 GPU 上计算第 2 层的反向传播,并保存结果 ∂L∂h2\frac{\partial L}{\partial h_2}∂h2∂L,∂L∂w2\frac{\partial L}{\partial w_2}∂w2∂L,∂L∂b2\frac{\partial L}{\partial b_2}∂b2∂L,然后删除上一步广播到的参数(除 GPU-3 外);②由于第 2 层参数存储在 GPU-2 上,因此基于 All-Gather 通信将第 2 层参数广播到所有 GPU 上;③基于 Reduce-Scatter 操作将第 3 层梯度收集到 GPU-3 上,然后计算第 3 层优化器状态,更新第 3 层参数。
  • 第四步:第 1 层反向传播的计算,第 2 层参数的更新:在每个 GPU 上计算第 1 层的反向传播,并保存结果 ∂L∂h1\frac{\partial L}{\partial h_1}∂h1∂L,∂L∂w1\frac{\partial L}{\partial w_1}∂w1∂L,∂L∂b1\frac{\partial L}{\partial b_1}∂b1∂L,然后删除上一步广播到的参数(除 GPU-2 外);基于 Reduce-Scatter 操作将第 2 层梯度收集到 GPU-2 上,然后计算第 2 层优化器状态,更新第 2 层参数。
  • 第五步:第 1 层参数的更新:在每个 GPU 上计算第 1 层优化器状态,更新第 1 层参数。

(五)ZeRO-R 优化

ZeRO-R(ZeRO-Redundancy)主要针对非参数存储进行优化,包括激活值、临时缓冲区和内存碎片等问题。

1. 激活值存储优化

在前向传播中,需要保存中间激活值用于反向传播,对于大模型和大 batch size,激活值显存占用可能过大。

传统方法:每个GPU保存完整激活值

  • GPU-0: [a1, a2, a3, a4]
  • GPU-1: [a1, a2, a3, a4]
  • GPU-2: [a1, a2, a3, a4]
  • GPU-3: [a1, a2, a3, a4]

ZeRO-R:激活值分区存储

  • GPU-0: [a1]
  • GPU-1: [a2]
  • GPU-2: [a3]
  • GPU-3: [a4]
2. 临时缓冲区优化

All-Reduce、All-Gather等操作需要临时缓冲区,这些缓冲区可能很大,且需要频繁分配释放。

传统方法:每次操作动态分配缓冲区

ZeRO-R:

  • 固定大小的持久化缓冲区
  • 每次操作复用这块缓冲区,避免频繁分配释放
3. 内存碎片优化

由于张量的频繁分配和释放,GPU显存会出现碎片化,导致即使总空闲内存足够,也无法分配大张量。

ZeRO-R 解决方案:

  • 首先尝试在现有空闲块中分配大张量
  • 如果不存在足够大的空闲块,则尝试合并相邻的小碎片,形成大碎片
  • 如果合并后仍无法满足需求,则触发垃圾回收,释放所有未使用的内存
相关推荐
Gloria_niki2 小时前
图像分割深度学习学习总结
人工智能
武子康3 小时前
AI研究-118 具身智能 Mobile-ALOHA 解读:移动+双臂模仿学习的开源方案(含论文/代码/套件链接)
人工智能·深度学习·学习·机器学习·ai·开源·模仿学习
长桥夜波3 小时前
机器学习日报12
人工智能·机器学习
AI柠檬3 小时前
机器学习:数据集的划分
人工智能·算法·机器学习
tt5555555555553 小时前
《神经网络与深度学习》学习笔记一
深度学习·神经网络·学习
诸葛务农4 小时前
光刻胶分类与特性——g/i线光刻胶及东京应化TP-3000系列胶典型配方(上)
人工智能·材料工程
mm-q29152227294 小时前
YOLOv5(PyTorch)目标检测实战:TensorRT加速部署!训练自己的数据集(Ubuntu)——(人工智能、深度学习、机器学习、神经网络)
人工智能·深度学习·机器学习
搞科研的小刘选手4 小时前
【多所高校合作】第四届图像处理、计算机视觉与机器学习国际学术会议(ICICML 2025)
图像处理·人工智能·机器学习·计算机视觉·数据挖掘·人脸识别·人机交互
FreeCode4 小时前
LangChain1.0智能体开发:消息组件(Messages)
人工智能·langchain·agent