「赤兔」Chitu 框架深度解读(十二):分布式并行初始化与管理

「赤兔」Chitu 框架深度解读(十二):分布式并行初始化与管理

大模型训练和推理通常需要在多个设备(GPU/NPU)上并行进行。「赤兔」Chitu 框架支持多种并行策略,包括张量并行 (TP)、流水线并行 (PP)、数据并行 (DP) 和专家并行 (EP)。其分布式并行环境的初始化和管理由 distributed/parallel_state.pydistributed/comm_group.py 模块负责。

核心概念:CommGroup

distributed/comm_group.py 定义了 CommGroup 类,它是对 PyTorch ProcessGroup 的封装和扩展。

  • 初始化 : CommGroup 根据传入的 rank_list (一个包含多个 Rank 列表的列表,每个子列表代表一个通信组) 和当前进程的全局 rank 来创建对应的 ProcessGroup
  • 关键属性 :
    • group: 底层的 PyTorch ProcessGroup
    • cpu_group: 对应的 CPU ProcessGroup (用于 CPU 上的集合通信)。
    • ranks_in_group: 当前 CommGroup 包含的所有 Rank 列表。
    • group_size: 当前进程所在通信组的大小。
    • rank_in_group: 当前进程在所在通信组内的局部 Rank。
    • is_first_rank/is_last_rank: 判断当前进程是否是组内的第一个/最后一个 Rank。
  • 通信操作封装 : 提供了对 torch.distributed 常用通信原语(如 broadcast, all_reduce, all_gather, reduce_scatter 等)的封装,自动传入正确的 group 参数。

CommGroup 的设计简化了在不同并行维度上进行通信的操作,使得上层代码无需手动管理多个 ProcessGroup 对象。

并行状态管理 (parallel_state.py)

distributed/parallel_state.py 负责初始化和维护不同并行维度的 CommGroup 实例,并提供全局访问接口。

  • 全局变量 : 定义了 _WORLD_GROUP, _TP_GROUP, _PP_GROUP, _DP_GROUP, _EP_GROUP 等全局变量,用于存储各个并行维度的 CommGroup 实例。
  • 初始化函数 (initialize_parallel_groups) : 这是并行设置的核心入口。
    • 输入 : TP, PP, DP, EP 的大小 (tp_size, pp_size, dp_size, ep_size)。
    • 获取环境信息 : 获取全局 rank, local_rank, world_size
    • 按序初始化 : 依次调用 initialize_world_group, initialize_tp_group, initialize_pp_group, initialize_dp_group, initialize_ep_group
    • 初始化逻辑 : 每个 initialize_*_group 函数根据并行维度的大小和当前 rank 计算出该维度对应的 rank_list,然后创建 CommGroup 实例并赋值给相应的全局变量。例如:
      • initialize_tp_group: world_size 被划分为 world_size // tp_size 个 TP 组,每个组包含 tp_size 个连续的 Rank。
      • initialize_pp_group: world_size 被划分为 world_size // pp_size 个 PP 组,每个组包含跨 TP 和 DP 维度、间隔为 num_pp_groups 的 Rank。
      • initialize_dp_group: 类似 PP 组的划分方式。
      • initialize_ep_group: 逻辑稍复杂。如果 ep_size > 1
        • tp_size == ep_sizedp_size == 1,则 EP 组直接复用 TP 组 (_EP_GROUP = _TP_GROUP)。
        • dp_size == ep_sizetp_size == 1,则 EP 组直接复用 DP 组 (_EP_GROUP = _DP_GROUP)。
        • 否则,创建新的 EP 通信组,通常是连续的 Rank 组成。
        • 如果 ep_size == 1,则每个 Rank 自己构成一个 EP 组。
    • 特殊处理 : initialize_pp_group 中包含了针对 Ascend NPU 的特殊处理,为流水线相邻 Stage 之间创建了额外的 Pair Group (_PP_PAIR_GROUP_DICT),可能是为了优化 P2P 通信。
  • 访问接口 : 提供 get_world_group(), get_tp_group(), get_pp_group(), get_dp_group(), get_ep_group(), get_tp_size(), get_dp_size(), get_ep_size() 等函数,方便全局访问并行状态信息和通信组。
  • 销毁 : destroy_parallel_groups() 负责销毁创建的通信组。

使用流程

  1. 在程序启动时,根据配置确定 TP, PP, DP, EP 的大小。
  2. 调用 initialize_parallel_groups 初始化所有并行通信组。
  3. 在模型代码或算子实现中,通过 get_tp_group(), get_ep_group() 等接口获取相应的 CommGroup
  4. 调用 CommGroup 实例提供的通信方法(如 tp_group.all_reduce(tensor))执行集合通信。

总结

「赤兔」的分布式并行管理模块设计清晰,通过 CommGroup 封装了底层的通信细节,并通过 parallel_state 模块提供了统一的初始化入口和全局访问接口。这种设计使得在代码中实现和管理复杂的混合并行策略(如 TP+PP+DP+EP)变得更加方便和规范。对 EP 组复用 TP/DP 组以及为 NPU 创建 PP Pair Group 的特殊处理,也体现了其在特定场景下的优化考虑。# 「赤兔」Chitu 框架深度解读(十二):分布式并行初始化与管理

大模型训练和推理通常需要在多个设备(GPU/NPU)上并行进行。「赤兔」Chitu 框架支持多种并行策略,包括张量并行 (TP)、流水线并行 (PP)、数据并行 (DP) 和专家并行 (EP)。其分布式并行环境的初始化和管理由 distributed/parallel_state.pydistributed/comm_group.py 模块负责。

核心概念:CommGroup

distributed/comm_group.py 定义了 CommGroup 类,它是对 PyTorch ProcessGroup 的封装和扩展。

  • 初始化 : CommGroup 根据传入的 rank_list (一个包含多个 Rank 列表的列表,每个子列表代表一个通信组) 和当前进程的全局 rank 来创建对应的 ProcessGroup
  • 关键属性 :
    • group: 底层的 PyTorch ProcessGroup
    • cpu_group: 对应的 CPU ProcessGroup (用于 CPU 上的集合通信)。
    • ranks_in_group: 当前 CommGroup 包含的所有 Rank 列表。
    • group_size: 当前进程所在通信组的大小。
    • rank_in_group: 当前进程在所在通信组内的局部 Rank。
    • is_first_rank/is_last_rank: 判断当前进程是否是组内的第一个/最后一个 Rank。
  • 通信操作封装 : 提供了对 torch.distributed 常用通信原语(如 broadcast, all_reduce, all_gather, reduce_scatter 等)的封装,自动传入正确的 group 参数。

CommGroup 的设计简化了在不同并行维度上进行通信的操作,使得上层代码无需手动管理多个 ProcessGroup 对象。

并行状态管理 (parallel_state.py)

distributed/parallel_state.py 负责初始化和维护不同并行维度的 CommGroup 实例,并提供全局访问接口。

  • 全局变量 : 定义了 _WORLD_GROUP, _TP_GROUP, _PP_GROUP, _DP_GROUP, _EP_GROUP 等全局变量,用于存储各个并行维度的 CommGroup 实例。
  • 初始化函数 (initialize_parallel_groups) : 这是并行设置的核心入口。
    • 输入 : TP, PP, DP, EP 的大小 (tp_size, pp_size, dp_size, ep_size)。
    • 获取环境信息 : 获取全局 rank, local_rank, world_size
    • 按序初始化 : 依次调用 initialize_world_group, initialize_tp_group, initialize_pp_group, initialize_dp_group, initialize_ep_group
    • 初始化逻辑 : 每个 initialize_*_group 函数根据并行维度的大小和当前 rank 计算出该维度对应的 rank_list,然后创建 CommGroup 实例并赋值给相应的全局变量。例如:
      • initialize_tp_group: world_size 被划分为 world_size // tp_size 个 TP 组,每个组包含 tp_size 个连续的 Rank。
      • initialize_pp_group: world_size 被划分为 world_size // pp_size 个 PP 组,每个组包含跨 TP 和 DP 维度、间隔为 num_pp_groups 的 Rank。
      • initialize_dp_group: 类似 PP 组的划分方式。
      • initialize_ep_group: 逻辑稍复杂。如果 ep_size > 1
        • tp_size == ep_sizedp_size == 1,则 EP 组直接复用 TP 组 (_EP_GROUP = _TP_GROUP)。
        • dp_size == ep_sizetp_size == 1,则 EP 组直接复用 DP 组 (_EP_GROUP = _DP_GROUP)。
        • 否则,创建新的 EP 通信组,通常是连续的 Rank 组成。
        • 如果 ep_size == 1,则每个 Rank 自己构成一个 EP 组。
    • 特殊处理 : initialize_pp_group 中包含了针对 Ascend NPU 的特殊处理,为流水线相邻 Stage 之间创建了额外的 Pair Group (_PP_PAIR_GROUP_DICT),可能是为了优化 P2P 通信。
  • 访问接口 : 提供 get_world_group(), get_tp_group(), get_pp_group(), get_dp_group(), get_ep_group(), get_tp_size(), get_dp_size(), get_ep_size() 等函数,方便全局访问并行状态信息和通信组。
  • 销毁 : destroy_parallel_groups() 负责销毁创建的通信组。

使用流程

  1. 在程序启动时,根据配置确定 TP, PP, DP, EP 的大小。
  2. 调用 initialize_parallel_groups 初始化所有并行通信组。
  3. 在模型代码或算子实现中,通过 get_tp_group(), get_ep_group() 等接口获取相应的 CommGroup
  4. 调用 CommGroup 实例提供的通信方法(如 tp_group.all_reduce(tensor))执行集合通信。

总结

「赤兔」的分布式并行管理模块设计清晰,通过 CommGroup 封装了底层的通信细节,并通过 parallel_state 模块提供了统一的初始化入口和全局访问接口。这种设计使得在代码中实现和管理复杂的混合并行策略(如 TP+PP+DP+EP)变得更加方便和规范。对 EP 组复用 TP/DP 组以及为 NPU 创建 PP Pair Group 的特殊处理,也体现了其在特定场景下的优化考虑。

相关推荐
2501_9481201517 分钟前
区块链与人工智能融合的隐私保护技术
人工智能·区块链
Liue612312315 小时前
基于YOLOv26的口罩佩戴检测与识别系统实现与优化
人工智能·yolo·目标跟踪
小二·6 小时前
Python Web 开发进阶实战 :AI 原生数字孪生 —— 在 Flask + Three.js 中构建物理世界实时仿真与优化平台
前端·人工智能·python
chinesegf7 小时前
文本嵌入模型的比较(一)
人工智能·算法·机器学习
珠海西格电力7 小时前
零碳园区的能源结构优化需要哪些技术支持?
大数据·人工智能·物联网·架构·能源
Black蜡笔小新7 小时前
视频汇聚平台EasyCVR打造校园消防智能监管新防线
网络·人工智能·音视频
珠海西格电力科技7 小时前
双碳目标下,微电网为何成为能源转型核心载体?
网络·人工智能·物联网·云计算·智慧城市·能源
2501_941837267 小时前
【计算机视觉】基于YOLOv26的交通事故检测与交通状况分析系统详解_1
人工智能·yolo·计算机视觉
HyperAI超神经7 小时前
加州大学构建基于全连接神经网络的片上光谱仪,在芯片级尺寸上实现8纳米的光谱分辨率
人工智能·深度学习·神经网络·机器学习·ai编程
badfl7 小时前
AI漫剧技术方案拆解:NanoBanana+Sora视频生成全流程
人工智能·ai·ai作画