在训练深度学习模型的过程中,保存模型检查点是一个非常重要的步骤。它不仅可以防止训练过程中出现意外中断导致的损失,还能方便我们后续对模型进行评估和测试。
所谓先睹为快,因此我们先看代码:
py
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0): # 检查是否达到模型保存的间隔
model.eval() # 将模型切换到评估模式
moe_path = '_moe' if lm_config.use_moe else '' # 根据是否使用 MoE 架构构造保存路径(没有用到)
ckp = f'{args.save_dir}/pretrain_{lm_config.dim}{moe_path}.pth' # 构造检查点文件路径
if isinstance(model, torch.nn.parallel.DistributedDataParallel): # 如果模型是分布式数据并行模型,则保存 module 的状态字典
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
torch.save(state_dict, ckp) # 保存模型状态字典到文件
model.train() # 将模型切换回训练模式
整体分析
这段代码的主要目的是在训练过程中定期保存模型的状态字典。它通过判断当前训练步数是否满足保存间隔条件,并在满足条件时执行模型的保存操作。此外,代码还考虑了分布式训练的情况,确保只有主进程(rank 0)执行保存操作,以避免多个进程重复保存相同的模型。
逐行解释
python
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0): # 检查是否达到模型保存的间隔
- 条件判断 :这行代码检查两个条件是否都满足:
(step + 1) % args.save_interval == 0
:检查当前训练步数step + 1
是否是保存间隔args.save_interval
的整数倍。如果是,则表示达到了保存模型的条件。(not ddp or dist.get_rank() == 0)
:检查是否处于分布式训练模式(ddp
表示是否使用分布式数据并行)。如果未使用分布式训练(not ddp
),则直接满足条件;如果使用了分布式训练,则只有主进程(dist.get_rank() == 0
)满足条件。这样可以确保只有主进程负责保存模型,避免多个进程重复保存。
python
model.eval() # 将模型切换到评估模式
- 模型模式切换 :将模型设置为评估模式。在评估模式下,模型的行为会有一些变化,例如:
- Dropout 层:在训练模式下,Dropout 层会随机丢弃一部分神经元,以防止过拟合。但在评估模式下,Dropout 层不会丢弃任何神经元,以确保模型的输出是确定性的。
- BatchNorm 层:在训练模式下,BatchNorm 层会使用当前批次的统计信息(均值和方差)进行归一化。而在评估模式下,BatchNorm 层会使用训练过程中累积的全局统计信息,以确保模型的输出一致。
python
moe_path = '_moe' if lm_config.use_moe else '' # 根据是否使用 MoE 架构构造保存路径(没有用到)
- 构造保存路径后缀 :根据模型是否使用了 MoE(Mixture of Experts)架构来构造保存路径的后缀。
lm_config.use_moe
:这是一个布尔值,表示模型是否启用了 MoE 架构。- 如果启用了 MoE 架构,
moe_path
被设置为'_moe'
;否则,设置为一个空字符串。
python
ckp = f'{args.save_dir}/pretrain_{lm_config.dim}{moe_path}.pth' # 构造检查点文件路径
- 构造保存路径 :构造一个保存模型检查点的文件路径。
args.save_dir
:保存目录的路径。pretrain_{lm_config.dim}
:文件名的前缀,其中lm_config.dim
表示模型的隐藏层维度。{moe_path}
:根据是否使用 MoE 架构添加的后缀。.pth
:文件扩展名,表示这是一个 PyTorch 的模型文件。
python
if isinstance(model, torch.nn.parallel.DistributedDataParallel): # 如果模型是分布式数据并行模型,则保存 module 的状态字典
- 检查模型类型 :检查模型是否是
torch.nn.parallel.DistributedDataParallel
类型。DistributedDataParallel
是 PyTorch 中用于分布式训练的一个包装类,它会将模型的参数和梯度分布在多个 GPU 上。- 如果模型是
DistributedDataParallel
类型,则需要保存其内部的module
的状态字典,因为DistributedDataParallel
是对原始模型的一个包装。
- 如果模型是
python
state_dict = model.module.state_dict()
- 获取状态字典 :如果模型是
DistributedDataParallel
类型,则调用model.module.state_dict()
获取模型的状态字典。状态字典包含了模型的所有参数和缓冲区。
python
else:
state_dict = model.state_dict()
- 获取状态字典 :如果模型不是
DistributedDataParallel
类型,则直接调用model.state_dict()
获取模型的状态字典。
python
torch.save(state_dict, ckp) # 保存模型状态字典到文件
- 保存模型 :使用
torch.save()
函数将模型的状态字典保存到之前构造的文件路径ckp
中。这样就可以在后续的训练或推理中加载该模型。
python
model.train() # 将模型切换回训练模式
- 恢复模型模式:将模型重新切换回训练模式。在训练模式下,模型的行为会恢复到适合训练的状态,例如 Dropout 层会重新启用随机丢弃,BatchNorm 层会使用当前批次的统计信息。
总结
这段代码的主要逻辑是:
- 条件判断:检查是否满足保存模型的条件(步数间隔和分布式训练的主进程判断)。
- 模型模式切换:将模型切换到评估模式,以确保保存的模型状态是确定性的。
- 构造保存路径:根据模型的配置构造保存路径。
- 保存模型状态字典:根据模型是否是分布式数据并行模型,获取相应的状态字典并保存到文件中。
- 恢复模型模式:将模型切换回训练模式,以便继续训练。
保存模型检查点在训练循环中非常重要,它确保了模型可以在训练过程中定期保存,以便后续的恢复、评估或部署。