「赤兔」Chitu 框架深度解读(十二):分布式并行初始化与管理

「赤兔」Chitu 框架深度解读(十二):分布式并行初始化与管理

大模型训练和推理通常需要在多个设备(GPU/NPU)上并行进行。「赤兔」Chitu 框架支持多种并行策略,包括张量并行 (TP)、流水线并行 (PP)、数据并行 (DP) 和专家并行 (EP)。其分布式并行环境的初始化和管理由 distributed/parallel_state.pydistributed/comm_group.py 模块负责。

核心概念:CommGroup

distributed/comm_group.py 定义了 CommGroup 类,它是对 PyTorch ProcessGroup 的封装和扩展。

  • 初始化 : CommGroup 根据传入的 rank_list (一个包含多个 Rank 列表的列表,每个子列表代表一个通信组) 和当前进程的全局 rank 来创建对应的 ProcessGroup
  • 关键属性 :
    • group: 底层的 PyTorch ProcessGroup
    • cpu_group: 对应的 CPU ProcessGroup (用于 CPU 上的集合通信)。
    • ranks_in_group: 当前 CommGroup 包含的所有 Rank 列表。
    • group_size: 当前进程所在通信组的大小。
    • rank_in_group: 当前进程在所在通信组内的局部 Rank。
    • is_first_rank/is_last_rank: 判断当前进程是否是组内的第一个/最后一个 Rank。
  • 通信操作封装 : 提供了对 torch.distributed 常用通信原语(如 broadcast, all_reduce, all_gather, reduce_scatter 等)的封装,自动传入正确的 group 参数。

CommGroup 的设计简化了在不同并行维度上进行通信的操作,使得上层代码无需手动管理多个 ProcessGroup 对象。

并行状态管理 (parallel_state.py)

distributed/parallel_state.py 负责初始化和维护不同并行维度的 CommGroup 实例,并提供全局访问接口。

  • 全局变量 : 定义了 _WORLD_GROUP, _TP_GROUP, _PP_GROUP, _DP_GROUP, _EP_GROUP 等全局变量,用于存储各个并行维度的 CommGroup 实例。
  • 初始化函数 (initialize_parallel_groups) : 这是并行设置的核心入口。
    • 输入 : TP, PP, DP, EP 的大小 (tp_size, pp_size, dp_size, ep_size)。
    • 获取环境信息 : 获取全局 rank, local_rank, world_size
    • 按序初始化 : 依次调用 initialize_world_group, initialize_tp_group, initialize_pp_group, initialize_dp_group, initialize_ep_group
    • 初始化逻辑 : 每个 initialize_*_group 函数根据并行维度的大小和当前 rank 计算出该维度对应的 rank_list,然后创建 CommGroup 实例并赋值给相应的全局变量。例如:
      • initialize_tp_group: world_size 被划分为 world_size // tp_size 个 TP 组,每个组包含 tp_size 个连续的 Rank。
      • initialize_pp_group: world_size 被划分为 world_size // pp_size 个 PP 组,每个组包含跨 TP 和 DP 维度、间隔为 num_pp_groups 的 Rank。
      • initialize_dp_group: 类似 PP 组的划分方式。
      • initialize_ep_group: 逻辑稍复杂。如果 ep_size > 1
        • tp_size == ep_sizedp_size == 1,则 EP 组直接复用 TP 组 (_EP_GROUP = _TP_GROUP)。
        • dp_size == ep_sizetp_size == 1,则 EP 组直接复用 DP 组 (_EP_GROUP = _DP_GROUP)。
        • 否则,创建新的 EP 通信组,通常是连续的 Rank 组成。
        • 如果 ep_size == 1,则每个 Rank 自己构成一个 EP 组。
    • 特殊处理 : initialize_pp_group 中包含了针对 Ascend NPU 的特殊处理,为流水线相邻 Stage 之间创建了额外的 Pair Group (_PP_PAIR_GROUP_DICT),可能是为了优化 P2P 通信。
  • 访问接口 : 提供 get_world_group(), get_tp_group(), get_pp_group(), get_dp_group(), get_ep_group(), get_tp_size(), get_dp_size(), get_ep_size() 等函数,方便全局访问并行状态信息和通信组。
  • 销毁 : destroy_parallel_groups() 负责销毁创建的通信组。

使用流程

  1. 在程序启动时,根据配置确定 TP, PP, DP, EP 的大小。
  2. 调用 initialize_parallel_groups 初始化所有并行通信组。
  3. 在模型代码或算子实现中,通过 get_tp_group(), get_ep_group() 等接口获取相应的 CommGroup
  4. 调用 CommGroup 实例提供的通信方法(如 tp_group.all_reduce(tensor))执行集合通信。

总结

「赤兔」的分布式并行管理模块设计清晰,通过 CommGroup 封装了底层的通信细节,并通过 parallel_state 模块提供了统一的初始化入口和全局访问接口。这种设计使得在代码中实现和管理复杂的混合并行策略(如 TP+PP+DP+EP)变得更加方便和规范。对 EP 组复用 TP/DP 组以及为 NPU 创建 PP Pair Group 的特殊处理,也体现了其在特定场景下的优化考虑。# 「赤兔」Chitu 框架深度解读(十二):分布式并行初始化与管理

大模型训练和推理通常需要在多个设备(GPU/NPU)上并行进行。「赤兔」Chitu 框架支持多种并行策略,包括张量并行 (TP)、流水线并行 (PP)、数据并行 (DP) 和专家并行 (EP)。其分布式并行环境的初始化和管理由 distributed/parallel_state.pydistributed/comm_group.py 模块负责。

核心概念:CommGroup

distributed/comm_group.py 定义了 CommGroup 类,它是对 PyTorch ProcessGroup 的封装和扩展。

  • 初始化 : CommGroup 根据传入的 rank_list (一个包含多个 Rank 列表的列表,每个子列表代表一个通信组) 和当前进程的全局 rank 来创建对应的 ProcessGroup
  • 关键属性 :
    • group: 底层的 PyTorch ProcessGroup
    • cpu_group: 对应的 CPU ProcessGroup (用于 CPU 上的集合通信)。
    • ranks_in_group: 当前 CommGroup 包含的所有 Rank 列表。
    • group_size: 当前进程所在通信组的大小。
    • rank_in_group: 当前进程在所在通信组内的局部 Rank。
    • is_first_rank/is_last_rank: 判断当前进程是否是组内的第一个/最后一个 Rank。
  • 通信操作封装 : 提供了对 torch.distributed 常用通信原语(如 broadcast, all_reduce, all_gather, reduce_scatter 等)的封装,自动传入正确的 group 参数。

CommGroup 的设计简化了在不同并行维度上进行通信的操作,使得上层代码无需手动管理多个 ProcessGroup 对象。

并行状态管理 (parallel_state.py)

distributed/parallel_state.py 负责初始化和维护不同并行维度的 CommGroup 实例,并提供全局访问接口。

  • 全局变量 : 定义了 _WORLD_GROUP, _TP_GROUP, _PP_GROUP, _DP_GROUP, _EP_GROUP 等全局变量,用于存储各个并行维度的 CommGroup 实例。
  • 初始化函数 (initialize_parallel_groups) : 这是并行设置的核心入口。
    • 输入 : TP, PP, DP, EP 的大小 (tp_size, pp_size, dp_size, ep_size)。
    • 获取环境信息 : 获取全局 rank, local_rank, world_size
    • 按序初始化 : 依次调用 initialize_world_group, initialize_tp_group, initialize_pp_group, initialize_dp_group, initialize_ep_group
    • 初始化逻辑 : 每个 initialize_*_group 函数根据并行维度的大小和当前 rank 计算出该维度对应的 rank_list,然后创建 CommGroup 实例并赋值给相应的全局变量。例如:
      • initialize_tp_group: world_size 被划分为 world_size // tp_size 个 TP 组,每个组包含 tp_size 个连续的 Rank。
      • initialize_pp_group: world_size 被划分为 world_size // pp_size 个 PP 组,每个组包含跨 TP 和 DP 维度、间隔为 num_pp_groups 的 Rank。
      • initialize_dp_group: 类似 PP 组的划分方式。
      • initialize_ep_group: 逻辑稍复杂。如果 ep_size > 1
        • tp_size == ep_sizedp_size == 1,则 EP 组直接复用 TP 组 (_EP_GROUP = _TP_GROUP)。
        • dp_size == ep_sizetp_size == 1,则 EP 组直接复用 DP 组 (_EP_GROUP = _DP_GROUP)。
        • 否则,创建新的 EP 通信组,通常是连续的 Rank 组成。
        • 如果 ep_size == 1,则每个 Rank 自己构成一个 EP 组。
    • 特殊处理 : initialize_pp_group 中包含了针对 Ascend NPU 的特殊处理,为流水线相邻 Stage 之间创建了额外的 Pair Group (_PP_PAIR_GROUP_DICT),可能是为了优化 P2P 通信。
  • 访问接口 : 提供 get_world_group(), get_tp_group(), get_pp_group(), get_dp_group(), get_ep_group(), get_tp_size(), get_dp_size(), get_ep_size() 等函数,方便全局访问并行状态信息和通信组。
  • 销毁 : destroy_parallel_groups() 负责销毁创建的通信组。

使用流程

  1. 在程序启动时,根据配置确定 TP, PP, DP, EP 的大小。
  2. 调用 initialize_parallel_groups 初始化所有并行通信组。
  3. 在模型代码或算子实现中,通过 get_tp_group(), get_ep_group() 等接口获取相应的 CommGroup
  4. 调用 CommGroup 实例提供的通信方法(如 tp_group.all_reduce(tensor))执行集合通信。

总结

「赤兔」的分布式并行管理模块设计清晰,通过 CommGroup 封装了底层的通信细节,并通过 parallel_state 模块提供了统一的初始化入口和全局访问接口。这种设计使得在代码中实现和管理复杂的混合并行策略(如 TP+PP+DP+EP)变得更加方便和规范。对 EP 组复用 TP/DP 组以及为 NPU 创建 PP Pair Group 的特殊处理,也体现了其在特定场景下的优化考虑。

相关推荐
IT_陈寒16 小时前
JavaScript性能优化:10个V8引擎隐藏技巧让你的代码快30%
前端·人工智能·后端
Dev7z16 小时前
基于图像处理技术的智能答题卡识别与评分系统设计与实现
图像处理·人工智能
掘金安东尼17 小时前
本地模型 + 云端模型的 Hybrid Inference 架构设计:下一代智能系统的底层范式
人工智能
强盛小灵通专卖员17 小时前
煤矿传送带异物检测:深度学习引领煤矿安全新革命!
人工智能·目标检测·sci·研究生·煤矿安全·延毕·传送带
学历真的很重要17 小时前
PyTorch 零基础入门:从张量到 GPU 加速完全指南
人工智能·pytorch·后端·深度学习·语言模型·职场和发展
mit6.82417 小时前
[Column] Perplexity 如何构建 AI 版 Google | 模型无关架构 | Vespa AI检索
人工智能
xier_ran17 小时前
深度学习:梯度检验(Gradient Checking)
人工智能·深度学习·梯度检验
尼古拉斯·纯情暖男·天真·阿玮17 小时前
基于卷积神经网络的手写数字识别
人工智能·神经网络·cnn
2401_8414956417 小时前
MoE算法深度解析:从理论架构到行业实践
人工智能·深度学习·机器学习·自然语言处理·大语言模型·moe·混合专家模型
kanimito17 小时前
大语言模型入门指南:从科普到实战的技术笔记(2)
人工智能·笔记·语言模型