PyTorch DDP 随机卡死复盘:最后一个 batch 挂起,NCCL 等待不返回

PyTorch DDP 随机卡死复盘:最后一个 batch 挂起,NCCL 等待不返回,三步修复 Sampler & drop_last

很多人在接触深度学习的过程往往都是从自己的笔记本开始的,但是从接触工作后,更多的是通过分布式的训练来模型。由于经验的不足常常会遇到分布式训练"玄学卡死":多卡的训练偶发在 epoch 尾部停住不动,并且GPU 利用率掉到 0%,日志无异常。为了解决首次接触分布式训练的人的疑问,本文从bug现象以及调试逐一分析。

❓ Bug 现象

在我们进行多卡训练的时候,偶尔会出现随机在某些 epoch 尾部卡住,无异常栈;nvidia-smi 显示两卡功耗接近空闲。偶尔能看到 NCCL 打印(并不总出现):

cmd 复制代码
NCCL WARN Reduce failed: ... Async operation timed out

接着通过kill -SIGQUIT 打印 Python 栈后发现停在 反向传播的梯度 allreduce*上(DistributedDataParallel 内部)。

但是这个现象在关掉 DDP(单卡训练)完全正常;把 batch_size 改小/大,卡住概率改变但仍会发生。

📽️ 场景重现

当我们的问题在单卡不会出现,但是多卡会出现问题的时候,问题点集中在数据的问题上。主要原因以下:

1️⃣ shuffle=TrueDistributedSampler 混用(会被忽略但容易误导)。

2️⃣ drop_last=False 时,最后一个小批的样本数在不同 rank 上可能不一致 (当 len(dataset) 不是 world_size 的整数倍且某些数据被过滤/增强丢弃时尤其明显)。

3️⃣ 每个 epoch 忘记调用 sampler.set_epoch(epoch) ,导致各 rank 的随机顺序不一致

以下是笔者在多卡训练遇到的问题代码

python 复制代码
import os, random, torch, torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, Dataset, DistributedSampler

class DummyDS(Dataset):
    def __init__(self, N=1003):  # 刻意设成非 world_size 整数倍
        self.N = N
    def __len__(self): return self.N
    def __getitem__(self, i):
        x = torch.randn(32, 3, 224, 224)
        y = torch.randint(0, 10, (32,))   # 模拟有时会丢弃某些样本的增强(省略)
        return x, y

def setup():
    dist.init_process_group("nccl")
    torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))

def main():
    setup()
    rank = dist.get_rank()
    device = torch.device("cuda", int(os.environ["LOCAL_RANK"]))
    ds = DummyDS()

    sampler = DistributedSampler(ds, shuffle=True, drop_last=False)  # ❌ drop_last=False
    # ❌ DataLoader 里又写了 shuffle=True(被忽略,但容易误以为生效)
    loader = DataLoader(ds, batch_size=2, shuffle=True, sampler=sampler, num_workers=4)

    model = torch.nn.Linear(3*224*224, 10).to(device)
    model = DDP(model, device_ids=[device.index])
    opt = torch.optim.SGD(model.parameters(), lr=0.1)

    for epoch in range(5):
        # ❌ 忘记 sampler.set_epoch(epoch)
        for x, y in loader:
            x = x.view(x.size(0), -1).to(device)
            y = y.to(device)
            opt.zero_grad()
            loss = torch.nn.functional.cross_entropy(model(x), y)
            loss.backward()      # 🔥 偶发卡在这里(allreduce)
            opt.step()
        if rank == 0:
            print(f"epoch {epoch} done")

    dist.destroy_process_group()

if __name__ == "__main__":
    main()

🔴触发条件(满足一两个就可能复现):

1️⃣ len(dataset) 不是 world_size 的整数倍。

2️⃣ 动态数据过滤/增强(例如有时返回 None 或丢样),导致各 rank 实际步数不同。

3️⃣ 忘记 sampler.set_epoch(epoch),各 rank 洗牌次序不同。

4️⃣ drop_last=False,导致最后一个 batch 在各 rank 的样本数不同。

5️⃣ 某些自定义 collate_fn 在"空 batch"时直接 continue

✔️ Debug

1️⃣ 先确认"各 rank 步数一致"

在训练 loop 里加统计(不要只在 rank0 打印):

python 复制代码
from collections import Counter
steps = Counter()
for i, _ in enumerate(loader):
    steps[rank] += 1
dist.all_reduce(torch.tensor([steps[rank]], device=device), op=dist.ReduceOp.SUM)
# 或每个 rank 各自 print,检查是否相等

我的现象 :有的 epoch,rank0 比 rank1 多 1--2 个 step

2️⃣开启 NCCL 调试

在启动前设置:

python 复制代码
export NCCL_DEBUG=INFO
export NCCL_ASYNC_ERROR_HANDLING=1
export NCCL_BLOCKING_WAIT=1

再跑一遍,可看到某些 allreduce 一直等不到某 rank 进来。

3️⃣检查 Sampler 与 DataLoader 参数
  • DistributedSampler 必须 搭配 sampler.set_epoch(epoch)
  • DataLoader 里不要再写 shuffle=True
  • 若数据不可整除,优先 drop_last=True;否则确保各 rank 最后一个 batch 大小一致(例如补齐/填充)。

🟢 解决方案(修复版)

  • 严格对齐 Sampler 语义 + 丢最后不齐整的 batch
python 复制代码
import os, torch, torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler, Dataset

class DummyDS(Dataset):
    def __init__(self, N=1003): self.N=N
    def __len__(self): return self.N
    def __getitem__(self, i):
        x = torch.randn(32, 3, 224, 224)
        y = torch.randint(0, 10, (32,))
        return x, y

def setup():
    dist.init_process_group("nccl")
    torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))

def main():
    setup()
    rank = dist.get_rank()
    device = torch.device("cuda", int(os.environ["LOCAL_RANK"]))

    ds = DummyDS()
    # 关键 1:使用 DistributedSampler,统一交给它洗牌
    sampler = DistributedSampler(ds, shuffle=True, drop_last=True)  # ✅
    # 关键 2:DataLoader 里不要再写 shuffle
    loader = DataLoader(ds, batch_size=2, sampler=sampler, num_workers=4, pin_memory=True)

    model = torch.nn.Linear(3*224*224, 10).to(device)
    ddp = DDP(model, device_ids=[device.index], find_unused_parameters=False)  # 如无动态分支,关掉更稳更快
    opt = torch.optim.SGD(ddp.parameters(), lr=0.1)

    for epoch in range(5):
        sampler.set_epoch(epoch)  # ✅ 关键 3:每个 epoch 设置不同随机种子
        for x, y in loader:
            x = x.view(x.size(0), -1).to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)
            opt.zero_grad(set_to_none=True)
            loss = torch.nn.functional.cross_entropy(ddp(x), y)
            loss.backward()
            opt.step()
        if rank == 0:
            print(f"epoch {epoch} ok")

    dist.barrier()  # ✅ 收尾同步,避免 rank 提前退出
    dist.destroy_process_group()

if __name__ == "__main__":
    main()
  • 必须保留最后一批

如果确实不能 drop_last=True(例如小数据集),可考虑对齐 batch 大小

  1. Padding/Repeat :在 collate_fn 里把最后一批补齐到一致大小
  2. EvenlyDistributedSampler :自定义 sampler,确保各 rank 拿到完全等长的 index 列表(对总长度做上采样)。

示例(最简单的"循环补齐"):

python 复制代码
class EvenSampler(DistributedSampler):
    def __iter__(self):
        # 先拿到原始 index,再做均匀补齐
        indices = list(super().__iter__())
        # 使得 len(indices) 可整除 num_replicas
        rem = len(indices) % self.num_replicas
        if rem != 0:
            pad = self.num_replicas - rem
            indices += indices[:pad]     # 简单重复前几个样本
        return iter(indices)

总结

以上是这次 DDP 卡死问题从现象 → 排查 → 解决 的完整记录。这个坑非常高频 ,尤其在课程项目/科研代码里常被忽视。希望这篇复盘能让你在分布式训练时少掉一把汗。最终定位是 DistributedSampler 使用不当 + drop_last=False + 忘记 set_epoch引发各 rank 步数不一致,导致 allreduce 永久等待。

相关推荐
熊文豪1 分钟前
蓝耘MaaS驱动PandaWiki:零基础搭建AI智能知识库完整指南
人工智能·pandawiki·蓝耘maas
没有梦想的咸鱼185-1037-166319 分钟前
【遥感技术】从CNN到Transformer:基于PyTorch的遥感影像、无人机影像的地物分类、目标检测、语义分割和点云分类
pytorch·python·深度学习·机器学习·数据分析·cnn·transformer
whaosoft-14324 分钟前
51c视觉~合集2~目标跟踪
人工智能
cyyt39 分钟前
深度学习周报(9.15~9.21)
人工智能·深度学习·量子计算
Deepoch1 小时前
Deepoc具身智能模型:为传统机器人注入“灵魂”,重塑建筑施工现场安全新范式
人工智能·科技·机器人·人机交互·具身智能
吃饭睡觉发paper2 小时前
High precision single-photon object detection via deep neural networks,OE2024
人工智能·目标检测·计算机视觉
醉方休2 小时前
TensorFlow.js高级功能
javascript·人工智能·tensorflow
云宏信息2 小时前
赛迪顾问《2025中国虚拟化市场研究报告》解读丨虚拟化市场迈向“多元算力架构”,国产化与AI驱动成关键变量
网络·人工智能·ai·容器·性能优化·架构·云计算
红苕稀饭6662 小时前
VideoChat-Flash论文阅读
人工智能·深度学习·机器学习
周杰伦_Jay2 小时前
【图文详解】强化学习核心框架、数学基础、分类、应用场景
人工智能·科技·算法·机器学习·计算机视觉·分类·数据挖掘