【NCCL】8 PAT AllGather 设备端实现详解3

PatAGAlgorithm::getNextOp() 详解

PatAGAlgorithm::getNextOp() 函数的算法规划是 PAT (Pipelined Allgather Tree) 算法的核心调度器。

算法概述

PAT AllGather 使用**二叉树(binomial tree)结构,通过 聚合(aggregation)流水线(pipelining)**优化,在 O(log N) 步骤内完成 AllGather 操作。

关键成员变量

状态变量

cpp 复制代码
int a;              // 当前聚合步骤内的子步骤索引
int as;             // 聚合步骤索引(aggregated steps)
int phase;          // 当前阶段(0, 1, 2)
int scale;          // 缩放因子(用于阶段2)
size_t offset;      // 当前数据块偏移

配置参数

cpp 复制代码
int aggFactor;      // 聚合因子(决定聚合多少步骤)
int aggDelta;       // 聚合增量(nrPow2 / aggFactor)
int parallelFactor; // 并行因子(工作组数)
int postFreq;       // 发布频率(等于 aggFactor)
int nrPow2;         // >= nranks 的最小 2 的幂

算法初始化

构造函数 (PatAGAlgorithm())

cpp 复制代码
PatAGAlgorithm(int stepSize, int stepDepth, int maxParallelFactor, 
               size_t offset, size_t end, size_t count, int chunkCount, 
               int rank, int nranks) {
  parallelFactor = maxParallelFactor;
  aggDelta = nrPow2 = (1<<log2Up(nranks));  // 例如:nranks=6 → nrPow2=8
  
  // 计算聚合因子(基于缓冲区大小)
  aggFactor = 1;
  size_t channelSize = end-offset;
  while (stepSize / (channelSize*sizeof(T)*aggFactor) >= 2 && aggFactor < nranks/2) {
    aggFactor *= 2;  // 增加聚合
    aggDelta /= 2;   // 减少增量
  }
  postFreq = aggFactor;
  if (postFreq < parallelFactor) parallelFactor = postFreq;
  
  // 根据步骤深度进一步调整聚合
  int d = stepDepth;
  while (d > 1 && aggFactor < nranks/2) {
    d /= 2;
    aggFactor *= 2;
    aggDelta /= 2;
  }
  
  asDim = log2Up(aggDelta);  // as 的维度
  reset();
}

聚合因子的作用

  • 将多个小步骤聚合成一个大步骤
  • 减少同步开销
  • 提高缓冲区利用率

三个阶段(Phase)

PAT AllGather 分为三个主要阶段:

Phase 0: 初始接收阶段

cpp 复制代码
else if (phase == 0) {
  int s = a*aggDelta + as;  // 计算距离
  if (s >= nranks) skip = 1;
  
  int recvDataRank = (rank + s) % nranks;  // 接收数据的源 rank
  ps->outIx = recvDataRank * count + offset;  // 输出位置
  
  // 只接收,不发送
  ps->sendDim = -1;
  ps->recvDim = 0;  // 从维度 0(距离 1)接收
  ps->inpIx = 0;
  ps->sendOffset = -1;
  ps->recvOffset = (a % postFreq) * nelem;  // 接收缓冲区偏移
  ps->stepOffset = 0;
  
  // 决定是否发布接收完成
  ps->postRecv = (a % postFreq == postFreq-1) || ((a+1)*aggDelta+as >= nranks) ? 1 : 0;
  ps->postSend = 0;
}

Phase 0 的目的

  • 从最近的邻居(距离 1)接收初始数据
  • 为后续的二叉树传输做准备

Phase 1: 主要二叉树传输阶段

cpp 复制代码
else if (phase == 1) {
  int s = a*aggDelta + as;  // 计算距离
  if (s >= nranks) skip = 1;
  
  // 计算发送维度(log2(距离))
  ps->sendDim = firstBitSet(s, nrPow2);  // 例如:s=4 → sendDim=2
  s -= (1<<ps->sendDim);  // 减去发送距离
  
  int sendDataRank = (rank + nranks + s) % nranks;
  ps->outIx = sendDataRank * count + offset;
  
  // 计算接收维度
  ps->recvDim = s ? firstBitSet(s, nrPow2) : -1;
  
  // 设置偏移和发布标志
  ps->sendOffset = ps->recvOffset = (a % postFreq) * nelem;
  ps->postSend = (a % postFreq == postFreq-1) || ((a+1)*aggDelta+as >= nranks) ? 1 : 0;
  ps->postRecv = (ps->sendDim == 0) && ((a % postFreq == postFreq-1) || ((a+1)*aggDelta+as-1 >= nranks)) ? 1 : 0;
  ps->stepOffset = (ps->sendDim == 0) ? 0 : a/postFreq;
  
  // 特殊情况:接收维度为 -1
  if (ps->recvDim == -1) {
    ps->recvOffset = -1;
    ps->postRecv = 0;
  } else if (as - (1<<ps->sendDim) == 0) {
    // 调整接收偏移
    int foffset = (a*aggDelta) >> (ps->recvDim+1);
    ps->recvOffset = (foffset%postFreq)*nelem;
    ps->postRecv = (ps->sendDim == 0) && ((foffset % postFreq == postFreq-1) || ((((foffset+1)*2)+1)<<ps->recvDim) >= nranks) ? 1 : 0;
    ps->stepOffset = (ps->sendDim == 0) ? 0 : foffset/postFreq;
  }
  
  // 边界情况:确保至少接收一次
  if (s < nranks && ps->sendDim == 0 && skip) {
    ps->sendDim = -1;
    ps->sendOffset = -1;
    ps->postSend = 0;
    skip = 0;
  }
}

Phase 1 的核心逻辑

  • 同时进行接收和发送
  • 使用二叉树结构(通过 firstBitSet 计算维度)
  • 支持聚合多个步骤

Phase 2: 缩放传输阶段

cpp 复制代码
else if (phase == 2) {
  int s = (2*a+1)*scale*aggDelta;  // 计算距离(奇数倍)
  ps->postSend = (a % postFreq == postFreq-1) || ((2*(a+1)+1)*scale*aggDelta >= nranks) ? 1 : 0;
  ps->postRecv = 0;
  
  if (s >= nranks) skip = 1;
  
  ps->sendDim = firstBitSet(s, nrPow2);
  s -= (1<<ps->sendDim);
  ps->sendOffset = (a%postFreq) * nelem;
  ps->stepOffset = a / postFreq;
  
  int sendDataRank = (rank + nranks + s) % nranks;
  ps->outIx = sendDataRank * count + offset;
  
  ps->recvDim = s ? firstBitSet(s, nrPow2) : -1;
  if (ps->recvDim == -1) {
    ps->recvOffset = -1;
  } else {
    s -= (1<<ps->recvDim);
    int foffset = (a*2*scale*aggDelta) >> (ps->recvDim+1);
    ps->recvOffset = (foffset%postFreq)*nelem;
    ps->stepOffset = foffset / postFreq;
  }
}

Phase 2 的目的

  • 处理更大距离的传输(奇数倍)
  • 使用缩放因子 scale 逐步减小
  • 填补二叉树的空隙

阶段转换逻辑

阶段转换条件 (getNextOp())

cpp 复制代码
a++;  // 递增子步骤
if (a >= lastA && a >= parallelFactor) {
  int p = phase;
  if (p == 2) scale /= 2;  // Phase 2 时缩小 scale
  
  // 阶段转换状态机
  phase =
    p == 2 ? scale ? 2 : 1 :           // Phase 2: scale>0 继续 Phase 2,否则转 Phase 1
    p == 1 ? as % 2 == 1 ? 0 : 1 :     // Phase 1: as 奇数转 Phase 0,偶数继续 Phase 1
    1;                                  // Phase 0: 转 Phase 1
  
  // 更新 as(聚合步骤索引)
  if (p == 0 || (p == 1 && as % 2 == 0)) 
    as = nextAs();  // 计算下一个 as 值
  
  // 检查是否完成所有数据
  if (p == 0 && as == aggDelta/2) {
    offset += chunkCount;  // 移动到下一个数据块
    if (offset >= end) {
      ps->last = 2;  // 全局完成
    } else {
      reset();  // 重置,处理下一个数据块
    }
  } else {
    resetA();  // 重置 a,继续当前数据块
  }
}

提前结束条件

cpp 复制代码
else if (phase == 0 && as == 1 && offset + chunkCount >= end && 
         a-1 >= ((lastA-1) / parallelFactor) * parallelFactor) {
  ps->last = 1;  // 本组完成(但其他组可能还在运行)
}

关键辅助函数

1. nextAs() - 计算下一个聚合步骤索引

cpp 复制代码
__device__ __host__ int nextAs() {
  for (int d=0; d<asDim; d++) {
    int p = 1<<d;  // 2^d
    bitCount[d]--;
    if (bitCount[d] == 0) {
      v ^= p;  // 翻转第 d 位
      bitCount[d] = p;  // 重置计数
      if ((v&p) == 0) {  // 如果第 d 位变为 0
        bitCount[d] += firstBitSet(bitZeroStep[d], asDim) - 1;
        if (bitCount[d] == 0) {
          v ^= p;
          bitCount[d] = p;
        }
        bitZeroStep[d]++;
      }
    }
  }
  return v;
}

nextAs() 的作用

  • 生成一个特殊的序列,用于遍历所有可能的 as
  • 使用位操作和计数器实现高效的序列生成
  • 确保覆盖所有必要的通信模式

2. mirror() - 位镜像

cpp 复制代码
__device__ __host__ int mirror(int i, int max) {
  int ret = 0;
  for (int mask=1, imask=max/2; mask<max; mask<<=1, imask>>=1) {
    if ((i&mask)) ret += imask;  // 如果第 k 位为 1,则在结果的 (log2(max)-k-1) 位置 1
  }
  return ret;
}

示例

复制代码
max = 8 (二进制: 1000)
i = 3 (二进制: 011)
mirror(3, 8) = 6 (二进制: 110)

解释:
  i:   0 1 1
       ↓ ↓ ↓
  ret: 1 1 0  = 6

3. firstBitSet() - 找到最低位的 1

cpp 复制代码
__device__ __host__ int firstBitSet(int i, int max) {
  int ffs = __ffs(i);  // CUDA 内置函数,返回最低位 1 的位置(1-based)
  return ffs ? ffs-1 : max;  // 转换为 0-based
}

示例

复制代码
i = 12 (二进制: 1100)
firstBitSet(12, 32) = 2  // 第 2 位(从 0 开始)是最低位的 1

算法执行流程示例

示例:8 个节点,aggFactor=1(无聚合)

初始化
复制代码
nranks = 8
nrPow2 = 8
aggFactor = 1
aggDelta = 8
parallelFactor = 1
Rank 0 的执行序列

Phase 0 (初始接收):

复制代码
Step 0: s=0, recvDataRank=0, recvDim=0
  → 从 Rank 1 接收数据(距离 1)

Phase 1 (二叉树传输):

复制代码
Step 1: s=1, sendDim=0, recvDim=-1
  → 发送到 Rank 7(距离 1),无接收

Step 2: s=2, sendDim=1, recvDim=0
  → 发送到 Rank 6(距离 2),从 Rank 2 接收(距离 2)

Step 3: s=3, sendDim=0, recvDim=1
  → 发送到 Rank 5(距离 1),从 Rank 3 接收(距离 3)

Step 4: s=4, sendDim=2, recvDim=0
  → 发送到 Rank 4(距离 4),从 Rank 4 接收(距离 4)

Step 5: s=5, sendDim=0, recvDim=2
  → 发送到 Rank 3(距离 1),从 Rank 5 接收(距离 5)

Step 6: s=6, sendDim=1, recvDim=0
  → 发送到 Rank 2(距离 2),从 Rank 6 接收(距离 6)

Step 7: s=7, sendDim=0, recvDim=1
  → 发送到 Rank 1(距离 1),从 Rank 7 接收(距离 7)

距离计算公式

Phase 0
cpp 复制代码
s = a*aggDelta + as
recvPeer = (rank + s) % nranks
Phase 1
cpp 复制代码
s = a*aggDelta + as
sendDim = firstBitSet(s, nrPow2)  // log2(最低位的 1)
sendPeer = (rank - (1<<sendDim)) % nranks
recvPeer = (rank + (s - (1<<sendDim))) % nranks
Phase 2
cpp 复制代码
s = (2*a+1)*scale*aggDelta
sendDim = firstBitSet(s, nrPow2)
sendPeer = (rank - (1<<sendDim)) % nranks
recvPeer = (rank + (s - (1<<sendDim) - (1<<recvDim))) % nranks

聚合机制

聚合的目的

将多个小步骤聚合成一个大步骤,减少同步开销:

复制代码
不聚合(aggFactor=1):
  Step 0: 发送 1 个块
  Step 1: 发送 1 个块
  Step 2: 发送 1 个块
  ...
  每步都需要同步

聚合(aggFactor=4):
  Step 0-3: 发送 4 个块(一次性)
  Step 4-7: 发送 4 个块
  ...
  每 4 步同步一次

发布频率(postFreq)

cpp 复制代码
ps->postSend = (a % postFreq == postFreq-1) || ... ? 1 : 0;
ps->postRecv = (a % postFreq == postFreq-1) || ... ? 1 : 0;
  • postFreq 控制多久发布一次步骤完成
  • 等于 aggFactor
  • 减少对等节点间的同步次数

缓冲区偏移计算

接收偏移

cpp 复制代码
ps->recvOffset = (a % postFreq) * nelem;

发送偏移

cpp 复制代码
ps->sendOffset = (a % postFreq) * nelem;

步骤偏移

cpp 复制代码
ps->stepOffset = (ps->sendDim == 0) ? 0 : a/postFreq;

作用

  • 在循环缓冲区中定位数据
  • 支持流水线操作
  • 避免缓冲区冲突

数据流示例(4 个节点)

初始状态

复制代码
Rank 0: [A, _, _, _]
Rank 1: [_, B, _, _]
Rank 2: [_, _, C, _]
Rank 3: [_, _, _, D]

Phase 0(初始接收)

复制代码
Step 0 (s=0, recvDim=0):
  Rank 0 ← Rank 1: [A, B, _, _]
  Rank 1 ← Rank 0: [A, B, _, _]
  Rank 2 ← Rank 3: [_, _, C, D]
  Rank 3 ← Rank 2: [_, _, C, D]

Phase 1(二叉树传输)

复制代码
Step 1 (s=1, sendDim=0, recvDim=-1):
  Rank 0 → Rank 3: 发送 [A, B]
  Rank 1 → Rank 2: 发送 [A, B]
  Rank 2 → Rank 1: 发送 [C, D]
  Rank 3 → Rank 0: 发送 [C, D]

Step 2 (s=2, sendDim=1, recvDim=0):
  Rank 0 → Rank 2: 发送 [A, B, C, D],← Rank 2: 接收 [C, D]
  Rank 1 → Rank 3: 发送 [A, B, C, D],← Rank 3: 接收 [C, D]
  Rank 2 → Rank 0: 发送 [A, B, C, D],← Rank 0: 接收 [A, B]
  Rank 3 → Rank 1: 发送 [A, B, C, D],← Rank 1: 接收 [A, B]

结果:所有节点都有 [A, B, C, D]

优化技术

1. 流水线并行

  • 多个工作组并行处理不同步骤
  • parallelFactor 控制并行度
  • 每组从 step = group 开始,step += nGroups 递增

2. 聚合优化

  • 减少同步次数
  • 提高缓冲区利用率
  • 通过 aggFactorpostFreq 控制

3. 跳过优化

cpp 复制代码
if (a >= lastA) skip = 1;  // 超出范围,跳过
if (s >= nranks) skip = 1;  // 距离超出节点数,跳过

4. 边界处理

cpp 复制代码
if (s < nranks && ps->sendDim == 0 && skip) {
  // 确保至少接收一次,即使不需要发送
  ps->sendDim = -1;
  skip = 0;
}

算法复杂度分析

时间复杂度

  • 步骤数: O(log N)(二叉树深度)
  • 每步操作: O(1)(并行执行)
  • 总时间: O(log N)

空间复杂度

  • 共享内存: O(log N)(存储对等节点信息)
  • 步骤缓冲 : O(1)(循环使用 NCCL_SHMEM_PAT_STEPS

总结

getNextOp() 的核心功能

  1. 计算距离 (s): 根据 a, as, aggDelta, scale 计算通信距离
  2. 确定维度 : 使用 firstBitSet() 将距离转换为二叉树维度
  3. 计算对等节点: 根据维度和 rank 计算发送/接收的对等节点
  4. 设置偏移: 计算缓冲区偏移,支持流水线和聚合
  5. 控制发布: 决定何时通知对等节点步骤完成
  6. 阶段转换: 管理 Phase 0 → 1 → 2 的转换
  7. 边界处理 : 处理 nranks 不是 2 的幂的情况

这个算法通过精心设计的状态机和位操作,实现了高效的多节点 AllGather 操作,特别适合每节点 1 GPU 的场景。

相关推荐
predawnlove6 天前
【NCCL】4 AllGather-PAT算法
算法·gpu·nccl
predawnlove6 天前
【NCCL】5 GPU 间链路 Preconnect 机制
gpu·nccl
predawnlove7 天前
【NCCL】3. ncclPrepareTasks 到 scheduleCollTasksToPlan 的衔接机制
gpu·nccl·通信库
Luchang-Li4 个月前
sglang pytorch NCCL hang分析
pytorch·python·nccl
小马敲马5 个月前
[4.2-2] NCCL新版本的register如何实现的?
开发语言·c++·人工智能·算法·性能优化·nccl
caodongwang8 个月前
【NCCL】transport建立(一)
p2p·rdma·nccl·transport
跑步去兜风1 年前
RCCL/NCCL中的Transports方式选择:P2P or SHM or NET
服务器·p2p·nccl·shm·rccl
Eloudy1 年前
NCCL 中的一些辅助debug 知识点
nvlink·nccl
Hi202402171 年前
将数据切分成N份,采用NCCL异步通信,让all_gather+matmul尽量Overlap
pytorch·python·性能优化·分布式训练·nccl·融合算子