保存模型检查点

在训练深度学习模型的过程中,保存模型检查点是一个非常重要的步骤。它不仅可以防止训练过程中出现意外中断导致的损失,还能方便我们后续对模型进行评估和测试。

所谓先睹为快,因此我们先看代码:

py 复制代码
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0): # 检查是否达到模型保存的间隔
    model.eval() # 将模型切换到评估模式
    moe_path = '_moe' if lm_config.use_moe else '' # 根据是否使用 MoE 架构构造保存路径(没有用到)
    ckp = f'{args.save_dir}/pretrain_{lm_config.dim}{moe_path}.pth' # 构造检查点文件路径

    if isinstance(model, torch.nn.parallel.DistributedDataParallel): # 如果模型是分布式数据并行模型,则保存 module 的状态字典
        state_dict = model.module.state_dict()
    else:
        state_dict = model.state_dict()

    torch.save(state_dict, ckp) # 保存模型状态字典到文件
    model.train() # 将模型切换回训练模式

整体分析

这段代码的主要目的是在训练过程中定期保存模型的状态字典。它通过判断当前训练步数是否满足保存间隔条件,并在满足条件时执行模型的保存操作。此外,代码还考虑了分布式训练的情况,确保只有主进程(rank 0)执行保存操作,以避免多个进程重复保存相同的模型。

逐行解释

python 复制代码
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0): # 检查是否达到模型保存的间隔
  • 条件判断 :这行代码检查两个条件是否都满足:
    • (step + 1) % args.save_interval == 0:检查当前训练步数 step + 1 是否是保存间隔 args.save_interval 的整数倍。如果是,则表示达到了保存模型的条件。
    • (not ddp or dist.get_rank() == 0):检查是否处于分布式训练模式(ddp 表示是否使用分布式数据并行)。如果未使用分布式训练(not ddp),则直接满足条件;如果使用了分布式训练,则只有主进程(dist.get_rank() == 0)满足条件。这样可以确保只有主进程负责保存模型,避免多个进程重复保存。
python 复制代码
model.eval() # 将模型切换到评估模式
  • 模型模式切换 :将模型设置为评估模式。在评估模式下,模型的行为会有一些变化,例如:
    • Dropout 层:在训练模式下,Dropout 层会随机丢弃一部分神经元,以防止过拟合。但在评估模式下,Dropout 层不会丢弃任何神经元,以确保模型的输出是确定性的。
    • BatchNorm 层:在训练模式下,BatchNorm 层会使用当前批次的统计信息(均值和方差)进行归一化。而在评估模式下,BatchNorm 层会使用训练过程中累积的全局统计信息,以确保模型的输出一致。
python 复制代码
moe_path = '_moe' if lm_config.use_moe else '' # 根据是否使用 MoE 架构构造保存路径(没有用到)
  • 构造保存路径后缀 :根据模型是否使用了 MoE(Mixture of Experts)架构来构造保存路径的后缀。
    • lm_config.use_moe:这是一个布尔值,表示模型是否启用了 MoE 架构。
    • 如果启用了 MoE 架构,moe_path 被设置为 '_moe';否则,设置为一个空字符串。
python 复制代码
ckp = f'{args.save_dir}/pretrain_{lm_config.dim}{moe_path}.pth' # 构造检查点文件路径
  • 构造保存路径 :构造一个保存模型检查点的文件路径。
    • args.save_dir:保存目录的路径。
    • pretrain_{lm_config.dim}:文件名的前缀,其中 lm_config.dim 表示模型的隐藏层维度。
    • {moe_path}:根据是否使用 MoE 架构添加的后缀。
    • .pth:文件扩展名,表示这是一个 PyTorch 的模型文件。
python 复制代码
if isinstance(model, torch.nn.parallel.DistributedDataParallel): # 如果模型是分布式数据并行模型,则保存 module 的状态字典
  • 检查模型类型 :检查模型是否是 torch.nn.parallel.DistributedDataParallel 类型。DistributedDataParallel 是 PyTorch 中用于分布式训练的一个包装类,它会将模型的参数和梯度分布在多个 GPU 上。
    • 如果模型是 DistributedDataParallel 类型,则需要保存其内部的 module 的状态字典,因为 DistributedDataParallel 是对原始模型的一个包装。
python 复制代码
state_dict = model.module.state_dict()
  • 获取状态字典 :如果模型是 DistributedDataParallel 类型,则调用 model.module.state_dict() 获取模型的状态字典。状态字典包含了模型的所有参数和缓冲区。
python 复制代码
else:
    state_dict = model.state_dict()
  • 获取状态字典 :如果模型不是 DistributedDataParallel 类型,则直接调用 model.state_dict() 获取模型的状态字典。
python 复制代码
torch.save(state_dict, ckp) # 保存模型状态字典到文件
  • 保存模型 :使用 torch.save() 函数将模型的状态字典保存到之前构造的文件路径 ckp 中。这样就可以在后续的训练或推理中加载该模型。
python 复制代码
model.train() # 将模型切换回训练模式
  • 恢复模型模式:将模型重新切换回训练模式。在训练模式下,模型的行为会恢复到适合训练的状态,例如 Dropout 层会重新启用随机丢弃,BatchNorm 层会使用当前批次的统计信息。

总结

这段代码的主要逻辑是:

  1. 条件判断:检查是否满足保存模型的条件(步数间隔和分布式训练的主进程判断)。
  2. 模型模式切换:将模型切换到评估模式,以确保保存的模型状态是确定性的。
  3. 构造保存路径:根据模型的配置构造保存路径。
  4. 保存模型状态字典:根据模型是否是分布式数据并行模型,获取相应的状态字典并保存到文件中。
  5. 恢复模型模式:将模型切换回训练模式,以便继续训练。

保存模型检查点在训练循环中非常重要,它确保了模型可以在训练过程中定期保存,以便后续的恢复、评估或部署。

相关推荐
榮十一1 小时前
100道Java面试SQL题及答案
java·sql·面试
LYFlied1 小时前
【每日算法】LeetCode 20. 有效的括号
数据结构·算法·leetcode·面试
WYiQIU1 小时前
从今天开始备战1月中旬的前端寒假实习需要准备什么?(飞书+github+源码+题库含答案)
前端·javascript·面试·职场和发展·前端框架·github·飞书
LYFlied1 小时前
【每日算法】LeetCode 76. 最小覆盖子串
数据结构·算法·leetcode·面试·职场和发展
努力学算法的蒟蒻2 小时前
day36(12.17)——leetcode面试经典150
算法·leetcode·面试
用户479492835691511 小时前
面试官问"try-catch影响性能吗",我用数据打脸
前端·javascript·面试
沐雪架构师11 小时前
大模型Agent面试精选15题(第四辑)-Agent与RAG(检索增强生成)结合的高频面试题
面试·职场和发展
未若君雅裁11 小时前
JVM面试篇总结
java·jvm·面试
YoungHong199212 小时前
面试经典150题[072]:从前序与中序遍历序列构造二叉树(LeetCode 105)
leetcode·面试·职场和发展
用户479492835691515 小时前
改了CSS刷新没反应-你可能不懂HTTP缓存
前端·javascript·面试