保存模型检查点

在训练深度学习模型的过程中,保存模型检查点是一个非常重要的步骤。它不仅可以防止训练过程中出现意外中断导致的损失,还能方便我们后续对模型进行评估和测试。

所谓先睹为快,因此我们先看代码:

py 复制代码
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0): # 检查是否达到模型保存的间隔
    model.eval() # 将模型切换到评估模式
    moe_path = '_moe' if lm_config.use_moe else '' # 根据是否使用 MoE 架构构造保存路径(没有用到)
    ckp = f'{args.save_dir}/pretrain_{lm_config.dim}{moe_path}.pth' # 构造检查点文件路径

    if isinstance(model, torch.nn.parallel.DistributedDataParallel): # 如果模型是分布式数据并行模型,则保存 module 的状态字典
        state_dict = model.module.state_dict()
    else:
        state_dict = model.state_dict()

    torch.save(state_dict, ckp) # 保存模型状态字典到文件
    model.train() # 将模型切换回训练模式

整体分析

这段代码的主要目的是在训练过程中定期保存模型的状态字典。它通过判断当前训练步数是否满足保存间隔条件,并在满足条件时执行模型的保存操作。此外,代码还考虑了分布式训练的情况,确保只有主进程(rank 0)执行保存操作,以避免多个进程重复保存相同的模型。

逐行解释

python 复制代码
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0): # 检查是否达到模型保存的间隔
  • 条件判断 :这行代码检查两个条件是否都满足:
    • (step + 1) % args.save_interval == 0:检查当前训练步数 step + 1 是否是保存间隔 args.save_interval 的整数倍。如果是,则表示达到了保存模型的条件。
    • (not ddp or dist.get_rank() == 0):检查是否处于分布式训练模式(ddp 表示是否使用分布式数据并行)。如果未使用分布式训练(not ddp),则直接满足条件;如果使用了分布式训练,则只有主进程(dist.get_rank() == 0)满足条件。这样可以确保只有主进程负责保存模型,避免多个进程重复保存。
python 复制代码
model.eval() # 将模型切换到评估模式
  • 模型模式切换 :将模型设置为评估模式。在评估模式下,模型的行为会有一些变化,例如:
    • Dropout 层:在训练模式下,Dropout 层会随机丢弃一部分神经元,以防止过拟合。但在评估模式下,Dropout 层不会丢弃任何神经元,以确保模型的输出是确定性的。
    • BatchNorm 层:在训练模式下,BatchNorm 层会使用当前批次的统计信息(均值和方差)进行归一化。而在评估模式下,BatchNorm 层会使用训练过程中累积的全局统计信息,以确保模型的输出一致。
python 复制代码
moe_path = '_moe' if lm_config.use_moe else '' # 根据是否使用 MoE 架构构造保存路径(没有用到)
  • 构造保存路径后缀 :根据模型是否使用了 MoE(Mixture of Experts)架构来构造保存路径的后缀。
    • lm_config.use_moe:这是一个布尔值,表示模型是否启用了 MoE 架构。
    • 如果启用了 MoE 架构,moe_path 被设置为 '_moe';否则,设置为一个空字符串。
python 复制代码
ckp = f'{args.save_dir}/pretrain_{lm_config.dim}{moe_path}.pth' # 构造检查点文件路径
  • 构造保存路径 :构造一个保存模型检查点的文件路径。
    • args.save_dir:保存目录的路径。
    • pretrain_{lm_config.dim}:文件名的前缀,其中 lm_config.dim 表示模型的隐藏层维度。
    • {moe_path}:根据是否使用 MoE 架构添加的后缀。
    • .pth:文件扩展名,表示这是一个 PyTorch 的模型文件。
python 复制代码
if isinstance(model, torch.nn.parallel.DistributedDataParallel): # 如果模型是分布式数据并行模型,则保存 module 的状态字典
  • 检查模型类型 :检查模型是否是 torch.nn.parallel.DistributedDataParallel 类型。DistributedDataParallel 是 PyTorch 中用于分布式训练的一个包装类,它会将模型的参数和梯度分布在多个 GPU 上。
    • 如果模型是 DistributedDataParallel 类型,则需要保存其内部的 module 的状态字典,因为 DistributedDataParallel 是对原始模型的一个包装。
python 复制代码
state_dict = model.module.state_dict()
  • 获取状态字典 :如果模型是 DistributedDataParallel 类型,则调用 model.module.state_dict() 获取模型的状态字典。状态字典包含了模型的所有参数和缓冲区。
python 复制代码
else:
    state_dict = model.state_dict()
  • 获取状态字典 :如果模型不是 DistributedDataParallel 类型,则直接调用 model.state_dict() 获取模型的状态字典。
python 复制代码
torch.save(state_dict, ckp) # 保存模型状态字典到文件
  • 保存模型 :使用 torch.save() 函数将模型的状态字典保存到之前构造的文件路径 ckp 中。这样就可以在后续的训练或推理中加载该模型。
python 复制代码
model.train() # 将模型切换回训练模式
  • 恢复模型模式:将模型重新切换回训练模式。在训练模式下,模型的行为会恢复到适合训练的状态,例如 Dropout 层会重新启用随机丢弃,BatchNorm 层会使用当前批次的统计信息。

总结

这段代码的主要逻辑是:

  1. 条件判断:检查是否满足保存模型的条件(步数间隔和分布式训练的主进程判断)。
  2. 模型模式切换:将模型切换到评估模式,以确保保存的模型状态是确定性的。
  3. 构造保存路径:根据模型的配置构造保存路径。
  4. 保存模型状态字典:根据模型是否是分布式数据并行模型,获取相应的状态字典并保存到文件中。
  5. 恢复模型模式:将模型切换回训练模式,以便继续训练。

保存模型检查点在训练循环中非常重要,它确保了模型可以在训练过程中定期保存,以便后续的恢复、评估或部署。

相关推荐
西安邮电大学3 分钟前
有关数组的经典算法题
java·后端·其他·算法·面试
触底反弹43 分钟前
一文彻底搞懂 JavaScript 栈和队列(建议收藏)
javascript·算法·面试
AI人工智能+电脑小能手1 小时前
【大白话说Java面试题 第113题】【并发篇】第13题:说一下乐观锁的优点和缺点?
java·开发语言·面试
Mahir081 小时前
HashMap 底层原理深度解密:从数据结构到 JDK1.7/1.8 演进全解
java·后端·面试·hashmap
uhakadotcom1 小时前
get_event_loop(),和 get_running_loop() + ThreadPoolExecutor 有啥区别
后端·面试·github
牛油果子哥q2 小时前
二叉树(Binary Tree)零基础精讲,树基础概念、树形分类、核心性质、递归/层序遍历、完整代码与面试考点全解
c++·面试·数据挖掘
牛油果子哥q2 小时前
队列(Queue)深度精讲,先进先出原理、顺序/链式/循环队列、STL queue底层、栈队列互模拟与面试考点全解
开发语言·c++·面试
Mahir083 小时前
ConcurrentHashMap 底层原理深度解密:从分段锁到 CAS + 红黑树的演进全解
java·面试·concurhashmap
刀法如飞3 小时前
《理解道德经》简单版第 3 章:不尚贤,使民不争
面试·程序员·创业
kyriewen13 小时前
手写 Promise.all、race、any:不到 30 行代码,解决并发异步的所有姿势
前端·javascript·面试