深入理解DeepSpeed模型检查点机制与功能


模型检查点简介

DeepSpeed为训练过程中的模型状态提供了强大的检查点功能,使得研究者能够安全地保存和恢复训练进度,便于调试、故障恢复以及在不同环境下复现训练结果。本文将详细介绍如何使用DeepSpeed进行模型检查点的加载与保存,以及如何从ZeRO检查点中提取fp32权重,避免检查点膨胀等问题。


加载训练检查点

DeepSpeed通过deepspeed.DeepSpeedEngine.load_checkpoint方法加载训练检查点。该方法接受多个参数来控制加载行为,包括:

python 复制代码
deepspeed.DeepSpeedEngine.load_checkpoint(
    self,
    load_dir,                  # 必填,要从中加载检查点的目录
    tag=None,                  # 可选,用于唯一标识检查点的标签,默认尝试加载'latest'文件中的标签
    load_module_strict=True,   # 可选,是否严格要求模块状态字典与检查点中的键匹配
    load_optimizer_states=True, # 可选,是否从检查点加载训练优化器状态(如ADAM的动量和方差)
    load_lr_scheduler_states=True, # 可选,是否添加来自检查点的学习率调度器状态
    load_module_only=False,     # 可选,是否仅加载模型权重(如warmstarting)
    custom_load_fn=None         # 可选,自定义模型加载函数
)

返回值包括已加载检查点的路径和客户端状态字典。需要注意的是,在ZeRO-3模式下,不能直接在.save_checkpoint()之后立即调用.load_checkpoint(),因为模型已被分区,且.load_checkpoint()需要原始未分区的模型。


保存训练检查点

使用deepspeed.DeepSpeedEngine.save_checkpoint方法保存训练检查点,所有进程都必须调用此方法,而不仅仅是rank为0的进程。这是因为每个进程都需要保存其主权重以及调度器和优化器的状态。

python 复制代码
deepspeed.DeepSpeedEngine.save_checkpoint(
    self,
    save_dir,                 # 必填,保存检查点的目录
    tag=None,                 # 可选,检查点标签,如果不提供则使用全局步数作为唯一标识符
    client_state={},          # 可选,用于在客户端代码中保存所需训练状态的状态字典
    save_latest=True,          # 可选,是否保存指向最新检查点的'latest'文件
    exclude_frozen_parameters=False  # 可选,是否排除冻结参数
)

ZeRO检查点fp32权重恢复

DeepSpeed提供了一套方法,可以从保存的ZeRO检查点的优化器状态中提取fp32权重。例如,使用deepspeed.utils.zero_to_fp32.get_fp32_state_dict_from_zero_checkpoint方法可以将ZeRO 2或3的检查点转换为单一的fp32合并状态字典,然后加载到非DeepSpeed环境下的模型中。

另外,deepspeed.utils.zero_to_fp32.load_state_dict_from_zero_checkpoint方法更为简便,它可以将提供的模型置于CPU上,将ZeRO检查点转换为fp32状态字典并加载到模型中。注意,一旦执行这个操作,模型就不能再在相同的DeepSpeed上下文中使用,需要重新初始化DeepSpeed引擎。


避免ZeRO检查点膨胀问题

有时,使用torch.save()创建的ZeRO阶段1和2的检查点可能会比预期更大。这个问题是由于ZeRO的张量展平与PyTorch张量存储管理之间的交互导致的。为了避免这种情况,可以使用DeepSpeed的deepspeed.checkpoint.utils.clone_tensors_for_torch_save工具函数。例如,在创建HuggingFace模型检查点时,可以先克隆张量再进行保存,有效减小检查点大小。


总结起来,DeepSpeed的检查点机制为训练提供了极大的灵活性和便利性,无论是加载、保存还是从ZeRO检查点中提取fp32权重,都能确保训练过程的安全性和可靠性。同时,通过合理利用DeepSpeed提供的工具函数,还可以有效地解决检查点膨胀问题,实现高效模型管理和迁移。

相关推荐
武子康4 分钟前
大数据-212 数据挖掘 机器学习理论 - 无监督学习算法 KMeans 基本原理 簇内误差平方和
大数据·人工智能·学习·算法·机器学习·数据挖掘
passer__jw76733 分钟前
【LeetCode】【算法】283. 移动零
数据结构·算法·leetcode
Ocean☾39 分钟前
前端基础-html-注册界面
前端·算法·html
顶呱呱程序1 小时前
2-143 基于matlab-GUI的脉冲响应不变法实现音频滤波功能
算法·matlab·音视频·matlab-gui·音频滤波·脉冲响应不变法
爱吃生蚝的于勒1 小时前
深入学习指针(5)!!!!!!!!!!!!!!!
c语言·开发语言·数据结构·学习·计算机网络·算法
羊小猪~~1 小时前
数据结构C语言描述2(图文结合)--有头单链表,无头单链表(两种方法),链表反转、有序链表构建、排序等操作,考研可看
c语言·数据结构·c++·考研·算法·链表·visual studio
王哈哈^_^2 小时前
【数据集】【YOLO】【VOC】目标检测数据集,查找数据集,yolo目标检测算法详细实战训练步骤!
人工智能·深度学习·算法·yolo·目标检测·计算机视觉·pyqt
星沁城2 小时前
240. 搜索二维矩阵 II
java·线性代数·算法·leetcode·矩阵
脉牛杂德2 小时前
多项式加法——C语言
数据结构·c++·算法
legend_jz2 小时前
STL--哈希
c++·算法·哈希算法