Megatron-LM源码系列(六):Distributed-Optimizer分布式优化器实现Part1

1. 使用说明

在megatron中指定--use-distributed-optimizer就能开启分布式优化器, 参数定义在megatron/arguments.py中。分布式优化器的思路是将训练中的优化器状态均匀地分布到不同数据并行的rank结点上,相当于开启ZERO-1的训练。

python 复制代码
    group.add_argument('--use-distributed-optimizer', action='store_true',
                       help='Use distributed optimizer.')

在使用--use-distributed-optimizer, 同时会check两个参数 args.DDP_impl == 'local'(默认开启)和args.use_contiguous_buffers_in_local_ddp(默认开启)。

python 复制代码
    # If we use the distributed optimizer, we need to have local DDP
    # and we should make sure use-contiguous-buffers-in-local-ddp is on.
    if args.use_distributed_optimizer:
        assert args.DDP_impl == 'local'
        assert args.use_contiguous_buffers_in_local_ddp

分布式优化器节省的理论显存值依赖参数类型和梯度类型,以下是每一个parameter对应占用的理论字节数(d表示数据并行的size大小,也就是一个数据并行中的卡数, 等于 T P × P P TP \times PP TP×PP ):

训练数据类型 Non-distributed optim(单位Byte) Distributed optim(单位Byte)
float16 param, float16 grads 20 4 + 16/d
float16 param, fp32 grads 18 6 + 12/d
fp32 param, fp32 grads 16 8 + 8/d

2. 实现介绍

  • Distributed-Optimizer分布式优化器的主要实现是通过连续的grad buffer来进行的,grad buffer中用于模型状态和优化器状态之间进行parameter参数和grad梯度的通信。grad buffer中使用reduce-scatter和all-gather进行通信。

  • 数据流如下:

    1. 在每个dp的rank上计算完grad后,组成待更新的grad buffer数组
    2. 更新的时候通过reduce-scatter将grad buffer切分到各个rank上
    3. 在每个rank上完成优化器的step操作
    4. 最后将所有结果执行allgather操作得到更新后的grad buffer。
  • 以fp16类型grad为例,grad buffer分片说明如下:

    • 一共有4个参数,分别用绿/黄/蓝/红表示;总参数大小为16个fp16类型数据
    • 按DP中rank的个数对总数据均匀切分
    • 如果参数过大,每个rank可能会只包含部分参数的数据,所以要考虑参数的偏移
    • 每个DP rank中的每个param参数都对应有3个偏移,一个是world_index表示总的数据偏移,一个是local_index表示在当前rank中的数据偏移,一个是param_index相对于param来说,表示当前rank结点存的数据的偏移。
    • 以黄色参数Param1为例,在rank0存了Param1的一个元素,rank1存了Param1的4个元素;world_index来说rank0上黄色部分的元素是总数据的[3,4], rank1上黄色部分的4个元素是总数据的[4,8]; local_index来说在rank0上表示[3,4],rank1表示当前结点全部的4个元素,范围也就是[0,4];param_index来说,对于rank0上的Param1的param_index就是[0,1],在rank2上的param_index就是[1,5];
  • 关键步骤详解:

    1. 上图中每个方块看成是一个grad buffer中的一个fp16类型元素,在反向结束以后,grad buffer中有16个fp16类型的元素
    2. 在每一个DP rank上调用reduce-scatter操作
    3. 每个DP rank的grad buffer中都有4个fp16类型元素经过了reduce-scatter操作更新,没更新的12个fp16类型元素等待后续垃圾回收
    4. 每个DP rank从grad buffer中拷贝更新后的4个fp16类型元素到fp32类型的main grad buffer中,准备开始后续的更新操作,例如
      • DP rank0拷贝[0:4]个元素
      • DP rank1拷贝[4:8]个元素
      • DP rank2拷贝[8:12]个元素
      • DP rank3拷贝[12:16]个元素
    5. 执行Optimizer.step(), step()操作必须通过fp32类型来进行计算
    6. 每个DP rank从main grad buffer中拷贝step()更新后的4个fp32类型元素到fp16类型的grad buffer中
    7. 执行allgather操作, 这样每个grad buffer就都是最新更新后的数据了
    8. 基于grad buffer来更新各个模型的fp16类型的参数
    9. 开始进行下一轮的更新

3. 源码实现

3.1 程序入口

  • 初始化的入口在文件megatron/training.pyget_model函数中,在创建LocalDDP的实例中会传入args.use_contiguous_buffers_in_local_ddp
python 复制代码
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP

def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True):
    ...
    if wrap_with_ddp:
        if args.DDP_impl == 'torch':
            ...
        elif args.DDP_impl == 'local':
            model = [LocalDDP(model_module,
                              args.accumulate_allreduce_grads_in_fp32,
                              args.use_contiguous_buffers_in_local_ddp)
                     for model_module in model]
    ...
  • 训练的入口定义在train_step函数中, 基本流程如下:
python 复制代码
def train_step(forward_step_func, data_iterator,
               model, optimizer, opt_param_scheduler):
    ...
    
    # 清除grad
    if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_local_ddp:
        for partition in model:
            partition.zero_grad_buffer()
    optimizer.zero_grad()
    ...
    
    # 执行前反向计算
    losses_reduced = forward_backward_func(...)
    ...
    
    # 对梯度执行Reduce-Scatter操作
    optimizer.reduce_model_grads(args, timers)
    ...
    
    # 更新梯度
    timers('optimizer', log_level=1).start(barrier=args.barrier_with_L1_time)
    update_successful, grad_norm, num_zeros_in_grad = optimizer.step(args, timers)
    timers('optimizer').stop()
    ...
    
    # 对更新后的param执行gather操作
    if update_successful:
        optimizer.gather_model_params(args, timers)
    ...
    
    # 通过scheduler更新学习率
    if update_successful:
        increment = get_num_microbatches() * \
                    args.micro_batch_size * \
                    args.data_parallel_size
        opt_param_scheduler.step(increment=increment)
        skipped_iter = 0
    else:
        skipped_iter = 1
    ...

3.2 grad buffer初始化(DistributedDataParallel类)

  • grad buffer初始化是在类DistributedDataParallel的init函数中, 源码定义在megatron/optimizer/distrib_optimizer.py文件中。
python 复制代码
class DistributedDataParallel(DistributedDataParallelBase):
    def __init__(self, module,
                 accumulate_allreduce_grads_in_fp32,
                 use_contiguous_buffers):
  • 创建grad buffer和index map
python 复制代码
            self._grad_buffers = {}
            self._grad_buffer_param_index_map = {}
            data_parallel_world_size = mpu.get_data_parallel_world_size()
  • 按类型分别计算每个类型元素的个数,使用type_num_elements map进行存储,key是元素类型,value是类型出现的元素个数
python 复制代码
            # First calculate total number of elements per type.
            type_num_elements = {}
            for param in self.module.parameters():
                if param.requires_grad:
                    dtype = _get_buffer_type(param)
                    type_num_elements[dtype] = type_num_elements.get(dtype, 0) \
                                               + param.data.nelement()
  • 实际开始分配grad buffer, 为了支持被DP并行数正好切分,需要先对每个类型出现的个数进行padding操作;然后通过MemoryBuffer进行存储的分配
python 复制代码
            # Allocate the buffer.
            for dtype, num_elements in type_num_elements.items():

                # If using distributed optimizer, pad memory buffer to be
                # multiple of data_parallel_world_size. (This padding is done
                # due to a constraint with the reduce_scatter op, which requires
                # all tensors have equal size. See: optimizer.py.)
                num_elements_padded = data_parallel_world_size * \
                    int(math.ceil(num_elements / data_parallel_world_size))

                # Allocate grad buffer.
                self._grad_buffers[dtype] = MemoryBuffer(num_elements,
                                                         num_elements_padded,
                                                         dtype)
  • 从grad buffer中给每一个param参数分配对应的main_grad空间,在分配main_grad时根据每个param参数的类型从对应的self._grad_buffers[dtype]中得到跟param.data.shape一样的tensor,这里的tensor与grad buffer共享存储。同时grad buffer的分配是按倒序来分配的,比如self.module.parameters()中有三个参数分别是[p1, p2, p3], 在grad buffer中存储则是[p3_grad, p2_grad, p1_grad]_grad_buffer_param_index_map用来记录每个param的梯度在grad buffer中存储的起始和结束位置。
python 复制代码
            ...
            # Assume the back prop order is reverse the params order,
            # store the start index for the gradients.
            for param in self.module.parameters():
                if param.requires_grad:
                    dtype = _get_buffer_type(param)
                    type_num_elements[dtype] -= param.data.nelement()
                    # get的第二个参数是start_index,这里的start_index是从grad_buffer从大到小来算的
                    param.main_grad = self._grad_buffers[dtype].get(
                        param.data.shape, type_num_elements[dtype])
                    if dtype not in self._grad_buffer_param_index_map:
                        self._grad_buffer_param_index_map[dtype] = {}
                    self._grad_buffer_param_index_map[dtype][param] = (
                        type_num_elements[dtype],
                        type_num_elements[dtype] + param.data.nelement(),
                    )
  • 遍历每一个参数,对于每一个参数的grad_fn的下一个function累加grad_acc函数进行改写,由于param本身没有grad_fn,通过trick方式使用param.expand_as给param加上了grad_fn函数。
python 复制代码
            ...
            # Backward hook.
            # Accumalation function for the gradients. We need
            # to store them so they don't go out of scope.
            self.grad_accs = []
            # Loop over all the parameters in the model.
            for param in self.module.parameters():
                if param.requires_grad:
                    # 使用expand_as使param具有grad_fn.
                    param_tmp = param.expand_as(param)
                    # 获取梯度累加函数,并注册hook改写
                    grad_acc = param_tmp.grad_fn.next_functions[0][0]
                    grad_acc.register_hook(self._make_param_hook(param))
                    self.grad_accs.append(grad_acc)
    
    def _make_param_hook(self, param):
        """Create the all-reduce hook for backprop."""
        # Hook used for back-prop.
        def param_hook(*unused):
            # Add the gradient to the buffer.
            if param.grad is not None:
                # The gradient function of linear layers is fused with GEMMs
                param.main_grad.add_(param.grad.data)
                # Now we can deallocate grad memory.
                param.grad = None
        return param_hook

4. 参考

相关推荐
渡我白衣1 小时前
多路转接之epoll:理论篇
人工智能·神经网络·网络协议·tcp/ip·自然语言处理·信息与通信·tcpdump
明月照山海-1 小时前
机器学习周报二十八
人工智能·机器学习
weixin_437497777 小时前
读书笔记:Context Engineering 2.0 (上)
人工智能·nlp
喝拿铁写前端7 小时前
前端开发者使用 AI 的能力层级——从表面使用到工程化能力的真正分水岭
前端·人工智能·程序员
goodfat7 小时前
Win11如何关闭自动更新 Win11暂停系统更新的设置方法【教程】
人工智能·禁止windows更新·win11优化工具
北京领雁科技7 小时前
领雁科技反洗钱案例白皮书暨人工智能在反洗钱系统中的深度应用
人工智能·科技·安全
落叶,听雪7 小时前
河南建站系统哪个好
大数据·人工智能·python
清月电子8 小时前
杰理AC109N系列AC1082 AC1074 AC1090 芯片停产替代及资料说明
人工智能·单片机·嵌入式硬件·物联网
Dev7z8 小时前
非线性MPC在自动驾驶路径跟踪与避障控制中的应用及Matlab实现
人工智能·matlab·自动驾驶
七月shi人8 小时前
AI浪潮下,前端路在何方
前端·人工智能·ai编程