PyTorch DDP 随机卡死复盘:最后一个 batch 挂起,NCCL 等待不返回

PyTorch DDP 随机卡死复盘:最后一个 batch 挂起,NCCL 等待不返回,三步修复 Sampler & drop_last

很多人在接触深度学习的过程往往都是从自己的笔记本开始的,但是从接触工作后,更多的是通过分布式的训练来模型。由于经验的不足常常会遇到分布式训练"玄学卡死":多卡的训练偶发在 epoch 尾部停住不动,并且GPU 利用率掉到 0%,日志无异常。为了解决首次接触分布式训练的人的疑问,本文从bug现象以及调试逐一分析。

❓ Bug 现象

在我们进行多卡训练的时候,偶尔会出现随机在某些 epoch 尾部卡住,无异常栈;nvidia-smi 显示两卡功耗接近空闲。偶尔能看到 NCCL 打印(并不总出现):

cmd 复制代码
NCCL WARN Reduce failed: ... Async operation timed out

接着通过kill -SIGQUIT 打印 Python 栈后发现停在 反向传播的梯度 allreduce*上(DistributedDataParallel 内部)。

但是这个现象在关掉 DDP(单卡训练)完全正常;把 batch_size 改小/大,卡住概率改变但仍会发生。

📽️ 场景重现

当我们的问题在单卡不会出现,但是多卡会出现问题的时候,问题点集中在数据的问题上。主要原因以下:

1️⃣ shuffle=TrueDistributedSampler 混用(会被忽略但容易误导)。

2️⃣ drop_last=False 时,最后一个小批的样本数在不同 rank 上可能不一致 (当 len(dataset) 不是 world_size 的整数倍且某些数据被过滤/增强丢弃时尤其明显)。

3️⃣ 每个 epoch 忘记调用 sampler.set_epoch(epoch) ,导致各 rank 的随机顺序不一致

以下是笔者在多卡训练遇到的问题代码

python 复制代码
import os, random, torch, torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, Dataset, DistributedSampler

class DummyDS(Dataset):
    def __init__(self, N=1003):  # 刻意设成非 world_size 整数倍
        self.N = N
    def __len__(self): return self.N
    def __getitem__(self, i):
        x = torch.randn(32, 3, 224, 224)
        y = torch.randint(0, 10, (32,))   # 模拟有时会丢弃某些样本的增强(省略)
        return x, y

def setup():
    dist.init_process_group("nccl")
    torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))

def main():
    setup()
    rank = dist.get_rank()
    device = torch.device("cuda", int(os.environ["LOCAL_RANK"]))
    ds = DummyDS()

    sampler = DistributedSampler(ds, shuffle=True, drop_last=False)  # ❌ drop_last=False
    # ❌ DataLoader 里又写了 shuffle=True(被忽略,但容易误以为生效)
    loader = DataLoader(ds, batch_size=2, shuffle=True, sampler=sampler, num_workers=4)

    model = torch.nn.Linear(3*224*224, 10).to(device)
    model = DDP(model, device_ids=[device.index])
    opt = torch.optim.SGD(model.parameters(), lr=0.1)

    for epoch in range(5):
        # ❌ 忘记 sampler.set_epoch(epoch)
        for x, y in loader:
            x = x.view(x.size(0), -1).to(device)
            y = y.to(device)
            opt.zero_grad()
            loss = torch.nn.functional.cross_entropy(model(x), y)
            loss.backward()      # 🔥 偶发卡在这里(allreduce)
            opt.step()
        if rank == 0:
            print(f"epoch {epoch} done")

    dist.destroy_process_group()

if __name__ == "__main__":
    main()

🔴触发条件(满足一两个就可能复现):

1️⃣ len(dataset) 不是 world_size 的整数倍。

2️⃣ 动态数据过滤/增强(例如有时返回 None 或丢样),导致各 rank 实际步数不同。

3️⃣ 忘记 sampler.set_epoch(epoch),各 rank 洗牌次序不同。

4️⃣ drop_last=False,导致最后一个 batch 在各 rank 的样本数不同。

5️⃣ 某些自定义 collate_fn 在"空 batch"时直接 continue

✔️ Debug

1️⃣ 先确认"各 rank 步数一致"

在训练 loop 里加统计(不要只在 rank0 打印):

python 复制代码
from collections import Counter
steps = Counter()
for i, _ in enumerate(loader):
    steps[rank] += 1
dist.all_reduce(torch.tensor([steps[rank]], device=device), op=dist.ReduceOp.SUM)
# 或每个 rank 各自 print,检查是否相等

我的现象 :有的 epoch,rank0 比 rank1 多 1--2 个 step

2️⃣开启 NCCL 调试

在启动前设置:

python 复制代码
export NCCL_DEBUG=INFO
export NCCL_ASYNC_ERROR_HANDLING=1
export NCCL_BLOCKING_WAIT=1

再跑一遍,可看到某些 allreduce 一直等不到某 rank 进来。

3️⃣检查 Sampler 与 DataLoader 参数
  • DistributedSampler 必须 搭配 sampler.set_epoch(epoch)
  • DataLoader 里不要再写 shuffle=True
  • 若数据不可整除,优先 drop_last=True;否则确保各 rank 最后一个 batch 大小一致(例如补齐/填充)。

🟢 解决方案(修复版)

  • 严格对齐 Sampler 语义 + 丢最后不齐整的 batch
python 复制代码
import os, torch, torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler, Dataset

class DummyDS(Dataset):
    def __init__(self, N=1003): self.N=N
    def __len__(self): return self.N
    def __getitem__(self, i):
        x = torch.randn(32, 3, 224, 224)
        y = torch.randint(0, 10, (32,))
        return x, y

def setup():
    dist.init_process_group("nccl")
    torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))

def main():
    setup()
    rank = dist.get_rank()
    device = torch.device("cuda", int(os.environ["LOCAL_RANK"]))

    ds = DummyDS()
    # 关键 1:使用 DistributedSampler,统一交给它洗牌
    sampler = DistributedSampler(ds, shuffle=True, drop_last=True)  # ✅
    # 关键 2:DataLoader 里不要再写 shuffle
    loader = DataLoader(ds, batch_size=2, sampler=sampler, num_workers=4, pin_memory=True)

    model = torch.nn.Linear(3*224*224, 10).to(device)
    ddp = DDP(model, device_ids=[device.index], find_unused_parameters=False)  # 如无动态分支,关掉更稳更快
    opt = torch.optim.SGD(ddp.parameters(), lr=0.1)

    for epoch in range(5):
        sampler.set_epoch(epoch)  # ✅ 关键 3:每个 epoch 设置不同随机种子
        for x, y in loader:
            x = x.view(x.size(0), -1).to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)
            opt.zero_grad(set_to_none=True)
            loss = torch.nn.functional.cross_entropy(ddp(x), y)
            loss.backward()
            opt.step()
        if rank == 0:
            print(f"epoch {epoch} ok")

    dist.barrier()  # ✅ 收尾同步,避免 rank 提前退出
    dist.destroy_process_group()

if __name__ == "__main__":
    main()
  • 必须保留最后一批

如果确实不能 drop_last=True(例如小数据集),可考虑对齐 batch 大小

  1. Padding/Repeat :在 collate_fn 里把最后一批补齐到一致大小
  2. EvenlyDistributedSampler :自定义 sampler,确保各 rank 拿到完全等长的 index 列表(对总长度做上采样)。

示例(最简单的"循环补齐"):

python 复制代码
class EvenSampler(DistributedSampler):
    def __iter__(self):
        # 先拿到原始 index,再做均匀补齐
        indices = list(super().__iter__())
        # 使得 len(indices) 可整除 num_replicas
        rem = len(indices) % self.num_replicas
        if rem != 0:
            pad = self.num_replicas - rem
            indices += indices[:pad]     # 简单重复前几个样本
        return iter(indices)

总结

以上是这次 DDP 卡死问题从现象 → 排查 → 解决 的完整记录。这个坑非常高频 ,尤其在课程项目/科研代码里常被忽视。希望这篇复盘能让你在分布式训练时少掉一把汗。最终定位是 DistributedSampler 使用不当 + drop_last=False + 忘记 set_epoch引发各 rank 步数不一致,导致 allreduce 永久等待。

相关推荐
用户3521802454755 小时前
GraphRAG:让 RAG 更聪明的一种新玩法
人工智能
2501_924534516 小时前
济南矩阵跃动完成千万融资!国产GEO工具能否挑战国际巨头?
大数据·人工智能
霍格沃兹_测试6 小时前
Browser Use 浏览器自动化 Agent:让浏览器自动为你工作
人工智能·测试
爱看科技6 小时前
苹果Vision Air蓝图或定档2027,三星/微美全息加速XR+AI核心生态布局卡位
人工智能·xr
AI浩6 小时前
【面试题】搜索准确性不高你怎么排查?
人工智能
小陈phd6 小时前
高级RAG策略学习(一)——自适应检索系统
人工智能·windows·语言模型
网安INF6 小时前
【论文阅读】-《Besting the Black-Box: Barrier Zones for Adversarial Example Defense》
人工智能·深度学习·网络安全·黑盒攻击
pingao1413786 小时前
景区负氧离子气象站:引领绿色旅游,畅吸清新每一刻
大数据·人工智能·旅游
AKAMAI7 小时前
部署在用户身边,将直播延迟压缩至毫秒级
人工智能·云计算·直播