「赤兔」Chitu 框架深度解读(十二):分布式并行初始化与管理

「赤兔」Chitu 框架深度解读(十二):分布式并行初始化与管理

大模型训练和推理通常需要在多个设备(GPU/NPU)上并行进行。「赤兔」Chitu 框架支持多种并行策略,包括张量并行 (TP)、流水线并行 (PP)、数据并行 (DP) 和专家并行 (EP)。其分布式并行环境的初始化和管理由 distributed/parallel_state.pydistributed/comm_group.py 模块负责。

核心概念:CommGroup

distributed/comm_group.py 定义了 CommGroup 类,它是对 PyTorch ProcessGroup 的封装和扩展。

  • 初始化 : CommGroup 根据传入的 rank_list (一个包含多个 Rank 列表的列表,每个子列表代表一个通信组) 和当前进程的全局 rank 来创建对应的 ProcessGroup
  • 关键属性 :
    • group: 底层的 PyTorch ProcessGroup
    • cpu_group: 对应的 CPU ProcessGroup (用于 CPU 上的集合通信)。
    • ranks_in_group: 当前 CommGroup 包含的所有 Rank 列表。
    • group_size: 当前进程所在通信组的大小。
    • rank_in_group: 当前进程在所在通信组内的局部 Rank。
    • is_first_rank/is_last_rank: 判断当前进程是否是组内的第一个/最后一个 Rank。
  • 通信操作封装 : 提供了对 torch.distributed 常用通信原语(如 broadcast, all_reduce, all_gather, reduce_scatter 等)的封装,自动传入正确的 group 参数。

CommGroup 的设计简化了在不同并行维度上进行通信的操作,使得上层代码无需手动管理多个 ProcessGroup 对象。

并行状态管理 (parallel_state.py)

distributed/parallel_state.py 负责初始化和维护不同并行维度的 CommGroup 实例,并提供全局访问接口。

  • 全局变量 : 定义了 _WORLD_GROUP, _TP_GROUP, _PP_GROUP, _DP_GROUP, _EP_GROUP 等全局变量,用于存储各个并行维度的 CommGroup 实例。
  • 初始化函数 (initialize_parallel_groups) : 这是并行设置的核心入口。
    • 输入 : TP, PP, DP, EP 的大小 (tp_size, pp_size, dp_size, ep_size)。
    • 获取环境信息 : 获取全局 rank, local_rank, world_size
    • 按序初始化 : 依次调用 initialize_world_group, initialize_tp_group, initialize_pp_group, initialize_dp_group, initialize_ep_group
    • 初始化逻辑 : 每个 initialize_*_group 函数根据并行维度的大小和当前 rank 计算出该维度对应的 rank_list,然后创建 CommGroup 实例并赋值给相应的全局变量。例如:
      • initialize_tp_group: world_size 被划分为 world_size // tp_size 个 TP 组,每个组包含 tp_size 个连续的 Rank。
      • initialize_pp_group: world_size 被划分为 world_size // pp_size 个 PP 组,每个组包含跨 TP 和 DP 维度、间隔为 num_pp_groups 的 Rank。
      • initialize_dp_group: 类似 PP 组的划分方式。
      • initialize_ep_group: 逻辑稍复杂。如果 ep_size > 1
        • tp_size == ep_sizedp_size == 1,则 EP 组直接复用 TP 组 (_EP_GROUP = _TP_GROUP)。
        • dp_size == ep_sizetp_size == 1,则 EP 组直接复用 DP 组 (_EP_GROUP = _DP_GROUP)。
        • 否则,创建新的 EP 通信组,通常是连续的 Rank 组成。
        • 如果 ep_size == 1,则每个 Rank 自己构成一个 EP 组。
    • 特殊处理 : initialize_pp_group 中包含了针对 Ascend NPU 的特殊处理,为流水线相邻 Stage 之间创建了额外的 Pair Group (_PP_PAIR_GROUP_DICT),可能是为了优化 P2P 通信。
  • 访问接口 : 提供 get_world_group(), get_tp_group(), get_pp_group(), get_dp_group(), get_ep_group(), get_tp_size(), get_dp_size(), get_ep_size() 等函数,方便全局访问并行状态信息和通信组。
  • 销毁 : destroy_parallel_groups() 负责销毁创建的通信组。

使用流程

  1. 在程序启动时,根据配置确定 TP, PP, DP, EP 的大小。
  2. 调用 initialize_parallel_groups 初始化所有并行通信组。
  3. 在模型代码或算子实现中,通过 get_tp_group(), get_ep_group() 等接口获取相应的 CommGroup
  4. 调用 CommGroup 实例提供的通信方法(如 tp_group.all_reduce(tensor))执行集合通信。

总结

「赤兔」的分布式并行管理模块设计清晰,通过 CommGroup 封装了底层的通信细节,并通过 parallel_state 模块提供了统一的初始化入口和全局访问接口。这种设计使得在代码中实现和管理复杂的混合并行策略(如 TP+PP+DP+EP)变得更加方便和规范。对 EP 组复用 TP/DP 组以及为 NPU 创建 PP Pair Group 的特殊处理,也体现了其在特定场景下的优化考虑。# 「赤兔」Chitu 框架深度解读(十二):分布式并行初始化与管理

大模型训练和推理通常需要在多个设备(GPU/NPU)上并行进行。「赤兔」Chitu 框架支持多种并行策略,包括张量并行 (TP)、流水线并行 (PP)、数据并行 (DP) 和专家并行 (EP)。其分布式并行环境的初始化和管理由 distributed/parallel_state.pydistributed/comm_group.py 模块负责。

核心概念:CommGroup

distributed/comm_group.py 定义了 CommGroup 类,它是对 PyTorch ProcessGroup 的封装和扩展。

  • 初始化 : CommGroup 根据传入的 rank_list (一个包含多个 Rank 列表的列表,每个子列表代表一个通信组) 和当前进程的全局 rank 来创建对应的 ProcessGroup
  • 关键属性 :
    • group: 底层的 PyTorch ProcessGroup
    • cpu_group: 对应的 CPU ProcessGroup (用于 CPU 上的集合通信)。
    • ranks_in_group: 当前 CommGroup 包含的所有 Rank 列表。
    • group_size: 当前进程所在通信组的大小。
    • rank_in_group: 当前进程在所在通信组内的局部 Rank。
    • is_first_rank/is_last_rank: 判断当前进程是否是组内的第一个/最后一个 Rank。
  • 通信操作封装 : 提供了对 torch.distributed 常用通信原语(如 broadcast, all_reduce, all_gather, reduce_scatter 等)的封装,自动传入正确的 group 参数。

CommGroup 的设计简化了在不同并行维度上进行通信的操作,使得上层代码无需手动管理多个 ProcessGroup 对象。

并行状态管理 (parallel_state.py)

distributed/parallel_state.py 负责初始化和维护不同并行维度的 CommGroup 实例,并提供全局访问接口。

  • 全局变量 : 定义了 _WORLD_GROUP, _TP_GROUP, _PP_GROUP, _DP_GROUP, _EP_GROUP 等全局变量,用于存储各个并行维度的 CommGroup 实例。
  • 初始化函数 (initialize_parallel_groups) : 这是并行设置的核心入口。
    • 输入 : TP, PP, DP, EP 的大小 (tp_size, pp_size, dp_size, ep_size)。
    • 获取环境信息 : 获取全局 rank, local_rank, world_size
    • 按序初始化 : 依次调用 initialize_world_group, initialize_tp_group, initialize_pp_group, initialize_dp_group, initialize_ep_group
    • 初始化逻辑 : 每个 initialize_*_group 函数根据并行维度的大小和当前 rank 计算出该维度对应的 rank_list,然后创建 CommGroup 实例并赋值给相应的全局变量。例如:
      • initialize_tp_group: world_size 被划分为 world_size // tp_size 个 TP 组,每个组包含 tp_size 个连续的 Rank。
      • initialize_pp_group: world_size 被划分为 world_size // pp_size 个 PP 组,每个组包含跨 TP 和 DP 维度、间隔为 num_pp_groups 的 Rank。
      • initialize_dp_group: 类似 PP 组的划分方式。
      • initialize_ep_group: 逻辑稍复杂。如果 ep_size > 1
        • tp_size == ep_sizedp_size == 1,则 EP 组直接复用 TP 组 (_EP_GROUP = _TP_GROUP)。
        • dp_size == ep_sizetp_size == 1,则 EP 组直接复用 DP 组 (_EP_GROUP = _DP_GROUP)。
        • 否则,创建新的 EP 通信组,通常是连续的 Rank 组成。
        • 如果 ep_size == 1,则每个 Rank 自己构成一个 EP 组。
    • 特殊处理 : initialize_pp_group 中包含了针对 Ascend NPU 的特殊处理,为流水线相邻 Stage 之间创建了额外的 Pair Group (_PP_PAIR_GROUP_DICT),可能是为了优化 P2P 通信。
  • 访问接口 : 提供 get_world_group(), get_tp_group(), get_pp_group(), get_dp_group(), get_ep_group(), get_tp_size(), get_dp_size(), get_ep_size() 等函数,方便全局访问并行状态信息和通信组。
  • 销毁 : destroy_parallel_groups() 负责销毁创建的通信组。

使用流程

  1. 在程序启动时,根据配置确定 TP, PP, DP, EP 的大小。
  2. 调用 initialize_parallel_groups 初始化所有并行通信组。
  3. 在模型代码或算子实现中,通过 get_tp_group(), get_ep_group() 等接口获取相应的 CommGroup
  4. 调用 CommGroup 实例提供的通信方法(如 tp_group.all_reduce(tensor))执行集合通信。

总结

「赤兔」的分布式并行管理模块设计清晰,通过 CommGroup 封装了底层的通信细节,并通过 parallel_state 模块提供了统一的初始化入口和全局访问接口。这种设计使得在代码中实现和管理复杂的混合并行策略(如 TP+PP+DP+EP)变得更加方便和规范。对 EP 组复用 TP/DP 组以及为 NPU 创建 PP Pair Group 的特殊处理,也体现了其在特定场景下的优化考虑。

相关推荐
后端小肥肠3 小时前
效率狂飙!n8n 无人值守工作流,每天自动把领域最新热点做成小红书卡片存本地
人工智能·agent·mcp
CoderLiu3 小时前
LLM API 成本的 3 个秘密:如何让服务商为你的复杂推理买单
人工智能·llm
siriuuus3 小时前
MySQL 慢查询日志及优化
mysql·1024程序员节
筵陌3 小时前
MYSQL表的操作
数据库·mysql·1024程序员节
AI人工智能+3 小时前
智能文本抽取:通过OCR、自然语言处理等多项技术,将非结构化文档转化为可读、可分析的数据资产
人工智能·nlp·ocr·文本抽取
这张生成的图像能检测吗3 小时前
(论文速读)Anyattack: 面向视觉语言模型的大规模自监督对抗性攻击
人工智能·语言模型·clip·视觉语言模型·对抗攻击
gorgeous(๑>؂<๑)3 小时前
【DeepSeek-OCR系列第一篇】Language Modelling with Pixels【ICLR23】
人工智能·语言模型·自然语言处理·ocr
开放知识图谱3 小时前
论文浅尝 | LightPROF:一种轻量级推理框架,用于大型语言模型在知识图谱上的应用(AAAI2025)
人工智能·语言模型·自然语言处理·知识图谱
vlln3 小时前
【论文速读】LLM+AL: 用符号逻辑校准语言模型的规划能力
人工智能·语言模型·自然语言处理