保存模型检查点

在训练深度学习模型的过程中,保存模型检查点是一个非常重要的步骤。它不仅可以防止训练过程中出现意外中断导致的损失,还能方便我们后续对模型进行评估和测试。

所谓先睹为快,因此我们先看代码:

py 复制代码
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0): # 检查是否达到模型保存的间隔
    model.eval() # 将模型切换到评估模式
    moe_path = '_moe' if lm_config.use_moe else '' # 根据是否使用 MoE 架构构造保存路径(没有用到)
    ckp = f'{args.save_dir}/pretrain_{lm_config.dim}{moe_path}.pth' # 构造检查点文件路径

    if isinstance(model, torch.nn.parallel.DistributedDataParallel): # 如果模型是分布式数据并行模型,则保存 module 的状态字典
        state_dict = model.module.state_dict()
    else:
        state_dict = model.state_dict()

    torch.save(state_dict, ckp) # 保存模型状态字典到文件
    model.train() # 将模型切换回训练模式

整体分析

这段代码的主要目的是在训练过程中定期保存模型的状态字典。它通过判断当前训练步数是否满足保存间隔条件,并在满足条件时执行模型的保存操作。此外,代码还考虑了分布式训练的情况,确保只有主进程(rank 0)执行保存操作,以避免多个进程重复保存相同的模型。

逐行解释

python 复制代码
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0): # 检查是否达到模型保存的间隔
  • 条件判断 :这行代码检查两个条件是否都满足:
    • (step + 1) % args.save_interval == 0:检查当前训练步数 step + 1 是否是保存间隔 args.save_interval 的整数倍。如果是,则表示达到了保存模型的条件。
    • (not ddp or dist.get_rank() == 0):检查是否处于分布式训练模式(ddp 表示是否使用分布式数据并行)。如果未使用分布式训练(not ddp),则直接满足条件;如果使用了分布式训练,则只有主进程(dist.get_rank() == 0)满足条件。这样可以确保只有主进程负责保存模型,避免多个进程重复保存。
python 复制代码
model.eval() # 将模型切换到评估模式
  • 模型模式切换 :将模型设置为评估模式。在评估模式下,模型的行为会有一些变化,例如:
    • Dropout 层:在训练模式下,Dropout 层会随机丢弃一部分神经元,以防止过拟合。但在评估模式下,Dropout 层不会丢弃任何神经元,以确保模型的输出是确定性的。
    • BatchNorm 层:在训练模式下,BatchNorm 层会使用当前批次的统计信息(均值和方差)进行归一化。而在评估模式下,BatchNorm 层会使用训练过程中累积的全局统计信息,以确保模型的输出一致。
python 复制代码
moe_path = '_moe' if lm_config.use_moe else '' # 根据是否使用 MoE 架构构造保存路径(没有用到)
  • 构造保存路径后缀 :根据模型是否使用了 MoE(Mixture of Experts)架构来构造保存路径的后缀。
    • lm_config.use_moe:这是一个布尔值,表示模型是否启用了 MoE 架构。
    • 如果启用了 MoE 架构,moe_path 被设置为 '_moe';否则,设置为一个空字符串。
python 复制代码
ckp = f'{args.save_dir}/pretrain_{lm_config.dim}{moe_path}.pth' # 构造检查点文件路径
  • 构造保存路径 :构造一个保存模型检查点的文件路径。
    • args.save_dir:保存目录的路径。
    • pretrain_{lm_config.dim}:文件名的前缀,其中 lm_config.dim 表示模型的隐藏层维度。
    • {moe_path}:根据是否使用 MoE 架构添加的后缀。
    • .pth:文件扩展名,表示这是一个 PyTorch 的模型文件。
python 复制代码
if isinstance(model, torch.nn.parallel.DistributedDataParallel): # 如果模型是分布式数据并行模型,则保存 module 的状态字典
  • 检查模型类型 :检查模型是否是 torch.nn.parallel.DistributedDataParallel 类型。DistributedDataParallel 是 PyTorch 中用于分布式训练的一个包装类,它会将模型的参数和梯度分布在多个 GPU 上。
    • 如果模型是 DistributedDataParallel 类型,则需要保存其内部的 module 的状态字典,因为 DistributedDataParallel 是对原始模型的一个包装。
python 复制代码
state_dict = model.module.state_dict()
  • 获取状态字典 :如果模型是 DistributedDataParallel 类型,则调用 model.module.state_dict() 获取模型的状态字典。状态字典包含了模型的所有参数和缓冲区。
python 复制代码
else:
    state_dict = model.state_dict()
  • 获取状态字典 :如果模型不是 DistributedDataParallel 类型,则直接调用 model.state_dict() 获取模型的状态字典。
python 复制代码
torch.save(state_dict, ckp) # 保存模型状态字典到文件
  • 保存模型 :使用 torch.save() 函数将模型的状态字典保存到之前构造的文件路径 ckp 中。这样就可以在后续的训练或推理中加载该模型。
python 复制代码
model.train() # 将模型切换回训练模式
  • 恢复模型模式:将模型重新切换回训练模式。在训练模式下,模型的行为会恢复到适合训练的状态,例如 Dropout 层会重新启用随机丢弃,BatchNorm 层会使用当前批次的统计信息。

总结

这段代码的主要逻辑是:

  1. 条件判断:检查是否满足保存模型的条件(步数间隔和分布式训练的主进程判断)。
  2. 模型模式切换:将模型切换到评估模式,以确保保存的模型状态是确定性的。
  3. 构造保存路径:根据模型的配置构造保存路径。
  4. 保存模型状态字典:根据模型是否是分布式数据并行模型,获取相应的状态字典并保存到文件中。
  5. 恢复模型模式:将模型切换回训练模式,以便继续训练。

保存模型检查点在训练循环中非常重要,它确保了模型可以在训练过程中定期保存,以便后续的恢复、评估或部署。

相关推荐
八股文领域大手子2 小时前
磁盘I/O瓶颈排查:面试通关“三部曲”心法
面试·职场和发展
大学生小郑14 小时前
Go语言八股之Mysql基础详解
mysql·面试
八股文领域大手子17 小时前
Java死锁排查:线上救火实战指南
java·开发语言·面试
XQ丶YTY18 小时前
大二java第一面小厂(挂)
java·开发语言·笔记·学习·面试
面试官E先生19 小时前
【极兔快递Java社招】一面复盘|数据库+线程池+AQS+中间件面面俱到
java·面试
独行soc1 天前
2025年渗透测试面试题总结-渗透测试红队面试九(题目+回答)
linux·安全·web安全·网络安全·面试·职场和发展·渗透测试
软件测试媛1 天前
软件测试——面试八股文(入门篇)
软件测试·面试·职场和发展
牛马baby1 天前
Java高频面试之并发编程-17
java·开发语言·面试
chenyuhao20242 天前
链表的面试题4之合并有序链表
数据结构·链表·面试·c#
PgSheep2 天前
深入理解 JVM:StackOverFlow、OOM 与 GC overhead limit exceeded 的本质剖析及 Stack 与 Heap 的差异
jvm·面试