先看GPU结构,我们常说显存的时候,说的一般就是Global memory

训练的过程中,我们为了反向传播过程,必须将中间的结果(激活值)存储下来。

在训练的过程中,那些会消耗内存呢?
- model weights
- optimizer sates
- intermediate activation values
对于有N层的神经网络来说,内存的消耗是O(N)的。
检查点技术
在前向传播的时候,只选择保留部分数值,当进行反向传播时,所需要的中间值会进行重计算。


这样虽然会增减计算成本,但是也大大减少了内存占用。
模型并行
将模型进行拆分

数据并行
将minibatch 划分成更小的micobatch,训练每个batch的时候,每个工作节点获得一个microbatch,
梯度更新
各个节点之间计算出来的梯度要统一,可以使用 all-reduce或者 使用一个参数服务器用来统一更新各个节点之间的梯度。

为了加快训练,可以使得参数传递和计算过程互相掩盖

READING LIST:
- ZeRO
- Beyond Data and Model Parallelism for Deep Neural Networks
- GSPMD: General and Scalable Parallelization for ML Computation Graphs