保存模型检查点

在训练深度学习模型的过程中,保存模型检查点是一个非常重要的步骤。它不仅可以防止训练过程中出现意外中断导致的损失,还能方便我们后续对模型进行评估和测试。

所谓先睹为快,因此我们先看代码:

py 复制代码
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0): # 检查是否达到模型保存的间隔
    model.eval() # 将模型切换到评估模式
    moe_path = '_moe' if lm_config.use_moe else '' # 根据是否使用 MoE 架构构造保存路径(没有用到)
    ckp = f'{args.save_dir}/pretrain_{lm_config.dim}{moe_path}.pth' # 构造检查点文件路径

    if isinstance(model, torch.nn.parallel.DistributedDataParallel): # 如果模型是分布式数据并行模型,则保存 module 的状态字典
        state_dict = model.module.state_dict()
    else:
        state_dict = model.state_dict()

    torch.save(state_dict, ckp) # 保存模型状态字典到文件
    model.train() # 将模型切换回训练模式

整体分析

这段代码的主要目的是在训练过程中定期保存模型的状态字典。它通过判断当前训练步数是否满足保存间隔条件,并在满足条件时执行模型的保存操作。此外,代码还考虑了分布式训练的情况,确保只有主进程(rank 0)执行保存操作,以避免多个进程重复保存相同的模型。

逐行解释

python 复制代码
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0): # 检查是否达到模型保存的间隔
  • 条件判断 :这行代码检查两个条件是否都满足:
    • (step + 1) % args.save_interval == 0:检查当前训练步数 step + 1 是否是保存间隔 args.save_interval 的整数倍。如果是,则表示达到了保存模型的条件。
    • (not ddp or dist.get_rank() == 0):检查是否处于分布式训练模式(ddp 表示是否使用分布式数据并行)。如果未使用分布式训练(not ddp),则直接满足条件;如果使用了分布式训练,则只有主进程(dist.get_rank() == 0)满足条件。这样可以确保只有主进程负责保存模型,避免多个进程重复保存。
python 复制代码
model.eval() # 将模型切换到评估模式
  • 模型模式切换 :将模型设置为评估模式。在评估模式下,模型的行为会有一些变化,例如:
    • Dropout 层:在训练模式下,Dropout 层会随机丢弃一部分神经元,以防止过拟合。但在评估模式下,Dropout 层不会丢弃任何神经元,以确保模型的输出是确定性的。
    • BatchNorm 层:在训练模式下,BatchNorm 层会使用当前批次的统计信息(均值和方差)进行归一化。而在评估模式下,BatchNorm 层会使用训练过程中累积的全局统计信息,以确保模型的输出一致。
python 复制代码
moe_path = '_moe' if lm_config.use_moe else '' # 根据是否使用 MoE 架构构造保存路径(没有用到)
  • 构造保存路径后缀 :根据模型是否使用了 MoE(Mixture of Experts)架构来构造保存路径的后缀。
    • lm_config.use_moe:这是一个布尔值,表示模型是否启用了 MoE 架构。
    • 如果启用了 MoE 架构,moe_path 被设置为 '_moe';否则,设置为一个空字符串。
python 复制代码
ckp = f'{args.save_dir}/pretrain_{lm_config.dim}{moe_path}.pth' # 构造检查点文件路径
  • 构造保存路径 :构造一个保存模型检查点的文件路径。
    • args.save_dir:保存目录的路径。
    • pretrain_{lm_config.dim}:文件名的前缀,其中 lm_config.dim 表示模型的隐藏层维度。
    • {moe_path}:根据是否使用 MoE 架构添加的后缀。
    • .pth:文件扩展名,表示这是一个 PyTorch 的模型文件。
python 复制代码
if isinstance(model, torch.nn.parallel.DistributedDataParallel): # 如果模型是分布式数据并行模型,则保存 module 的状态字典
  • 检查模型类型 :检查模型是否是 torch.nn.parallel.DistributedDataParallel 类型。DistributedDataParallel 是 PyTorch 中用于分布式训练的一个包装类,它会将模型的参数和梯度分布在多个 GPU 上。
    • 如果模型是 DistributedDataParallel 类型,则需要保存其内部的 module 的状态字典,因为 DistributedDataParallel 是对原始模型的一个包装。
python 复制代码
state_dict = model.module.state_dict()
  • 获取状态字典 :如果模型是 DistributedDataParallel 类型,则调用 model.module.state_dict() 获取模型的状态字典。状态字典包含了模型的所有参数和缓冲区。
python 复制代码
else:
    state_dict = model.state_dict()
  • 获取状态字典 :如果模型不是 DistributedDataParallel 类型,则直接调用 model.state_dict() 获取模型的状态字典。
python 复制代码
torch.save(state_dict, ckp) # 保存模型状态字典到文件
  • 保存模型 :使用 torch.save() 函数将模型的状态字典保存到之前构造的文件路径 ckp 中。这样就可以在后续的训练或推理中加载该模型。
python 复制代码
model.train() # 将模型切换回训练模式
  • 恢复模型模式:将模型重新切换回训练模式。在训练模式下,模型的行为会恢复到适合训练的状态,例如 Dropout 层会重新启用随机丢弃,BatchNorm 层会使用当前批次的统计信息。

总结

这段代码的主要逻辑是:

  1. 条件判断:检查是否满足保存模型的条件(步数间隔和分布式训练的主进程判断)。
  2. 模型模式切换:将模型切换到评估模式,以确保保存的模型状态是确定性的。
  3. 构造保存路径:根据模型的配置构造保存路径。
  4. 保存模型状态字典:根据模型是否是分布式数据并行模型,获取相应的状态字典并保存到文件中。
  5. 恢复模型模式:将模型切换回训练模式,以便继续训练。

保存模型检查点在训练循环中非常重要,它确保了模型可以在训练过程中定期保存,以便后续的恢复、评估或部署。

相关推荐
智商低情商凑1 小时前
CAS(Compare And Swap)
java·jvm·面试
uhakadotcom2 小时前
人工智能如何改变医疗行业:简单易懂的基础介绍与实用案例
算法·面试·github
zizisuo2 小时前
面试篇:Spring Boot
spring boot·面试·职场和发展
uhakadotcom4 小时前
企业智能体网络(Agent Mesh)入门指南:基础知识与实用示例
后端·面试·github
独孤歌5 小时前
告别频繁登录:打造用户无感的 Token 刷新机制
安全·面试
Eliauk__5 小时前
深入剖析 Vue 双向数据绑定机制 —— 从响应式原理到 v-model 实现全解析
前端·javascript·面试
慕仲卿5 小时前
模型初始化:加载分词器和模型
面试
海底火旺5 小时前
寻找缺失的最小正整数:从暴力到最优的算法演进
javascript·算法·面试
顾林海5 小时前
深入探究 Android Native 代码的崩溃捕获机制
android·面试·性能优化
慕仲卿5 小时前
缩放器和优化器的定义
面试