PyTorch DDP 随机卡死复盘:最后一个 batch 挂起,NCCL 等待不返回

PyTorch DDP 随机卡死复盘:最后一个 batch 挂起,NCCL 等待不返回,三步修复 Sampler & drop_last

很多人在接触深度学习的过程往往都是从自己的笔记本开始的,但是从接触工作后,更多的是通过分布式的训练来模型。由于经验的不足常常会遇到分布式训练"玄学卡死":多卡的训练偶发在 epoch 尾部停住不动,并且GPU 利用率掉到 0%,日志无异常。为了解决首次接触分布式训练的人的疑问,本文从bug现象以及调试逐一分析。

❓ Bug 现象

在我们进行多卡训练的时候,偶尔会出现随机在某些 epoch 尾部卡住,无异常栈;nvidia-smi 显示两卡功耗接近空闲。偶尔能看到 NCCL 打印(并不总出现):

cmd 复制代码
NCCL WARN Reduce failed: ... Async operation timed out

接着通过kill -SIGQUIT 打印 Python 栈后发现停在 反向传播的梯度 allreduce*上(DistributedDataParallel 内部)。

但是这个现象在关掉 DDP(单卡训练)完全正常;把 batch_size 改小/大,卡住概率改变但仍会发生。

📽️ 场景重现

当我们的问题在单卡不会出现,但是多卡会出现问题的时候,问题点集中在数据的问题上。主要原因以下:

1️⃣ shuffle=TrueDistributedSampler 混用(会被忽略但容易误导)。

2️⃣ drop_last=False 时,最后一个小批的样本数在不同 rank 上可能不一致 (当 len(dataset) 不是 world_size 的整数倍且某些数据被过滤/增强丢弃时尤其明显)。

3️⃣ 每个 epoch 忘记调用 sampler.set_epoch(epoch) ,导致各 rank 的随机顺序不一致

以下是笔者在多卡训练遇到的问题代码

python 复制代码
import os, random, torch, torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, Dataset, DistributedSampler

class DummyDS(Dataset):
    def __init__(self, N=1003):  # 刻意设成非 world_size 整数倍
        self.N = N
    def __len__(self): return self.N
    def __getitem__(self, i):
        x = torch.randn(32, 3, 224, 224)
        y = torch.randint(0, 10, (32,))   # 模拟有时会丢弃某些样本的增强(省略)
        return x, y

def setup():
    dist.init_process_group("nccl")
    torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))

def main():
    setup()
    rank = dist.get_rank()
    device = torch.device("cuda", int(os.environ["LOCAL_RANK"]))
    ds = DummyDS()

    sampler = DistributedSampler(ds, shuffle=True, drop_last=False)  # ❌ drop_last=False
    # ❌ DataLoader 里又写了 shuffle=True(被忽略,但容易误以为生效)
    loader = DataLoader(ds, batch_size=2, shuffle=True, sampler=sampler, num_workers=4)

    model = torch.nn.Linear(3*224*224, 10).to(device)
    model = DDP(model, device_ids=[device.index])
    opt = torch.optim.SGD(model.parameters(), lr=0.1)

    for epoch in range(5):
        # ❌ 忘记 sampler.set_epoch(epoch)
        for x, y in loader:
            x = x.view(x.size(0), -1).to(device)
            y = y.to(device)
            opt.zero_grad()
            loss = torch.nn.functional.cross_entropy(model(x), y)
            loss.backward()      # 🔥 偶发卡在这里(allreduce)
            opt.step()
        if rank == 0:
            print(f"epoch {epoch} done")

    dist.destroy_process_group()

if __name__ == "__main__":
    main()

🔴触发条件(满足一两个就可能复现):

1️⃣ len(dataset) 不是 world_size 的整数倍。

2️⃣ 动态数据过滤/增强(例如有时返回 None 或丢样),导致各 rank 实际步数不同。

3️⃣ 忘记 sampler.set_epoch(epoch),各 rank 洗牌次序不同。

4️⃣ drop_last=False,导致最后一个 batch 在各 rank 的样本数不同。

5️⃣ 某些自定义 collate_fn 在"空 batch"时直接 continue

✔️ Debug

1️⃣ 先确认"各 rank 步数一致"

在训练 loop 里加统计(不要只在 rank0 打印):

python 复制代码
from collections import Counter
steps = Counter()
for i, _ in enumerate(loader):
    steps[rank] += 1
dist.all_reduce(torch.tensor([steps[rank]], device=device), op=dist.ReduceOp.SUM)
# 或每个 rank 各自 print,检查是否相等

我的现象 :有的 epoch,rank0 比 rank1 多 1--2 个 step

2️⃣开启 NCCL 调试

在启动前设置:

python 复制代码
export NCCL_DEBUG=INFO
export NCCL_ASYNC_ERROR_HANDLING=1
export NCCL_BLOCKING_WAIT=1

再跑一遍,可看到某些 allreduce 一直等不到某 rank 进来。

3️⃣检查 Sampler 与 DataLoader 参数
  • DistributedSampler 必须 搭配 sampler.set_epoch(epoch)
  • DataLoader 里不要再写 shuffle=True
  • 若数据不可整除,优先 drop_last=True;否则确保各 rank 最后一个 batch 大小一致(例如补齐/填充)。

🟢 解决方案(修复版)

  • 严格对齐 Sampler 语义 + 丢最后不齐整的 batch
python 复制代码
import os, torch, torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler, Dataset

class DummyDS(Dataset):
    def __init__(self, N=1003): self.N=N
    def __len__(self): return self.N
    def __getitem__(self, i):
        x = torch.randn(32, 3, 224, 224)
        y = torch.randint(0, 10, (32,))
        return x, y

def setup():
    dist.init_process_group("nccl")
    torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))

def main():
    setup()
    rank = dist.get_rank()
    device = torch.device("cuda", int(os.environ["LOCAL_RANK"]))

    ds = DummyDS()
    # 关键 1:使用 DistributedSampler,统一交给它洗牌
    sampler = DistributedSampler(ds, shuffle=True, drop_last=True)  # ✅
    # 关键 2:DataLoader 里不要再写 shuffle
    loader = DataLoader(ds, batch_size=2, sampler=sampler, num_workers=4, pin_memory=True)

    model = torch.nn.Linear(3*224*224, 10).to(device)
    ddp = DDP(model, device_ids=[device.index], find_unused_parameters=False)  # 如无动态分支,关掉更稳更快
    opt = torch.optim.SGD(ddp.parameters(), lr=0.1)

    for epoch in range(5):
        sampler.set_epoch(epoch)  # ✅ 关键 3:每个 epoch 设置不同随机种子
        for x, y in loader:
            x = x.view(x.size(0), -1).to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)
            opt.zero_grad(set_to_none=True)
            loss = torch.nn.functional.cross_entropy(ddp(x), y)
            loss.backward()
            opt.step()
        if rank == 0:
            print(f"epoch {epoch} ok")

    dist.barrier()  # ✅ 收尾同步,避免 rank 提前退出
    dist.destroy_process_group()

if __name__ == "__main__":
    main()
  • 必须保留最后一批

如果确实不能 drop_last=True(例如小数据集),可考虑对齐 batch 大小

  1. Padding/Repeat :在 collate_fn 里把最后一批补齐到一致大小
  2. EvenlyDistributedSampler :自定义 sampler,确保各 rank 拿到完全等长的 index 列表(对总长度做上采样)。

示例(最简单的"循环补齐"):

python 复制代码
class EvenSampler(DistributedSampler):
    def __iter__(self):
        # 先拿到原始 index,再做均匀补齐
        indices = list(super().__iter__())
        # 使得 len(indices) 可整除 num_replicas
        rem = len(indices) % self.num_replicas
        if rem != 0:
            pad = self.num_replicas - rem
            indices += indices[:pad]     # 简单重复前几个样本
        return iter(indices)

总结

以上是这次 DDP 卡死问题从现象 → 排查 → 解决 的完整记录。这个坑非常高频 ,尤其在课程项目/科研代码里常被忽视。希望这篇复盘能让你在分布式训练时少掉一把汗。最终定位是 DistributedSampler 使用不当 + drop_last=False + 忘记 set_epoch引发各 rank 步数不一致,导致 allreduce 永久等待。

相关推荐
搞科研的小刘选手20 小时前
【厦门大学主办】第六届计算机科学与管理科技国际学术会议(ICCSMT 2025)
人工智能·科技·计算机网络·计算机·云计算·学术会议
fanstuck20 小时前
深入解析 PyPTO Operator:以 DeepSeek‑V3.2‑Exp 模型为例的实战指南
人工智能·语言模型·aigc·gpu算力
萤丰信息20 小时前
智慧园区能源革命:从“耗电黑洞”到零碳样本的蜕变
java·大数据·人工智能·科技·安全·能源·智慧园区
世洋Blog20 小时前
更好的利用ChatGPT进行项目的开发
人工智能·unity·chatgpt
serve the people1 天前
机器学习(ML)和人工智能(AI)技术在WAF安防中的应用
人工智能·机器学习
0***K8921 天前
前端机器学习
人工智能·机器学习
陈天伟教授1 天前
基于学习的人工智能(5)机器学习基本框架
人工智能·学习·机器学习
m0_650108241 天前
PaLM-E:具身智能的多模态语言模型新范式
论文阅读·人工智能·机器人·具身智能·多模态大语言模型·palm-e·大模型驱动
zandy10111 天前
2025年11月AI IDE权深度测榜:深度分析不同场景的落地选型攻略
ide·人工智能·ai编程·ai代码·腾讯云ai代码助手
欢喜躲在眉梢里1 天前
CANN 异构计算架构实操指南:从环境部署到 AI 任务加速全流程
运维·服务器·人工智能·ai·架构·计算