「赤兔」Chitu 框架深度解读(十二):分布式并行初始化与管理

「赤兔」Chitu 框架深度解读(十二):分布式并行初始化与管理

大模型训练和推理通常需要在多个设备(GPU/NPU)上并行进行。「赤兔」Chitu 框架支持多种并行策略,包括张量并行 (TP)、流水线并行 (PP)、数据并行 (DP) 和专家并行 (EP)。其分布式并行环境的初始化和管理由 distributed/parallel_state.pydistributed/comm_group.py 模块负责。

核心概念:CommGroup

distributed/comm_group.py 定义了 CommGroup 类,它是对 PyTorch ProcessGroup 的封装和扩展。

  • 初始化 : CommGroup 根据传入的 rank_list (一个包含多个 Rank 列表的列表,每个子列表代表一个通信组) 和当前进程的全局 rank 来创建对应的 ProcessGroup
  • 关键属性 :
    • group: 底层的 PyTorch ProcessGroup
    • cpu_group: 对应的 CPU ProcessGroup (用于 CPU 上的集合通信)。
    • ranks_in_group: 当前 CommGroup 包含的所有 Rank 列表。
    • group_size: 当前进程所在通信组的大小。
    • rank_in_group: 当前进程在所在通信组内的局部 Rank。
    • is_first_rank/is_last_rank: 判断当前进程是否是组内的第一个/最后一个 Rank。
  • 通信操作封装 : 提供了对 torch.distributed 常用通信原语(如 broadcast, all_reduce, all_gather, reduce_scatter 等)的封装,自动传入正确的 group 参数。

CommGroup 的设计简化了在不同并行维度上进行通信的操作,使得上层代码无需手动管理多个 ProcessGroup 对象。

并行状态管理 (parallel_state.py)

distributed/parallel_state.py 负责初始化和维护不同并行维度的 CommGroup 实例,并提供全局访问接口。

  • 全局变量 : 定义了 _WORLD_GROUP, _TP_GROUP, _PP_GROUP, _DP_GROUP, _EP_GROUP 等全局变量,用于存储各个并行维度的 CommGroup 实例。
  • 初始化函数 (initialize_parallel_groups) : 这是并行设置的核心入口。
    • 输入 : TP, PP, DP, EP 的大小 (tp_size, pp_size, dp_size, ep_size)。
    • 获取环境信息 : 获取全局 rank, local_rank, world_size
    • 按序初始化 : 依次调用 initialize_world_group, initialize_tp_group, initialize_pp_group, initialize_dp_group, initialize_ep_group
    • 初始化逻辑 : 每个 initialize_*_group 函数根据并行维度的大小和当前 rank 计算出该维度对应的 rank_list,然后创建 CommGroup 实例并赋值给相应的全局变量。例如:
      • initialize_tp_group: world_size 被划分为 world_size // tp_size 个 TP 组,每个组包含 tp_size 个连续的 Rank。
      • initialize_pp_group: world_size 被划分为 world_size // pp_size 个 PP 组,每个组包含跨 TP 和 DP 维度、间隔为 num_pp_groups 的 Rank。
      • initialize_dp_group: 类似 PP 组的划分方式。
      • initialize_ep_group: 逻辑稍复杂。如果 ep_size > 1
        • tp_size == ep_sizedp_size == 1,则 EP 组直接复用 TP 组 (_EP_GROUP = _TP_GROUP)。
        • dp_size == ep_sizetp_size == 1,则 EP 组直接复用 DP 组 (_EP_GROUP = _DP_GROUP)。
        • 否则,创建新的 EP 通信组,通常是连续的 Rank 组成。
        • 如果 ep_size == 1,则每个 Rank 自己构成一个 EP 组。
    • 特殊处理 : initialize_pp_group 中包含了针对 Ascend NPU 的特殊处理,为流水线相邻 Stage 之间创建了额外的 Pair Group (_PP_PAIR_GROUP_DICT),可能是为了优化 P2P 通信。
  • 访问接口 : 提供 get_world_group(), get_tp_group(), get_pp_group(), get_dp_group(), get_ep_group(), get_tp_size(), get_dp_size(), get_ep_size() 等函数,方便全局访问并行状态信息和通信组。
  • 销毁 : destroy_parallel_groups() 负责销毁创建的通信组。

使用流程

  1. 在程序启动时,根据配置确定 TP, PP, DP, EP 的大小。
  2. 调用 initialize_parallel_groups 初始化所有并行通信组。
  3. 在模型代码或算子实现中,通过 get_tp_group(), get_ep_group() 等接口获取相应的 CommGroup
  4. 调用 CommGroup 实例提供的通信方法(如 tp_group.all_reduce(tensor))执行集合通信。

总结

「赤兔」的分布式并行管理模块设计清晰,通过 CommGroup 封装了底层的通信细节,并通过 parallel_state 模块提供了统一的初始化入口和全局访问接口。这种设计使得在代码中实现和管理复杂的混合并行策略(如 TP+PP+DP+EP)变得更加方便和规范。对 EP 组复用 TP/DP 组以及为 NPU 创建 PP Pair Group 的特殊处理,也体现了其在特定场景下的优化考虑。# 「赤兔」Chitu 框架深度解读(十二):分布式并行初始化与管理

大模型训练和推理通常需要在多个设备(GPU/NPU)上并行进行。「赤兔」Chitu 框架支持多种并行策略,包括张量并行 (TP)、流水线并行 (PP)、数据并行 (DP) 和专家并行 (EP)。其分布式并行环境的初始化和管理由 distributed/parallel_state.pydistributed/comm_group.py 模块负责。

核心概念:CommGroup

distributed/comm_group.py 定义了 CommGroup 类,它是对 PyTorch ProcessGroup 的封装和扩展。

  • 初始化 : CommGroup 根据传入的 rank_list (一个包含多个 Rank 列表的列表,每个子列表代表一个通信组) 和当前进程的全局 rank 来创建对应的 ProcessGroup
  • 关键属性 :
    • group: 底层的 PyTorch ProcessGroup
    • cpu_group: 对应的 CPU ProcessGroup (用于 CPU 上的集合通信)。
    • ranks_in_group: 当前 CommGroup 包含的所有 Rank 列表。
    • group_size: 当前进程所在通信组的大小。
    • rank_in_group: 当前进程在所在通信组内的局部 Rank。
    • is_first_rank/is_last_rank: 判断当前进程是否是组内的第一个/最后一个 Rank。
  • 通信操作封装 : 提供了对 torch.distributed 常用通信原语(如 broadcast, all_reduce, all_gather, reduce_scatter 等)的封装,自动传入正确的 group 参数。

CommGroup 的设计简化了在不同并行维度上进行通信的操作,使得上层代码无需手动管理多个 ProcessGroup 对象。

并行状态管理 (parallel_state.py)

distributed/parallel_state.py 负责初始化和维护不同并行维度的 CommGroup 实例,并提供全局访问接口。

  • 全局变量 : 定义了 _WORLD_GROUP, _TP_GROUP, _PP_GROUP, _DP_GROUP, _EP_GROUP 等全局变量,用于存储各个并行维度的 CommGroup 实例。
  • 初始化函数 (initialize_parallel_groups) : 这是并行设置的核心入口。
    • 输入 : TP, PP, DP, EP 的大小 (tp_size, pp_size, dp_size, ep_size)。
    • 获取环境信息 : 获取全局 rank, local_rank, world_size
    • 按序初始化 : 依次调用 initialize_world_group, initialize_tp_group, initialize_pp_group, initialize_dp_group, initialize_ep_group
    • 初始化逻辑 : 每个 initialize_*_group 函数根据并行维度的大小和当前 rank 计算出该维度对应的 rank_list,然后创建 CommGroup 实例并赋值给相应的全局变量。例如:
      • initialize_tp_group: world_size 被划分为 world_size // tp_size 个 TP 组,每个组包含 tp_size 个连续的 Rank。
      • initialize_pp_group: world_size 被划分为 world_size // pp_size 个 PP 组,每个组包含跨 TP 和 DP 维度、间隔为 num_pp_groups 的 Rank。
      • initialize_dp_group: 类似 PP 组的划分方式。
      • initialize_ep_group: 逻辑稍复杂。如果 ep_size > 1
        • tp_size == ep_sizedp_size == 1,则 EP 组直接复用 TP 组 (_EP_GROUP = _TP_GROUP)。
        • dp_size == ep_sizetp_size == 1,则 EP 组直接复用 DP 组 (_EP_GROUP = _DP_GROUP)。
        • 否则,创建新的 EP 通信组,通常是连续的 Rank 组成。
        • 如果 ep_size == 1,则每个 Rank 自己构成一个 EP 组。
    • 特殊处理 : initialize_pp_group 中包含了针对 Ascend NPU 的特殊处理,为流水线相邻 Stage 之间创建了额外的 Pair Group (_PP_PAIR_GROUP_DICT),可能是为了优化 P2P 通信。
  • 访问接口 : 提供 get_world_group(), get_tp_group(), get_pp_group(), get_dp_group(), get_ep_group(), get_tp_size(), get_dp_size(), get_ep_size() 等函数,方便全局访问并行状态信息和通信组。
  • 销毁 : destroy_parallel_groups() 负责销毁创建的通信组。

使用流程

  1. 在程序启动时,根据配置确定 TP, PP, DP, EP 的大小。
  2. 调用 initialize_parallel_groups 初始化所有并行通信组。
  3. 在模型代码或算子实现中,通过 get_tp_group(), get_ep_group() 等接口获取相应的 CommGroup
  4. 调用 CommGroup 实例提供的通信方法(如 tp_group.all_reduce(tensor))执行集合通信。

总结

「赤兔」的分布式并行管理模块设计清晰,通过 CommGroup 封装了底层的通信细节,并通过 parallel_state 模块提供了统一的初始化入口和全局访问接口。这种设计使得在代码中实现和管理复杂的混合并行策略(如 TP+PP+DP+EP)变得更加方便和规范。对 EP 组复用 TP/DP 组以及为 NPU 创建 PP Pair Group 的特殊处理,也体现了其在特定场景下的优化考虑。

相关推荐
NAGNIP7 小时前
一文搞懂深度学习中的通用逼近定理!
人工智能·算法·面试
冬奇Lab8 小时前
一天一个开源项目(第36篇):EverMemOS - 跨 LLM 与平台的长时记忆 OS,让 Agent 会记忆更会推理
人工智能·开源·资讯
冬奇Lab8 小时前
OpenClaw 源码深度解析(一):Gateway——为什么需要一个"中枢"
人工智能·开源·源码阅读
AngelPP12 小时前
OpenClaw 架构深度解析:如何把 AI 助手搬到你的个人设备上
人工智能
宅小年12 小时前
Claude Code 换成了Kimi K2.5后,我再也回不去了
人工智能·ai编程·claude
九狼12 小时前
Flutter URL Scheme 跨平台跳转
人工智能·flutter·github
ZFSS12 小时前
Kimi Chat Completion API 申请及使用
前端·人工智能
天翼云开发者社区13 小时前
春节复工福利就位!天翼云息壤2500万Tokens免费送,全品类大模型一键畅玩!
人工智能·算力服务·息壤
知识浅谈13 小时前
教你如何用 Gemini 将课本图片一键转为精美 PPT
人工智能
Ray Liang14 小时前
被低估的量化版模型,小身材也能干大事
人工智能·ai·ai助手·mindx