Pytorch FSDP权重分片保存与合并

注:本文章方法只适用Pytorch FSDP1的模型,且切分策略为SHARDED_STATE_DICT场景。

在使用FSDP训练模型时,为了节省显存通常会把模型权重也进行切分,在保存权重时为了加速保存通常每个进程各自保存自己持有的部分权重,避免先汇聚到主进程再保存浪费大量时间的问题。保存成分片权重后,如果需要推理则还需要将分片权重进行合并。下面提供了保存分片权重以及将分片权重合并的代码示例,代码主要参考accelerate官方源码。

python 复制代码
import os

import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType
import torch.distributed.checkpoint as dist_cp
from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
import torch.distributed.checkpoint.format_utils as dist_cp_format_utils


def save_fsdp_model(model: FSDP, fsdp_ckpt_path: str):
    # refer accelerate/utils/fsdp_utils.py:save_fsdp_model
    with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
        os.makedirs(fsdp_ckpt_path, exist_ok=True)

        state_dict = {"model": model.state_dict()}
        dist_cp.save(
            state_dict=state_dict,
            storage_writer=dist_cp.FileSystemWriter(fsdp_ckpt_path),
            planner=DefaultSavePlanner(),
        )


def merge_fsdp_weights(fsdp_ckpt_path: str, save_path: str):
    # refer accelerate/utils/fsdp_utils.py:merge_fsdp_weights
    state_dict = {}
    dist_cp_format_utils._load_state_dict(
        state_dict,
        storage_reader=dist_cp.FileSystemReader(fsdp_ckpt_path),
        planner=dist_cp_format_utils._EmptyStateDictLoadPlanner(),
        no_dist=True,
    )

    # To handle if state is a dict like {model: {...}}
    if len(state_dict.keys()) == 1:
        state_dict = state_dict[list(state_dict)[0]]

    torch.save(state_dict, save_path)