[源码解析] 模型并行分布式训练Megatron (2) --- 整体架构
目录
- [[源码解析] 模型并行分布式训练Megatron (2) --- 整体架构](#[源码解析] 模型并行分布式训练Megatron (2) --- 整体架构)
- [0x00 摘要](#0x00 摘要)
- [0x01 启动](#0x01 启动)
- [1.1 分布式启动](#1.1 分布式启动)
- [1.2 构造基础](#1.2 构造基础)
- [1.2.1 获取模型](#1.2.1 获取模型)
- [1.2.2 获取数据集](#1.2.2 获取数据集)
- [1.2.3 步进函数](#1.2.3 步进函数)
- [1.2.3.1 广播数据](#1.2.3.1 广播数据)
- [0x02 Pretrain](#0x02 Pretrain)
- [0x03 初始化](#0x03 初始化)
- [3.1 initialize_megatron](#3.1 initialize_megatron)
- [3.2 初始化分布式环境](#3.2 初始化分布式环境)
- [3.3 初始化进程组全局变量](#3.3 初始化进程组全局变量)
- [0x04 设置模型](#0x04 设置模型)
- [4.1 setup_model_and_optimizer](#4.1 setup_model_and_optimizer)
- [4.2 模型](#4.2 模型)
- [4.2.1 BertModel](#4.2.1 BertModel)
- [4.2.2 语言模型](#4.2.2 语言模型)
- [4.2.3 ParallelTransformer](#4.2.3 ParallelTransformer)
- [4.2.3.1 获取层数](#4.2.3.1 获取层数)
- [4.2.3.2 前向传播](#4.2.3.2 前向传播)
- [4.3 get_model](#4.3 get_model)
- [0x05 数据并行](#0x05 数据并行)
- [5.1 设置数据](#5.1 设置数据)
- [5.2 DDP](#5.2 DDP)
- [5.2.1 定义](#5.2.1 定义)
- [5.2.2 初始化](#5.2.2 初始化)
- [5.2.3 内存](#5.2.3 内存)
- [5.2.4 支撑函数](#5.2.4 支撑函数)
- [5.2.5 梯度规约](#5.2.5 梯度规约)
- [0x06 训练](#0x06 训练)
- [6.1 训练主体](#6.1 训练主体)
- [6.2 训练step](#6.2 训练step)
- [6.3 获取schedule](#6.3 获取schedule)
- [0xFF 参考](#0xFF 参考)
0x00 摘要
NVIDIA Megatron 是一个基于 PyTorch 的分布式训练框架,用来训练超大Transformer语言模型,其通过综合应用了数据并行,Tensor并行和Pipeline并行来复现 GPT3,值得我们深入分析其背后机理。
本系列大概有6~7篇文章,通过论文和源码和大家一起学习研究。本文将对 Megatron 的基本架构做一下梳理。
本系列其他文章为:
[源码解析\] 模型并行分布式训练Megatron (1) --- 论文 \& 基础](https://www.cnblogs.com/rossiXYZ/p/15840803.html)
### 0x01 启动
#### 1.1 分布式启动
启动脚本在 examples/pretrain_bert_distributed.sh,其利用了 torch.distributed.launch 来启动多个进程。具体业务代码是 pretrain_bert.py。
因为 GPUS_PER_NODE 是8,所以 nproc_per_node 是8,这样,在本机上就启动了8个进程,每个进程之中含有**模型的一部分** 。++进程的 rank 是被 torch.distributed.launch 调用 elastic 自动分配的++。
```shell
#!/bin/bash
```
GPUS_PER_NODE=8
# Change for multinode config
MASTER_ADDR=localhost
MASTER_PORT=6000
NNODES=1
NODE_RANK=0
WORLD_SIZE= ( ( (( ((GPUS_PER_NODE\*$NNODES))
DATA_PATH=\
- _EMBEDDING_GROUP : 嵌入对应的进程组。
- _DATA_PARALLEL_GROUP :当前 rank 所属于的Data parallel group。
- 假如数据并行度数为2,则例子为[g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]。
python
# Intra-layer model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None
# Inter-layer model parallel group that the current rank belongs to.
_PIPELINE_MODEL_PARALLEL_GROUP = None
# Model parallel group (both intra- and pipeline) that the current rank belongs to.
_MODEL_PARALLEL_GROUP = None
# Embedding group.
_EMBEDDING_GROUP = None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None
0x04 设置模型
在 Pretrain 之中,会调用如下来设置模型,优化器等等。
python
# Model, optimizer, and learning rate. 使用model_provider设置模型、优化器和lr计划
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider,
model_type)
4.1 setup_model_and_optimizer
setup_model_and_optimizer 方法会设置模型和优化器,其中重点是get_model。
python
def setup_model_and_optimizer(model_provider_func, model_type):
"""Setup model and optimizer."""
args = get_args()
model = get_model(model_provider_func, model_type)
unwrapped_model = unwrap_model(model,
(torchDDP, LocalDDP, Float16Module))
optimizer = get_megatron_optimizer(unwrapped_model)
lr_scheduler = get_learning_rate_scheduler(optimizer)
`<span class="hljs-keyword">if</span> args.load <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span>:
timers = get_timers()
<span class="hljs-comment"># Extra barrier is added to make sure all ranks report the</span>
<span class="hljs-comment"># max time.</span>
torch.distributed.barrier()
args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
torch.distributed.barrier()
<span class="hljs-keyword">else</span>:
args.iteration = <span class="hljs-number">0</span>
<span class="hljs-comment"># We only support local DDP with multiple micro-batches.</span>
<span class="hljs-keyword">if</span> <span class="hljs-built_in">len</span>(model) > <span class="hljs-number">1</span> <span class="hljs-keyword">or</span> mpu.get_pipeline_model_parallel_world_size() > <span class="hljs-number">1</span>:
<span class="hljs-keyword">assert</span> args.DDP_impl == <span class="hljs-string">'local'</span>
<span class="hljs-comment"># get model without FP16 and/or TorchDDP wrappers</span>
<span class="hljs-keyword">if</span> args.iteration == <span class="hljs-number">0</span> <span class="hljs-keyword">and</span> <span class="hljs-built_in">len</span>(unwrapped_model) == <span class="hljs-number">1</span> \
<span class="hljs-keyword">and</span> <span class="hljs-built_in">hasattr</span>(unwrapped_model[<span class="hljs-number">0</span>], <span class="hljs-string">'init_state_dict_from_bert'</span>):
unwrapped_model[<span class="hljs-number">0</span>].init_state_dict_from_bert()
<span class="hljs-keyword">if</span> args.fp16:
optimizer.reload_model_params()
<span class="hljs-keyword">return</span> model, optimizer, lr_scheduler
`
4.2 模型
4.2.1 BertModel
我们首先看看 BertModel 的初始化函数,略过其他功能函数。其主要调用了 get_language_model。
python
class BertModel(MegatronModule):
"""Bert Language model."""
`<span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self,
num_tokentypes=<span class="hljs-number">2</span>,
add_binary_head=<span class="hljs-literal">True</span>,
parallel_output=<span class="hljs-literal">True</span>,
pre_process=<span class="hljs-literal">True</span>,
post_process=<span class="hljs-literal">True</span></span>):
<span class="hljs-built_in">super</span>(BertModel, self).__init__()
args = get_args()
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.add_binary_head = add_binary_head
self.parallel_output = parallel_output
self.pre_process = pre_process
self.post_process = post_process
init_method = init_method_normal(args.init_method_std)
scaled_init_method = scaled_init_method_normal(args.init_method_std,
args.num_layers)
<span class="hljs-comment"># 获取语言模型</span>
self.language_model, self._language_model_key = get_language_model(
num_tokentypes=num_tokentypes,
add_pooler=self.add_binary_head,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
scaled_init_method=scaled_init_method,
pre_process=self.pre_process,
post_process=self.post_process)
self.initialize_word_embeddings(init_method_normal)
<span class="hljs-keyword">if</span> self.post_process: <span class="hljs-comment"># 如果是最后一层,会特殊处理</span>
self.lm_head = BertLMHead(
self.word_embeddings_weight().size(<span class="hljs-number">0</span>),
args.hidden_size, init_method, args.layernorm_epsilon, parallel_output)
self._lm_head_key = <span class="hljs-string">'lm_head'</span>
self.binary_head = <span class="hljs-literal">None</span>
<span class="hljs-keyword">if</span> self.add_binary_head:
self.binary_head = get_linear_layer(args.hidden_size, <span class="hljs-number">2</span>,
init_method)
self._binary_head_key = <span class="hljs-string">'binary_head'</span>
`
4.2.2 语言模型
get_language_model 会获取一个 TransformerLanguageModel。
python
def get_language_model(num_tokentypes, add_pooler,
encoder_attn_mask_type, init_method=None,
scaled_init_method=None, add_encoder=True,
add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal,
pre_process=True, post_process=True):
"""Build language model and return along with the key to save."""
args = get_args()
`<span class="hljs-keyword">if</span> init_method <span class="hljs-keyword">is</span> <span class="hljs-literal">None</span>:
init_method = init_method_normal(args.init_method_std)
<span class="hljs-keyword">if</span> scaled_init_method <span class="hljs-keyword">is</span> <span class="hljs-literal">None</span>:
scaled_init_method = scaled_init_method_normal(args.init_method_std,
args.num_layers)
<span class="hljs-comment"># Language model.</span>
language_model = TransformerLanguageModel(
init_method,
scaled_init_method,
encoder_attn_mask_type,
num_tokentypes=num_tokentypes,
add_encoder=add_encoder,
add_decoder=add_decoder,
decoder_attn_mask_type=decoder_attn_mask_type,
add_pooler=add_pooler,
pre_process=pre_process,
post_process=post_process
)
<span class="hljs-comment"># key used for checkpoints.</span>
language_model_key = <span class="hljs-string">'language_model'</span>
<span class="hljs-keyword">return</span> language_model, language_model_key
`
TransformerLanguageModel 就是具体的语言模型,其中重要的是 ParallelTransformer。这里会依据传入的配置来进行生成。
- 如果是第一层,即有 pre_process,则会加入 embedding layer。
- 如果是中间层,则会根据 encoder 还是 decoder 来生成对应的 ParallelTransformer。
- 如果是最后一层,即有 post_process,则会加入 Pooler,在外层 BertModel 也会有对应处理。
python
class TransformerLanguageModel(MegatronModule):
"""Transformer language model.
`Arguments:
transformer_hparams: transformer hyperparameters
vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This
is used for positional embedding
embedding_dropout_prob: dropout probability for embeddings
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""</span>
<span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self,
init_method,
output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=<span class="hljs-number">0</span>,
add_encoder=<span class="hljs-literal">True</span>,
add_decoder=<span class="hljs-literal">False</span>,
decoder_attn_mask_type=AttnMaskType.causal,
add_pooler=<span class="hljs-literal">False</span>,
pre_process=<span class="hljs-literal">True</span>,
post_process=<span class="hljs-literal">True</span></span>):
<span class="hljs-built_in">super</span>(TransformerLanguageModel, self).__init__()
args = get_args()
self.pre_process = pre_process
self.post_process = post_process
self.hidden_size = args.hidden_size
self.num_tokentypes = num_tokentypes
self.init_method = init_method
self.add_encoder = add_encoder
self.encoder_attn_mask_type = encoder_attn_mask_type
self.add_decoder = add_decoder
self.decoder_attn_mask_type = decoder_attn_mask_type
self.add_pooler = add_pooler
self.encoder_hidden_state = <span class="hljs-literal">None</span>
<span class="hljs-comment"># Embeddings.</span>
<span class="hljs-keyword">if</span> self.pre_process:
self.embedding = Embedding(self.hidden_size,
args.padded_vocab_size,
args.max_position_embeddings,
args.hidden_dropout,
self.init_method,
self.num_tokentypes)
self._embedding_key = <span class="hljs-string">'embedding'</span>
<span class="hljs-comment"># Transformer.</span>
<span class="hljs-comment"># Encoder (usually set to True, False if part of an encoder-decoder</span>
<span class="hljs-comment"># architecture and in encoder-only stage).</span>
<span class="hljs-keyword">if</span> self.add_encoder:
self.encoder = ParallelTransformer(
self.init_method,
output_layer_init_method,
self_attn_mask_type=self.encoder_attn_mask_type,
pre_process=self.pre_process,
post_process=self.post_process
)
self._encoder_key = <span class="hljs-string">'encoder'</span>
<span class="hljs-keyword">else</span>:
self.encoder = <span class="hljs-literal">None</span>
<span class="hljs-comment"># Decoder (usually set to False, True if part of an encoder-decoder</span>
<span class="hljs-comment"># architecture and in decoder-only stage).</span>
<span class="hljs-keyword">if</span> self.add_decoder:
<span class="hljs-comment"># Temporary assertion until we verify correctness of pipeline parallelism</span>
<span class="hljs-comment"># implementation of T5.</span>
self.decoder = ParallelTransformer(
self.init_method,
output_layer_init_method,
layer_type=LayerType.decoder,
self_attn_mask_type=self.decoder_attn_mask_type,
pre_process=self.pre_process,
post_process=self.post_process)
self._decoder_key = <span class="hljs-string">'decoder'</span>
<span class="hljs-keyword">else</span>:
self.decoder = <span class="hljs-literal">None</span>
<span class="hljs-keyword">if</span> self.post_process:
<span class="hljs-comment"># Pooler.</span>
<span class="hljs-keyword">if</span> self.add_pooler:
self.pooler = Pooler(self.hidden_size, self.init_method)
self._pooler_key = <span class="hljs-string">'pooler'</span>
`
4.2.3 ParallelTransformer
这里会调用 ParallelTransformerLayer 生成具体的 Transformer层,我们会在后文中进行分析。
即,++ParallelTransformer 包括多个 Transformer,其中每层 Transformer 是一个 ParallelTransformerLayer++。
python
class ParallelTransformer(MegatronModule):
"""Transformer class."""
`<span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self, init_method, output_layer_init_method,
layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
pre_process=<span class="hljs-literal">True</span>, post_process=<span class="hljs-literal">True</span></span>):
<span class="hljs-built_in">super</span>(ParallelTransformer, self).__init__()
args = get_args()
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
self.pre_process = pre_process
self.post_process = post_process
self.input_tensor = <span class="hljs-literal">None</span>
<span class="hljs-comment"># Store activation checkpoiting flag.</span>
self.activations_checkpoint_method = args.activations_checkpoint_method
self.activations_checkpoint_num_layers = args.activations_checkpoint_num_layers
self.distribute_checkpointed_activations = args.distribute_checkpointed_activations
<span class="hljs-comment"># Number of layers.</span>
self.num_layers = mpu.get_num_layers( <span class="hljs-comment"># 获得本Transformer的具体层数</span>
args, args.model_type == ModelType.encoder_and_decoder)
<span class="hljs-comment"># Transformer layers.</span>
<span class="hljs-keyword">def</span> <span class="hljs-title function_">build_layer</span>(<span class="hljs-params">layer_number</span>):
<span class="hljs-keyword">return</span> ParallelTransformerLayer( <span class="hljs-comment"># 返回一层 Transformmer</span>
init_method,
output_layer_init_method,
layer_number,
layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type)
<span class="hljs-keyword">if</span> args.virtual_pipeline_model_parallel_size <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span>:
<span class="hljs-comment"># Number of layers in each model chunk is the number of layers in the stage,</span>
<span class="hljs-comment"># divided by the number of model chunks in a stage.</span>
self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
<span class="hljs-comment"># With 8 layers, 2 stages, and 4 model chunks, we want an assignment of</span>
<span class="hljs-comment"># layers to stages like (each list is a model chunk):</span>
<span class="hljs-comment"># Stage 0: [0] [2] [4] [6]</span>
<span class="hljs-comment"># Stage 1: [1] [3] [5] [7]</span>
<span class="hljs-comment"># With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of</span>
<span class="hljs-comment"># layers to stages like (each list is a model chunk):</span>
<span class="hljs-comment"># Stage 0: [0, 1] [4, 5]</span>
<span class="hljs-comment"># Stage 1: [2, 3] [6, 7]</span>
offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
args.num_layers // args.virtual_pipeline_model_parallel_size) + \
(mpu.get_pipeline_model_parallel_rank() * self.num_layers)
<span class="hljs-keyword">else</span>:
<span class="hljs-comment"># Each stage gets a contiguous set of layers.</span>
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
self.layers = torch.nn.ModuleList( <span class="hljs-comment"># 生成 num_layers 个 Transformer</span>
[build_layer(i + <span class="hljs-number">1</span> + offset) <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(self.num_layers)])
<span class="hljs-keyword">if</span> self.post_process:
<span class="hljs-comment"># Final layer norm before output.</span>
self.final_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
`
目前逻辑如下,我们假定有两个 transformer:

4.2.3.1 获取层数
这里一个重点就是获取层数,即获取本模型在并行处理状况下,应该拥有多少层。如果模型一共64层,流水线深度为16,则并行每个阶段有4层,则本子模型拥有4层。
python
def get_num_layers(args, is_encoder_and_decoder_model):
"""Compute the number of transformer layers resident on the current rank."""
if get_pipeline_model_parallel_world_size() > 1:
if is_encoder_and_decoder_model:
assert args.pipeline_model_parallel_split_rank is not None
num_ranks_in_encoder = args.pipeline_model_parallel_split_rank
num_ranks_in_decoder = get_pipeline_model_parallel_world_size() - num_ranks_in_encoder
if is_pipeline_stage_before_split():
num_layers = args.num_layers // num_ranks_in_encoder
else:
num_layers = args.num_layers // num_ranks_in_decoder
else:
num_layers = args.num_layers // get_pipeline_model_parallel_world_size()
else:
num_layers = args.num_layers
return num_layers
get_pipeline_model_parallel_world_size 获取本流水线组world size数目,就是流水线深度。
python
def get_pipeline_model_parallel_world_size():
"""Return world size for the pipeline model parallel group."""
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None:
return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group())
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE 的意思是流水线深度 p,就是纵向切 p-1刀。比如一共 12 层,纵向切 5 刀,则有 6 个stage,每个 stage 有 2 层。
4.2.3.2 前向传播
我们接着看看其前向传播函数,这里主要就是调用内部 ParallelTransformerLayer 的 forward 方法,如果是第一层或者最后一层,则做特殊处理。
python
def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None,
inference_params=None):
`<span class="hljs-keyword">if</span> self.pre_process:
<span class="hljs-comment"># Data format change to avoid explicit tranposes : [b s h] --> [s b h].</span>
<span class="hljs-comment"># If the input flag for fp32 residual connection is set, convert for float.</span>
<span class="hljs-keyword">if</span> self.fp32_residual_connection:
hidden_states = hidden_states.transpose(<span class="hljs-number">0</span>, <span class="hljs-number">1</span>).contiguous().<span class="hljs-built_in">float</span>()
<span class="hljs-comment"># Otherwise, leave it as is.</span>
<span class="hljs-keyword">else</span>:
hidden_states = hidden_states.transpose(<span class="hljs-number">0</span>, <span class="hljs-number">1</span>).contiguous()
<span class="hljs-keyword">else</span>:
<span class="hljs-comment"># See set_input_tensor()</span>
hidden_states = self.input_tensor
<span class="hljs-keyword">if</span> encoder_output <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span>:
encoder_output = encoder_output.transpose(<span class="hljs-number">0</span>, <span class="hljs-number">1</span>).contiguous()
<span class="hljs-keyword">if</span> self.activations_checkpoint_method <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span>:
hidden_states = self._checkpointed_forward(hidden_states,
attention_mask,
encoder_output,
enc_dec_attn_mask)
<span class="hljs-keyword">else</span>:
<span class="hljs-keyword">for</span> index <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(self.num_layers):
layer = self._get_layer(index)
hidden_states = layer( <span class="hljs-comment"># 调用ParallelTransformerLayer的forward函数</span>
hidden_states,
attention_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params)
<span class="hljs-comment"># Final layer norm.</span>
<span class="hljs-keyword">if</span> self.post_process:
<span class="hljs-comment"># Reverting data format change [s b h] --> [b s h].</span>
hidden_states = hidden_states.transpose(<span class="hljs-number">0</span>, <span class="hljs-number">1</span>).contiguous()
output = self.final_layernorm(hidden_states)
<span class="hljs-keyword">else</span>:
output = hidden_states
<span class="hljs-keyword">return</span> output
`
4.3 get_model
现在让我们回到 get_model,把生成模型的流程整理出来。
BERT之中含有多个transformer,所以直接按照层数切分,每一层是一模一样的transformer layer。前面提到了,++在我们样例之中启动了8个进程,每个进程里面有一个子模型,即原始BERT模型的部分层++。但是怎么知道每个子模型包含了多少层?答案是:因为已经建立了各种进程组,所以 get_model 方法会依据目前进程组情况进行处理。单个进程内模型获取如下:
- 如果是有 virtual 设置,则会遍历 virtual size,生成对应数目的模型(BertModel)。
- 否则如果是 encoder_and_decoder,则针对split进行配置。
- 设置 tensor model parallel 属性。
- 把本模型放置到GPU之上。
- 如果需要数据并行,则配置DDP。
具体代码如下:
python
def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True):
"""Build the model."""
args = get_args()
args.model_type = model_type
`<span class="hljs-comment"># Build model.</span>
<span class="hljs-keyword">if</span> mpu.get_pipeline_model_parallel_world_size() > <span class="hljs-number">1</span> <span class="hljs-keyword">and</span> \
args.virtual_pipeline_model_parallel_size <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span>: <span class="hljs-comment"># 有virtual设置,后续会提到</span>
model = []
<span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(args.virtual_pipeline_model_parallel_size): <span class="hljs-comment"># 遍历virtual</span>
<span class="hljs-comment"># 设置rank,主要是为了看是不是第一层,最后一层</span>
mpu.set_virtual_pipeline_model_parallel_rank(i)
<span class="hljs-comment"># Set pre_process and post_process only after virtual rank is set.</span>
pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage()
this_model = model_provider_func( <span class="hljs-comment"># 获取原始模型 BertModel</span>
pre_process=pre_process,
post_process=post_process
)
this_model.model_type = model_type
model.append(this_model) <span class="hljs-comment"># 模型列表之中添加一个新的 BertModel</span>
<span class="hljs-keyword">else</span>:
pre_process = mpu.is_pipeline_first_stage() <span class="hljs-comment"># 是不是第一层</span>
post_process = mpu.is_pipeline_last_stage() <span class="hljs-comment"># 是不是最后一层</span>
add_encoder = <span class="hljs-literal">True</span>
add_decoder = <span class="hljs-literal">True</span>
<span class="hljs-keyword">if</span> model_type == ModelType.encoder_and_decoder:
<span class="hljs-keyword">if</span> mpu.get_pipeline_model_parallel_world_size() > <span class="hljs-number">1</span>:
rank = mpu.get_pipeline_model_parallel_rank()
split_rank = args.pipeline_model_parallel_split_rank
world_size = mpu.get_pipeline_model_parallel_world_size()
pre_process = rank == <span class="hljs-number">0</span> <span class="hljs-keyword">or</span> rank == split_rank <span class="hljs-comment"># 是不是第一层</span>
post_process = (rank == (split_rank - <span class="hljs-number">1</span>)) <span class="hljs-keyword">or</span> ( <span class="hljs-comment"># 是不是最后一层</span>
rank == (world_size - <span class="hljs-number">1</span>))
add_encoder = mpu.is_pipeline_stage_before_split()
add_decoder = mpu.is_pipeline_stage_after_split()
model = model_provider_func( <span class="hljs-comment"># 获取原始模型</span>
pre_process=pre_process,
post_process=post_process,
add_encoder=add_encoder,
add_decoder=add_decoder)
<span class="hljs-keyword">else</span>:
model = model_provider_func( <span class="hljs-comment"># 获取原始模型</span>
pre_process=pre_process,
post_process=post_process
)
model.model_type = model_type
<span class="hljs-keyword">if</span> <span class="hljs-keyword">not</span> <span class="hljs-built_in">isinstance</span>(model, <span class="hljs-built_in">list</span>):
model = [model]
<span class="hljs-comment"># Set tensor model parallel attributes if not set.</span>
<span class="hljs-comment"># Only parameters that are already tensor model parallel have these</span>
<span class="hljs-comment"># attributes set for them. We should make sure the default attributes</span>
<span class="hljs-comment"># are set for all params so the optimizer can use them.</span>
<span class="hljs-keyword">for</span> model_module <span class="hljs-keyword">in</span> model:
<span class="hljs-keyword">for</span> param <span class="hljs-keyword">in</span> model_module.parameters():
mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
<span class="hljs-comment"># GPU allocation.</span>
<span class="hljs-keyword">for</span> model_module <span class="hljs-keyword">in</span> model: <span class="hljs-comment"># 把本模型放置到GPU之上</span>
model_module.cuda(torch.cuda.current_device())
<span class="hljs-comment"># Fp16 conversion.</span>
<span class="hljs-keyword">if</span> args.fp16 <span class="hljs-keyword">or</span> args.bf16:
model = [Float16Module(model_module, args) <span class="hljs-keyword">for</span> model_module <span class="hljs-keyword">in</span> model]
<span class="hljs-keyword">if</span> wrap_with_ddp: <span class="hljs-comment"># 如果需要数据并行,则配置DDP</span>
<span class="hljs-keyword">if</span> args.DDP_impl == <span class="hljs-string">'torch'</span>:
i = torch.cuda.current_device()
model = [torchDDP(model_module, device_ids=[i], output_device=i,
process_group=mpu.get_data_parallel_group())
<span class="hljs-keyword">for</span> model_module <span class="hljs-keyword">in</span> model]
<span class="hljs-keyword">elif</span> args.DDP_impl == <span class="hljs-string">'local'</span>:
model = [LocalDDP(model_module,
args.accumulate_allreduce_grads_in_fp32,
args.use_contiguous_buffers_in_local_ddp)
<span class="hljs-keyword">for</span> model_module <span class="hljs-keyword">in</span> model]
<span class="hljs-keyword">else</span>:
<span class="hljs-keyword">raise</span> NotImplementedError(<span class="hljs-string">'Unknown DDP implementation specified: '</span>
<span class="hljs-string">'{}. Exiting.'</span>.<span class="hljs-built_in">format</span>(args.DDP_impl))
<span class="hljs-keyword">return</span> model
`
单个进程内的逻辑大致如下,这里 torchDDP 的意思是把 BertModel 之中的 module 用 torchDDP 来封装。

0x05 数据并行
5.1 设置数据
build_train_valid_test_data_iterators 方法会对数据进行处理,提供了 train,valid,test 三种不同的数据集。
python
def build_train_valid_test_data_iterators(
build_train_valid_test_datasets_provider):
"""XXX"""
args = get_args()
(train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)
`<span class="hljs-comment"># Backward compatibility, assume fixed batch size.</span>
<span class="hljs-keyword">if</span> args.iteration > <span class="hljs-number">0</span> <span class="hljs-keyword">and</span> args.consumed_train_samples == <span class="hljs-number">0</span>:
args.consumed_train_samples = args.iteration * args.global_batch_size
<span class="hljs-keyword">if</span> args.iteration > <span class="hljs-number">0</span> <span class="hljs-keyword">and</span> args.consumed_valid_samples == <span class="hljs-number">0</span>:
<span class="hljs-keyword">if</span> args.train_samples <span class="hljs-keyword">is</span> <span class="hljs-literal">None</span>:
args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
args.eval_iters * args.global_batch_size
<span class="hljs-comment"># Data loader only on rank 0 of each model parallel group.</span>
<span class="hljs-keyword">if</span> mpu.get_tensor_model_parallel_rank() == <span class="hljs-number">0</span>:
<span class="hljs-comment"># Number of train/valid/test samples.</span>
<span class="hljs-keyword">if</span> args.train_samples:
train_samples = args.train_samples
<span class="hljs-keyword">else</span>:
train_samples = args.train_iters * args.global_batch_size
eval_iters = (args.train_iters // args.eval_interval + <span class="hljs-number">1</span>) * \
args.eval_iters
test_iters = args.eval_iters
train_val_test_num_samples = [train_samples,
eval_iters * args.global_batch_size,
test_iters * args.global_batch_size]
<span class="hljs-comment"># Build the datasets.</span>
train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
train_val_test_num_samples)
<span class="hljs-comment"># Build dataloders.</span>
train_dataloader = build_pretraining_data_loader(
train_ds, args.consumed_train_samples)
valid_dataloader = build_pretraining_data_loader(
valid_ds, args.consumed_valid_samples)
test_dataloader = build_pretraining_data_loader(test_ds, <span class="hljs-number">0</span>)
<span class="hljs-comment"># Flags to know if we need to do training/validation/testing.</span>
do_train = train_dataloader <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span> <span class="hljs-keyword">and</span> args.train_iters > <span class="hljs-number">0</span>
do_valid = valid_dataloader <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span> <span class="hljs-keyword">and</span> args.eval_iters > <span class="hljs-number">0</span>
do_test = test_dataloader <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span> <span class="hljs-keyword">and</span> args.eval_iters > <span class="hljs-number">0</span>
<span class="hljs-comment"># Need to broadcast num_tokens and num_type_tokens.</span>
flags = torch.cuda.LongTensor(
[<span class="hljs-built_in">int</span>(do_train), <span class="hljs-built_in">int</span>(do_valid), <span class="hljs-built_in">int</span>(do_test)])
<span class="hljs-keyword">else</span>:
flags = torch.cuda.LongTensor([<span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>])
<span class="hljs-comment"># Broadcast num tokens.</span>
torch.distributed.broadcast(flags,
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
args.do_train = flags[<span class="hljs-number">0</span>].item()
args.do_valid = flags[<span class="hljs-number">1</span>].item()
args.do_test = flags[<span class="hljs-number">2</span>].item()
<span class="hljs-comment"># Build iterators.</span>
dl_type = args.dataloader_type
<span class="hljs-keyword">if</span> train_dataloader <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span>:
train_data_iterator = <span class="hljs-built_in">iter</span>(train_dataloader) <span class="hljs-keyword">if</span> dl_type == <span class="hljs-string">'single'</span> \
<span class="hljs-keyword">else</span> <span class="hljs-built_in">iter</span>(cyclic_iter(train_dataloader))
<span class="hljs-keyword">else</span>:
train_data_iterator = <span class="hljs-literal">None</span>
<span class="hljs-keyword">if</span> valid_dataloader <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span>:
valid_data_iterator = <span class="hljs-built_in">iter</span>(valid_dataloader) <span class="hljs-keyword">if</span> dl_type == <span class="hljs-string">'single'</span> \
<span class="hljs-keyword">else</span> <span class="hljs-built_in">iter</span>(cyclic_iter(valid_dataloader))
<span class="hljs-keyword">else</span>:
valid_data_iterator = <span class="hljs-literal">None</span>
<span class="hljs-keyword">if</span> test_dataloader <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span>:
test_data_iterator = <span class="hljs-built_in">iter</span>(test_dataloader) <span class="hljs-keyword">if</span> dl_type == <span class="hljs-string">'single'</span> \
<span class="hljs-keyword">else</span> <span class="hljs-built_in">iter</span>(cyclic_iter(test_dataloader))
<span class="hljs-keyword">else</span>:
test_data_iterator = <span class="hljs-literal">None</span>
<span class="hljs-keyword">return</span> train_data_iterator, valid_data_iterator, test_data_iterator
`
5.2 DDP
在 get_model 之中,有如下代码使用 DDP。
python
from megatron.model import DistributedDataParallel as LocalDDP
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
if wrap_with_ddp:
if args.DDP_impl == 'torch':
i = torch.cuda.current_device()
model = [torchDDP(model_module, device_ids=[i], output_device=i,
process_group=mpu.get_data_parallel_group())
for model_module in model]
<span class="hljs-keyword">elif</span> args.DDP_impl == <span class="hljs-string">'local'</span>:
model = [LocalDDP(model_module,
args.accumulate_allreduce_grads_in_fp32,
args.use_contiguous_buffers_in_local_ddp)
<span class="hljs-keyword">for</span> model_module <span class="hljs-keyword">in</span> model]
<span class="hljs-keyword">else</span>:
<span class="hljs-keyword">raise</span> NotImplementedError(<span class="hljs-string">'Unknown DDP implementation specified: '</span>
<span class="hljs-string">'{}. Exiting.'</span>.<span class="hljs-built_in">format</span>(args.DDP_impl))
所以我们看看 megatron 自己的 DDP实现。
5.2.1 定义
定义只有注释可以看看,使用连续的(contiguous)内存来存储和累积梯度,每一种类型的张量属于一个统一的内存,可以统一做 allreduce。
python
class DistributedDataParallel(DistributedDataParallelBase):
"""DDP with contiguous buffers options to storre and accumulate gradients.
This class:
- has the potential to reduce memory fragmentation.
- provides the option to do the gradient accumulation
in a type other than the params type (for example fp32)
``Arguments:
module: input model.
accumulate_allreduce_grads_in_fp32: if true do the gradient accumulation
and the gradient all-reduce all in in float32. If this option is
true, we require `use_contiguous_buffers` to be true too.
use_contiguous_buffers: if true, use a contiguous buffer to store the
gradients.
"""</span>
``
5.2.2 初始化
初始化方法的目的是把同类型梯度连续存储。
python
def __init__(self, module,
accumulate_allreduce_grads_in_fp32,
use_contiguous_buffers):
`<span class="hljs-built_in">super</span>(DistributedDataParallel, self).__init__(module)
self.accumulate_allreduce_grads_in_fp32 \
= accumulate_allreduce_grads_in_fp32
self.use_contiguous_buffers = use_contiguous_buffers
<span class="hljs-comment"># If we are using fp32-accumulate-allreduce explicitly</span>
<span class="hljs-comment"># this means we need main grads in a continous buffer.</span>
<span class="hljs-keyword">if</span> self.accumulate_allreduce_grads_in_fp32:
<span class="hljs-keyword">assert</span> self.use_contiguous_buffers
<span class="hljs-comment"># ===================================</span>
<span class="hljs-comment"># Rest of this part applies only to</span>
<span class="hljs-comment"># the case we use continuous buffers.</span>
<span class="hljs-comment"># ===================================</span>
self._grad_buffers = <span class="hljs-literal">None</span>
<span class="hljs-keyword">if</span> self.use_contiguous_buffers: <span class="hljs-comment"># 这里只考虑连续内存</span>
self._grad_buffers = {} <span class="hljs-comment"># 定义buffer</span>
<span class="hljs-comment"># Simple function to define buffer type.</span>
<span class="hljs-keyword">def</span> <span class="hljs-title function_">_get_buffer_type</span>(<span class="hljs-params">param</span>): <span class="hljs-comment"># 返回buffer类型</span>
<span class="hljs-keyword">return</span> torch.<span class="hljs-built_in">float</span> <span class="hljs-keyword">if</span> \
self.accumulate_allreduce_grads_in_fp32 <span class="hljs-keyword">else</span> param.dtype
<span class="hljs-comment"># First calculate total number of elements per type.</span>
type_num_elements = {}
<span class="hljs-keyword">for</span> param <span class="hljs-keyword">in</span> self.module.parameters(): <span class="hljs-comment"># 遍历模型参数</span>
<span class="hljs-keyword">if</span> param.requires_grad: <span class="hljs-comment"># 如果需要计算梯度</span>
dtype = _get_buffer_type(param) <span class="hljs-comment"># 获取参数类型</span>
type_num_elements[dtype] = type_num_elements.get(dtype, <span class="hljs-number">0</span>) \
+ param.data.nelement() <span class="hljs-comment"># 该类型参数数目做相应增加</span>
<span class="hljs-comment"># 目前 type_num_elements 是各种类型参数的个数 </span>
<span class="hljs-comment"># Allocate the buffer.</span>
<span class="hljs-keyword">for</span> dtype, num_elements <span class="hljs-keyword">in</span> type_num_elements.items(): <span class="hljs-comment"># 遍历各种类型</span>
self._grad_buffers[dtype] = MemoryBuffer(num_elements, dtype) <span class="hljs-comment"># 分配内存</span>
<span class="hljs-comment"># 这里是假定反向传播是参数的反方向,存储每个参数梯度的起始位置 </span>
<span class="hljs-comment"># Assume the back prop order is reverse the params order, </span>
<span class="hljs-comment"># store the start index for the gradients.</span>
<span class="hljs-keyword">for</span> param <span class="hljs-keyword">in</span> self.module.parameters(): <span class="hljs-comment"># 遍历模型参数</span>
<span class="hljs-keyword">if</span> param.requires_grad: <span class="hljs-comment"># 如果需要计算梯度</span>
dtype = _get_buffer_type(param) <span class="hljs-comment"># 获取参数类型</span>
type_num_elements[dtype] -= param.data.nelement() <span class="hljs-comment"># 减少size</span>
<span class="hljs-comment"># 确定该参数在MemoryBuffer的位置</span>
param.main_grad = self._grad_buffers[dtype].get( <span class="hljs-comment"># 获取该参数对应的内存</span>
param.data.shape, type_num_elements[dtype])
<span class="hljs-comment"># Backward hook.</span>
<span class="hljs-comment"># Accumalation function for the gradients. We need</span>
<span class="hljs-comment"># to store them so they don't go out of scope.</span>
self.grad_accs = []
<span class="hljs-comment"># Loop over all the parameters in the model.</span>
<span class="hljs-keyword">for</span> param <span class="hljs-keyword">in</span> self.module.parameters(): <span class="hljs-comment"># 遍历模型参数</span>
<span class="hljs-keyword">if</span> param.requires_grad: <span class="hljs-comment"># 如果需要计算梯度</span>
<span class="hljs-comment"># Expand so we get access to grad_fn.</span>
param_tmp = param.expand_as(param)
<span class="hljs-comment"># Get the gradient accumulator functtion.</span>
grad_acc = param_tmp.grad_fn.next_functions[<span class="hljs-number">0</span>][<span class="hljs-number">0</span>] <span class="hljs-comment"># 得到参数对应的梯度函数</span>
grad_acc.register_hook(self._make_param_hook(param)) <span class="hljs-comment"># 注册了hook</span>
self.grad_accs.append(grad_acc) <span class="hljs-comment"># 统一管理梯度函数,其实就是book keeping作用</span>
`
5.2.3 内存
MemoryBuffer 是内存抽象。
python
class MemoryBuffer:
``<span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self, numel, dtype</span>):
self.numel = numel
self.dtype = dtype
self.data = torch.zeros(self.numel, <span class="hljs-comment"># 初始化内存</span>
dtype=self.dtype,
device=torch.cuda.current_device(),
requires_grad=<span class="hljs-literal">False</span>)
<span class="hljs-keyword">def</span> <span class="hljs-title function_">zero</span>(<span class="hljs-params">self</span>):
<span class="hljs-string">"""Reset the buffer to zero."""</span>
self.data.zero_()
<span class="hljs-keyword">def</span> <span class="hljs-title function_">get</span>(<span class="hljs-params">self, shape, start_index</span>):
<span class="hljs-string">"""Return a tensor with the input `shape` as a view into the
1-D data starting at `start_index`."""</span>
end_index = start_index + shape.numel() <span class="hljs-comment"># 定位到该张量在内存buffer之中的位置</span>
<span class="hljs-keyword">assert</span> end_index <= self.numel, \
<span class="hljs-string">'requested tensor is out of the buffer range.'</span>
buffer_tensor = self.data[start_index:end_index] <span class="hljs-comment"># 拿到内存</span>
buffer_tensor = buffer_tensor.view(shape)
<span class="hljs-keyword">return</span> buffer_tensor <span class="hljs-comment"># </span>
``
5.2.4 支撑函数
下面是两个支撑函数,分别是用于拷贝梯度和将buffer清零。
python
def _make_param_hook(self, param):
"""Create the all-reduce hook for backprop."""
# Hook used for back-prop.
def param_hook(*unused):
# Add the gradient to the buffer.
if param.grad.data is not None:
param.main_grad.add_(param.grad.data) # 把梯度拷贝到连续内存之中
# Now we can deallocate grad memory.
param.grad = None
return param_hook
def zero_grad_buffer(self):
"""Set the grad buffer data to zero. Needs to be called at the
begining of each iteration."""
assert self._grad_buffers is not None, 'buffers are not initialized.'
for , buffer in self.grad_buffers.items():
buffer .zero()
我们假定模型有6个参数,3个 fp32,3 个 fp16,所以被组合成两个连续内存 MemoryBuffer。

5.2.5 梯度规约
allreduce_gradients 是 DDP 对外提供的 API,在后面 train step 之中会调用到。
python
def allreduce_gradients(self):
"""Reduce gradients across data parallel ranks."""
# If we have buffers, simply reduce the data in the buffer.
if self._grad_buffers is not None:
# 连续内存
for _, buffer_ in self._grad_buffers.items(): # 遍历各种类型的buffer
buffer_.data /= mpu.get_data_parallel_world_size()
torch.distributed.all_reduce( # 统一归并
buffer_.data, group=mpu.get_data_parallel_group())
else:
# Otherwise, bucketize and all-reduce
buckets = {} # 否则还是用桶来归并
# Pack the buckets.
for param in self.module.parameters(): # 遍历梯度
if param.requires_grad and param.grad is not None:
tp = param.data.type()
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(param) # 同类型的梯度放到对应类型的桶之中
param.main_grad = param.grad
` <span class="hljs-comment"># For each bucket, all-reduce and copy all-reduced grads.</span>
<span class="hljs-keyword">for</span> tp <span class="hljs-keyword">in</span> buckets:
bucket = buckets[tp]
grads = [param.grad.data <span class="hljs-keyword">for</span> param <span class="hljs-keyword">in</span> bucket] <span class="hljs-comment"># 把桶里的梯度拿出来</span>
coalesced = _flatten_dense_tensors(grads) <span class="hljs-comment"># 打平梯度</span>
coalesced /= mpu.get_data_parallel_world_size()
torch.distributed.all_reduce( <span class="hljs-comment"># 归并</span>
coalesced, group=mpu.get_data_parallel_group())
<span class="hljs-keyword">for</span> buf, synced <span class="hljs-keyword">in</span> <span class="hljs-built_in">zip</span>(grads, _unflatten_dense_tensors(
coalesced, grads)):
buf.copy_(synced)
`
运行时候,分别对两种类型的连续内存做 AllReduce。

0x06 训练
Pretrain 之中会调用 train 来进行训练。
python
if args.do_train and args.train_iters > 0:
iteration = train(forward_step_func,
model, optimizer, lr_scheduler,
train_data_iterator, valid_data_iterator)
6.1 训练主体
train 是常规的套路,大家基本上按照名字就可以理解。
python
def train(forward_step_func, model, optimizer, lr_scheduler,
train_data_iterator, valid_data_iterator):
"""Train the model function."""
args = get_args()
timers = get_timers()
`<span class="hljs-comment"># Write args to tensorboard</span>
write_args_to_tensorboard()
<span class="hljs-comment"># Turn on training mode which enables dropout.</span>
<span class="hljs-keyword">for</span> model_module <span class="hljs-keyword">in</span> model:
model_module.train() <span class="hljs-comment"># </span>
<span class="hljs-comment"># Tracking loss.</span>
total_loss_dict = {}
<span class="hljs-comment"># Iterations.</span>
iteration = args.iteration
report_memory_flag = <span class="hljs-literal">True</span>
<span class="hljs-keyword">while</span> iteration < args.train_iters:
update_num_microbatches(args.consumed_train_samples)
loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \
train_step(forward_step_func, <span class="hljs-comment"># 训练</span>
train_data_iterator,
model,
optimizer,
lr_scheduler)
iteration += <span class="hljs-number">1</span>
args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
args.micro_batch_size * \
get_num_microbatches()
<span class="hljs-comment"># Logging.</span>
loss_scale = optimizer.get_loss_scale().item()
params_norm = <span class="hljs-literal">None</span>
<span class="hljs-keyword">if</span> args.log_params_norm:
params_norm = calc_params_l2_norm(model)
report_memory_flag = training_log(loss_dict, total_loss_dict,
optimizer.param_groups[<span class="hljs-number">0</span>][<span class="hljs-string">'lr'</span>],
iteration, loss_scale,
report_memory_flag, skipped_iter,
grad_norm, params_norm, num_zeros_in_grad)
<span class="hljs-comment"># Autoresume</span>
<span class="hljs-keyword">if</span> args.adlr_autoresume <span class="hljs-keyword">and</span> \
(iteration % args.adlr_autoresume_interval == <span class="hljs-number">0</span>):
check_adlr_autoresume_termination(iteration, model, optimizer,
lr_scheduler)
<span class="hljs-comment"># Evaluation</span>
<span class="hljs-keyword">if</span> args.eval_interval <span class="hljs-keyword">and</span> iteration % args.eval_interval == <span class="hljs-number">0</span> <span class="hljs-keyword">and</span> \
args.do_valid:
prefix = <span class="hljs-string">'iteration {}'</span>.<span class="hljs-built_in">format</span>(iteration)
evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model,
iteration, <span class="hljs-literal">False</span>)
<span class="hljs-comment"># Checkpointing</span>
saved_checkpoint = <span class="hljs-literal">False</span>
<span class="hljs-keyword">if</span> args.exit_signal_handler:
signal_handler = get_signal_handler()
<span class="hljs-keyword">if</span> <span class="hljs-built_in">any</span>(signal_handler.signals_received()):
save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler)
sys.exit()
<span class="hljs-keyword">if</span> args.save <span class="hljs-keyword">and</span> args.save_interval <span class="hljs-keyword">and</span> \
iteration % args.save_interval == <span class="hljs-number">0</span>:
save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler)
saved_checkpoint = <span class="hljs-literal">True</span>
<span class="hljs-comment"># Exiting based on duration</span>
<span class="hljs-keyword">if</span> args.exit_duration_in_mins:
train_time = (time.time() - _TRAIN_START_TIME) / <span class="hljs-number">60.0</span>
done_cuda = torch.cuda.IntTensor(
[train_time > args.exit_duration_in_mins])
torch.distributed.all_reduce(
done_cuda, op=torch.distributed.ReduceOp.MAX)
done = done_cuda.item()
<span class="hljs-keyword">if</span> done:
<span class="hljs-keyword">if</span> <span class="hljs-keyword">not</span> saved_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler)
sys.exit()
<span class="hljs-comment"># Exiting based on iterations</span>
<span class="hljs-keyword">if</span> args.exit_interval <span class="hljs-keyword">and</span> iteration % args.exit_interval == <span class="hljs-number">0</span>:
<span class="hljs-keyword">if</span> <span class="hljs-keyword">not</span> saved_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler)
torch.distributed.barrier()
sys.exit()
<span class="hljs-keyword">return</span> iteration
`
6.2 训练step
train_step 会获取 get_forward_backward_func 得到 schedule,因为是流水线并行,所以需要 schedule 如何具体训练。
python
def train_step(forward_step_func, data_iterator,
model, optimizer, lr_scheduler):
"""Single training step."""
args = get_args()
timers = get_timers()
`<span class="hljs-comment"># Set grad to zero.</span>
<span class="hljs-keyword">if</span> args.DDP_impl == <span class="hljs-string">'local'</span> <span class="hljs-keyword">and</span> args.use_contiguous_buffers_in_local_ddp:
<span class="hljs-keyword">for</span> partition <span class="hljs-keyword">in</span> model:
partition.zero_grad_buffer()
optimizer.zero_grad()
<span class="hljs-comment"># 获取训练schedule</span>
forward_backward_func = get_forward_backward_func()
losses_reduced = forward_backward_func( <span class="hljs-comment"># 进行训练</span>
forward_step_func, data_iterator, model,
optimizer, timers, forward_only=<span class="hljs-literal">False</span>)
<span class="hljs-comment"># Empty unused memory</span>
<span class="hljs-keyword">if</span> args.empty_unused_memory_level >= <span class="hljs-number">1</span>:
torch.cuda.empty_cache()
<span class="hljs-comment"># All-reduce if needed.</span>
<span class="hljs-keyword">if</span> args.DDP_impl == <span class="hljs-string">'local'</span>:
<span class="hljs-keyword">for</span> model_module <span class="hljs-keyword">in</span> model:
model_module.allreduce_gradients()
<span class="hljs-comment"># All-reduce word_embeddings' grad across first and last stages to ensure</span>
<span class="hljs-comment"># that word_embeddings parameters stay in sync.</span>
<span class="hljs-comment"># This should only run for models that support pipelined model parallelism</span>
<span class="hljs-comment"># (BERT and GPT-2).</span>
<span class="hljs-keyword">if</span> mpu.is_rank_in_embedding_group(ignore_virtual=<span class="hljs-literal">True</span>) <span class="hljs-keyword">and</span> \
mpu.get_pipeline_model_parallel_world_size() > <span class="hljs-number">1</span>:
<span class="hljs-keyword">if</span> mpu.is_pipeline_first_stage(ignore_virtual=<span class="hljs-literal">True</span>):
unwrapped_model = model[<span class="hljs-number">0</span>]
<span class="hljs-keyword">elif</span> mpu.is_pipeline_last_stage(ignore_virtual=<span class="hljs-literal">True</span>):
unwrapped_model = model[-<span class="hljs-number">1</span>]
<span class="hljs-keyword">else</span>: <span class="hljs-comment"># We do not support the interleaved schedule for T5 yet.</span>
unwrapped_model = model[<span class="hljs-number">0</span>]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
<span class="hljs-keyword">if</span> unwrapped_model.share_word_embeddings:
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
<span class="hljs-keyword">if</span> args.DDP_impl == <span class="hljs-string">'local'</span>:
grad = word_embeddings_weight.main_grad
<span class="hljs-keyword">else</span>:
grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
<span class="hljs-comment"># Update parameters.</span>
update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
<span class="hljs-comment"># Update learning rate.</span>
<span class="hljs-keyword">if</span> update_successful:
increment = get_num_microbatches() * \
args.micro_batch_size * \
args.data_parallel_size
lr_scheduler.step(increment=increment)
skipped_iter = <span class="hljs-number">0</span>
<span class="hljs-keyword">else</span>:
skipped_iter = <span class="hljs-number">1</span>
<span class="hljs-comment"># Empty unused memory</span>
<span class="hljs-keyword">if</span> args.empty_unused_memory_level >= <span class="hljs-number">2</span>:
torch.cuda.empty_cache()
<span class="hljs-keyword">if</span> mpu.is_pipeline_last_stage(ignore_virtual=<span class="hljs-literal">True</span>):
<span class="hljs-comment"># Average loss across microbatches.</span>
loss_reduced = {}
<span class="hljs-keyword">for</span> key <span class="hljs-keyword">in</span> losses_reduced[<span class="hljs-number">0</span>]:
losses_reduced_for_key = [x[key] <span class="hljs-keyword">for</span> x <span class="hljs-keyword">in</span> losses_reduced]
loss_reduced[key] = <span class="hljs-built_in">sum</span>(losses_reduced_for_key) / <span class="hljs-built_in">len</span>(losses_reduced_for_key)
<span class="hljs-keyword">return</span> loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad
<span class="hljs-keyword">return</span> {}, skipped_iter, grad_norm, num_zeros_in_grad
`
6.3 获取schedule
get_forward_backward_func 获取 pipeline 的schedule,这里分为 flush 和 interleaving 两种,我们后续会分析这两种schedule。
python
def get_forward_backward_func():
args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None:
forward_backward_func = forward_backward_pipelining_with_interleaving
else:
forward_backward_func = forward_backward_pipelining_without_interleaving
else:
forward_backward_func = forward_backward_no_pipelining
return forward_backward_func
训练逻辑大体拓展为:

至此,Megatron 基本架构分析完毕,下一篇我们介绍模型并行设置。
0xFF 参考
GTC 2020: Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism