PatAGAlgorithm::getNextOp() 详解
PatAGAlgorithm::getNextOp() 函数的算法规划是 PAT (Pipelined Allgather Tree) 算法的核心调度器。
算法概述
PAT AllGather 使用**二叉树(binomial tree)结构,通过 聚合(aggregation)和流水线(pipelining)**优化,在 O(log N) 步骤内完成 AllGather 操作。
关键成员变量
状态变量
cpp
int a; // 当前聚合步骤内的子步骤索引
int as; // 聚合步骤索引(aggregated steps)
int phase; // 当前阶段(0, 1, 2)
int scale; // 缩放因子(用于阶段2)
size_t offset; // 当前数据块偏移
配置参数
cpp
int aggFactor; // 聚合因子(决定聚合多少步骤)
int aggDelta; // 聚合增量(nrPow2 / aggFactor)
int parallelFactor; // 并行因子(工作组数)
int postFreq; // 发布频率(等于 aggFactor)
int nrPow2; // >= nranks 的最小 2 的幂
算法初始化
构造函数 (PatAGAlgorithm())
cpp
PatAGAlgorithm(int stepSize, int stepDepth, int maxParallelFactor,
size_t offset, size_t end, size_t count, int chunkCount,
int rank, int nranks) {
parallelFactor = maxParallelFactor;
aggDelta = nrPow2 = (1<<log2Up(nranks)); // 例如:nranks=6 → nrPow2=8
// 计算聚合因子(基于缓冲区大小)
aggFactor = 1;
size_t channelSize = end-offset;
while (stepSize / (channelSize*sizeof(T)*aggFactor) >= 2 && aggFactor < nranks/2) {
aggFactor *= 2; // 增加聚合
aggDelta /= 2; // 减少增量
}
postFreq = aggFactor;
if (postFreq < parallelFactor) parallelFactor = postFreq;
// 根据步骤深度进一步调整聚合
int d = stepDepth;
while (d > 1 && aggFactor < nranks/2) {
d /= 2;
aggFactor *= 2;
aggDelta /= 2;
}
asDim = log2Up(aggDelta); // as 的维度
reset();
}
聚合因子的作用:
- 将多个小步骤聚合成一个大步骤
- 减少同步开销
- 提高缓冲区利用率
三个阶段(Phase)
PAT AllGather 分为三个主要阶段:
Phase 0: 初始接收阶段
cpp
else if (phase == 0) {
int s = a*aggDelta + as; // 计算距离
if (s >= nranks) skip = 1;
int recvDataRank = (rank + s) % nranks; // 接收数据的源 rank
ps->outIx = recvDataRank * count + offset; // 输出位置
// 只接收,不发送
ps->sendDim = -1;
ps->recvDim = 0; // 从维度 0(距离 1)接收
ps->inpIx = 0;
ps->sendOffset = -1;
ps->recvOffset = (a % postFreq) * nelem; // 接收缓冲区偏移
ps->stepOffset = 0;
// 决定是否发布接收完成
ps->postRecv = (a % postFreq == postFreq-1) || ((a+1)*aggDelta+as >= nranks) ? 1 : 0;
ps->postSend = 0;
}
Phase 0 的目的:
- 从最近的邻居(距离 1)接收初始数据
- 为后续的二叉树传输做准备
Phase 1: 主要二叉树传输阶段
cpp
else if (phase == 1) {
int s = a*aggDelta + as; // 计算距离
if (s >= nranks) skip = 1;
// 计算发送维度(log2(距离))
ps->sendDim = firstBitSet(s, nrPow2); // 例如:s=4 → sendDim=2
s -= (1<<ps->sendDim); // 减去发送距离
int sendDataRank = (rank + nranks + s) % nranks;
ps->outIx = sendDataRank * count + offset;
// 计算接收维度
ps->recvDim = s ? firstBitSet(s, nrPow2) : -1;
// 设置偏移和发布标志
ps->sendOffset = ps->recvOffset = (a % postFreq) * nelem;
ps->postSend = (a % postFreq == postFreq-1) || ((a+1)*aggDelta+as >= nranks) ? 1 : 0;
ps->postRecv = (ps->sendDim == 0) && ((a % postFreq == postFreq-1) || ((a+1)*aggDelta+as-1 >= nranks)) ? 1 : 0;
ps->stepOffset = (ps->sendDim == 0) ? 0 : a/postFreq;
// 特殊情况:接收维度为 -1
if (ps->recvDim == -1) {
ps->recvOffset = -1;
ps->postRecv = 0;
} else if (as - (1<<ps->sendDim) == 0) {
// 调整接收偏移
int foffset = (a*aggDelta) >> (ps->recvDim+1);
ps->recvOffset = (foffset%postFreq)*nelem;
ps->postRecv = (ps->sendDim == 0) && ((foffset % postFreq == postFreq-1) || ((((foffset+1)*2)+1)<<ps->recvDim) >= nranks) ? 1 : 0;
ps->stepOffset = (ps->sendDim == 0) ? 0 : foffset/postFreq;
}
// 边界情况:确保至少接收一次
if (s < nranks && ps->sendDim == 0 && skip) {
ps->sendDim = -1;
ps->sendOffset = -1;
ps->postSend = 0;
skip = 0;
}
}
Phase 1 的核心逻辑:
- 同时进行接收和发送
- 使用二叉树结构(通过
firstBitSet计算维度) - 支持聚合多个步骤
Phase 2: 缩放传输阶段
cpp
else if (phase == 2) {
int s = (2*a+1)*scale*aggDelta; // 计算距离(奇数倍)
ps->postSend = (a % postFreq == postFreq-1) || ((2*(a+1)+1)*scale*aggDelta >= nranks) ? 1 : 0;
ps->postRecv = 0;
if (s >= nranks) skip = 1;
ps->sendDim = firstBitSet(s, nrPow2);
s -= (1<<ps->sendDim);
ps->sendOffset = (a%postFreq) * nelem;
ps->stepOffset = a / postFreq;
int sendDataRank = (rank + nranks + s) % nranks;
ps->outIx = sendDataRank * count + offset;
ps->recvDim = s ? firstBitSet(s, nrPow2) : -1;
if (ps->recvDim == -1) {
ps->recvOffset = -1;
} else {
s -= (1<<ps->recvDim);
int foffset = (a*2*scale*aggDelta) >> (ps->recvDim+1);
ps->recvOffset = (foffset%postFreq)*nelem;
ps->stepOffset = foffset / postFreq;
}
}
Phase 2 的目的:
- 处理更大距离的传输(奇数倍)
- 使用缩放因子
scale逐步减小 - 填补二叉树的空隙
阶段转换逻辑
阶段转换条件 (getNextOp())
cpp
a++; // 递增子步骤
if (a >= lastA && a >= parallelFactor) {
int p = phase;
if (p == 2) scale /= 2; // Phase 2 时缩小 scale
// 阶段转换状态机
phase =
p == 2 ? scale ? 2 : 1 : // Phase 2: scale>0 继续 Phase 2,否则转 Phase 1
p == 1 ? as % 2 == 1 ? 0 : 1 : // Phase 1: as 奇数转 Phase 0,偶数继续 Phase 1
1; // Phase 0: 转 Phase 1
// 更新 as(聚合步骤索引)
if (p == 0 || (p == 1 && as % 2 == 0))
as = nextAs(); // 计算下一个 as 值
// 检查是否完成所有数据
if (p == 0 && as == aggDelta/2) {
offset += chunkCount; // 移动到下一个数据块
if (offset >= end) {
ps->last = 2; // 全局完成
} else {
reset(); // 重置,处理下一个数据块
}
} else {
resetA(); // 重置 a,继续当前数据块
}
}
提前结束条件
cpp
else if (phase == 0 && as == 1 && offset + chunkCount >= end &&
a-1 >= ((lastA-1) / parallelFactor) * parallelFactor) {
ps->last = 1; // 本组完成(但其他组可能还在运行)
}
关键辅助函数
1. nextAs() - 计算下一个聚合步骤索引
cpp
__device__ __host__ int nextAs() {
for (int d=0; d<asDim; d++) {
int p = 1<<d; // 2^d
bitCount[d]--;
if (bitCount[d] == 0) {
v ^= p; // 翻转第 d 位
bitCount[d] = p; // 重置计数
if ((v&p) == 0) { // 如果第 d 位变为 0
bitCount[d] += firstBitSet(bitZeroStep[d], asDim) - 1;
if (bitCount[d] == 0) {
v ^= p;
bitCount[d] = p;
}
bitZeroStep[d]++;
}
}
}
return v;
}
nextAs() 的作用:
- 生成一个特殊的序列,用于遍历所有可能的
as值 - 使用位操作和计数器实现高效的序列生成
- 确保覆盖所有必要的通信模式
2. mirror() - 位镜像
cpp
__device__ __host__ int mirror(int i, int max) {
int ret = 0;
for (int mask=1, imask=max/2; mask<max; mask<<=1, imask>>=1) {
if ((i&mask)) ret += imask; // 如果第 k 位为 1,则在结果的 (log2(max)-k-1) 位置 1
}
return ret;
}
示例:
max = 8 (二进制: 1000)
i = 3 (二进制: 011)
mirror(3, 8) = 6 (二进制: 110)
解释:
i: 0 1 1
↓ ↓ ↓
ret: 1 1 0 = 6
3. firstBitSet() - 找到最低位的 1
cpp
__device__ __host__ int firstBitSet(int i, int max) {
int ffs = __ffs(i); // CUDA 内置函数,返回最低位 1 的位置(1-based)
return ffs ? ffs-1 : max; // 转换为 0-based
}
示例:
i = 12 (二进制: 1100)
firstBitSet(12, 32) = 2 // 第 2 位(从 0 开始)是最低位的 1
算法执行流程示例
示例:8 个节点,aggFactor=1(无聚合)
初始化
nranks = 8
nrPow2 = 8
aggFactor = 1
aggDelta = 8
parallelFactor = 1
Rank 0 的执行序列
Phase 0 (初始接收):
Step 0: s=0, recvDataRank=0, recvDim=0
→ 从 Rank 1 接收数据(距离 1)
Phase 1 (二叉树传输):
Step 1: s=1, sendDim=0, recvDim=-1
→ 发送到 Rank 7(距离 1),无接收
Step 2: s=2, sendDim=1, recvDim=0
→ 发送到 Rank 6(距离 2),从 Rank 2 接收(距离 2)
Step 3: s=3, sendDim=0, recvDim=1
→ 发送到 Rank 5(距离 1),从 Rank 3 接收(距离 3)
Step 4: s=4, sendDim=2, recvDim=0
→ 发送到 Rank 4(距离 4),从 Rank 4 接收(距离 4)
Step 5: s=5, sendDim=0, recvDim=2
→ 发送到 Rank 3(距离 1),从 Rank 5 接收(距离 5)
Step 6: s=6, sendDim=1, recvDim=0
→ 发送到 Rank 2(距离 2),从 Rank 6 接收(距离 6)
Step 7: s=7, sendDim=0, recvDim=1
→ 发送到 Rank 1(距离 1),从 Rank 7 接收(距离 7)
距离计算公式
Phase 0
cpp
s = a*aggDelta + as
recvPeer = (rank + s) % nranks
Phase 1
cpp
s = a*aggDelta + as
sendDim = firstBitSet(s, nrPow2) // log2(最低位的 1)
sendPeer = (rank - (1<<sendDim)) % nranks
recvPeer = (rank + (s - (1<<sendDim))) % nranks
Phase 2
cpp
s = (2*a+1)*scale*aggDelta
sendDim = firstBitSet(s, nrPow2)
sendPeer = (rank - (1<<sendDim)) % nranks
recvPeer = (rank + (s - (1<<sendDim) - (1<<recvDim))) % nranks
聚合机制
聚合的目的
将多个小步骤聚合成一个大步骤,减少同步开销:
不聚合(aggFactor=1):
Step 0: 发送 1 个块
Step 1: 发送 1 个块
Step 2: 发送 1 个块
...
每步都需要同步
聚合(aggFactor=4):
Step 0-3: 发送 4 个块(一次性)
Step 4-7: 发送 4 个块
...
每 4 步同步一次
发布频率(postFreq)
cpp
ps->postSend = (a % postFreq == postFreq-1) || ... ? 1 : 0;
ps->postRecv = (a % postFreq == postFreq-1) || ... ? 1 : 0;
postFreq控制多久发布一次步骤完成- 等于
aggFactor - 减少对等节点间的同步次数
缓冲区偏移计算
接收偏移
cpp
ps->recvOffset = (a % postFreq) * nelem;
发送偏移
cpp
ps->sendOffset = (a % postFreq) * nelem;
步骤偏移
cpp
ps->stepOffset = (ps->sendDim == 0) ? 0 : a/postFreq;
作用:
- 在循环缓冲区中定位数据
- 支持流水线操作
- 避免缓冲区冲突
数据流示例(4 个节点)
初始状态
Rank 0: [A, _, _, _]
Rank 1: [_, B, _, _]
Rank 2: [_, _, C, _]
Rank 3: [_, _, _, D]
Phase 0(初始接收)
Step 0 (s=0, recvDim=0):
Rank 0 ← Rank 1: [A, B, _, _]
Rank 1 ← Rank 0: [A, B, _, _]
Rank 2 ← Rank 3: [_, _, C, D]
Rank 3 ← Rank 2: [_, _, C, D]
Phase 1(二叉树传输)
Step 1 (s=1, sendDim=0, recvDim=-1):
Rank 0 → Rank 3: 发送 [A, B]
Rank 1 → Rank 2: 发送 [A, B]
Rank 2 → Rank 1: 发送 [C, D]
Rank 3 → Rank 0: 发送 [C, D]
Step 2 (s=2, sendDim=1, recvDim=0):
Rank 0 → Rank 2: 发送 [A, B, C, D],← Rank 2: 接收 [C, D]
Rank 1 → Rank 3: 发送 [A, B, C, D],← Rank 3: 接收 [C, D]
Rank 2 → Rank 0: 发送 [A, B, C, D],← Rank 0: 接收 [A, B]
Rank 3 → Rank 1: 发送 [A, B, C, D],← Rank 1: 接收 [A, B]
结果:所有节点都有 [A, B, C, D]
优化技术
1. 流水线并行
- 多个工作组并行处理不同步骤
parallelFactor控制并行度- 每组从
step = group开始,step += nGroups递增
2. 聚合优化
- 减少同步次数
- 提高缓冲区利用率
- 通过
aggFactor和postFreq控制
3. 跳过优化
cpp
if (a >= lastA) skip = 1; // 超出范围,跳过
if (s >= nranks) skip = 1; // 距离超出节点数,跳过
4. 边界处理
cpp
if (s < nranks && ps->sendDim == 0 && skip) {
// 确保至少接收一次,即使不需要发送
ps->sendDim = -1;
skip = 0;
}
算法复杂度分析
时间复杂度
- 步骤数: O(log N)(二叉树深度)
- 每步操作: O(1)(并行执行)
- 总时间: O(log N)
空间复杂度
- 共享内存: O(log N)(存储对等节点信息)
- 步骤缓冲 : O(1)(循环使用
NCCL_SHMEM_PAT_STEPS)
总结
getNextOp() 的核心功能
- 计算距离 (
s): 根据a,as,aggDelta,scale计算通信距离 - 确定维度 : 使用
firstBitSet()将距离转换为二叉树维度 - 计算对等节点: 根据维度和 rank 计算发送/接收的对等节点
- 设置偏移: 计算缓冲区偏移,支持流水线和聚合
- 控制发布: 决定何时通知对等节点步骤完成
- 阶段转换: 管理 Phase 0 → 1 → 2 的转换
- 边界处理 : 处理
nranks不是 2 的幂的情况
这个算法通过精心设计的状态机和位操作,实现了高效的多节点 AllGather 操作,特别适合每节点 1 GPU 的场景。