PyTorch DDP 随机卡死复盘:最后一个 batch 挂起,NCCL 等待不返回

PyTorch DDP 随机卡死复盘:最后一个 batch 挂起,NCCL 等待不返回,三步修复 Sampler & drop_last

很多人在接触深度学习的过程往往都是从自己的笔记本开始的,但是从接触工作后,更多的是通过分布式的训练来模型。由于经验的不足常常会遇到分布式训练"玄学卡死":多卡的训练偶发在 epoch 尾部停住不动,并且GPU 利用率掉到 0%,日志无异常。为了解决首次接触分布式训练的人的疑问,本文从bug现象以及调试逐一分析。

❓ Bug 现象

在我们进行多卡训练的时候,偶尔会出现随机在某些 epoch 尾部卡住,无异常栈;nvidia-smi 显示两卡功耗接近空闲。偶尔能看到 NCCL 打印(并不总出现):

cmd 复制代码
NCCL WARN Reduce failed: ... Async operation timed out

接着通过kill -SIGQUIT 打印 Python 栈后发现停在 反向传播的梯度 allreduce*上(DistributedDataParallel 内部)。

但是这个现象在关掉 DDP(单卡训练)完全正常;把 batch_size 改小/大,卡住概率改变但仍会发生。

📽️ 场景重现

当我们的问题在单卡不会出现,但是多卡会出现问题的时候,问题点集中在数据的问题上。主要原因以下:

1️⃣ shuffle=TrueDistributedSampler 混用(会被忽略但容易误导)。

2️⃣ drop_last=False 时,最后一个小批的样本数在不同 rank 上可能不一致 (当 len(dataset) 不是 world_size 的整数倍且某些数据被过滤/增强丢弃时尤其明显)。

3️⃣ 每个 epoch 忘记调用 sampler.set_epoch(epoch) ,导致各 rank 的随机顺序不一致

以下是笔者在多卡训练遇到的问题代码

python 复制代码
import os, random, torch, torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, Dataset, DistributedSampler

class DummyDS(Dataset):
    def __init__(self, N=1003):  # 刻意设成非 world_size 整数倍
        self.N = N
    def __len__(self): return self.N
    def __getitem__(self, i):
        x = torch.randn(32, 3, 224, 224)
        y = torch.randint(0, 10, (32,))   # 模拟有时会丢弃某些样本的增强(省略)
        return x, y

def setup():
    dist.init_process_group("nccl")
    torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))

def main():
    setup()
    rank = dist.get_rank()
    device = torch.device("cuda", int(os.environ["LOCAL_RANK"]))
    ds = DummyDS()

    sampler = DistributedSampler(ds, shuffle=True, drop_last=False)  # ❌ drop_last=False
    # ❌ DataLoader 里又写了 shuffle=True(被忽略,但容易误以为生效)
    loader = DataLoader(ds, batch_size=2, shuffle=True, sampler=sampler, num_workers=4)

    model = torch.nn.Linear(3*224*224, 10).to(device)
    model = DDP(model, device_ids=[device.index])
    opt = torch.optim.SGD(model.parameters(), lr=0.1)

    for epoch in range(5):
        # ❌ 忘记 sampler.set_epoch(epoch)
        for x, y in loader:
            x = x.view(x.size(0), -1).to(device)
            y = y.to(device)
            opt.zero_grad()
            loss = torch.nn.functional.cross_entropy(model(x), y)
            loss.backward()      # 🔥 偶发卡在这里(allreduce)
            opt.step()
        if rank == 0:
            print(f"epoch {epoch} done")

    dist.destroy_process_group()

if __name__ == "__main__":
    main()

🔴触发条件(满足一两个就可能复现):

1️⃣ len(dataset) 不是 world_size 的整数倍。

2️⃣ 动态数据过滤/增强(例如有时返回 None 或丢样),导致各 rank 实际步数不同。

3️⃣ 忘记 sampler.set_epoch(epoch),各 rank 洗牌次序不同。

4️⃣ drop_last=False,导致最后一个 batch 在各 rank 的样本数不同。

5️⃣ 某些自定义 collate_fn 在"空 batch"时直接 continue

✔️ Debug

1️⃣ 先确认"各 rank 步数一致"

在训练 loop 里加统计(不要只在 rank0 打印):

python 复制代码
from collections import Counter
steps = Counter()
for i, _ in enumerate(loader):
    steps[rank] += 1
dist.all_reduce(torch.tensor([steps[rank]], device=device), op=dist.ReduceOp.SUM)
# 或每个 rank 各自 print,检查是否相等

我的现象 :有的 epoch,rank0 比 rank1 多 1--2 个 step

2️⃣开启 NCCL 调试

在启动前设置:

python 复制代码
export NCCL_DEBUG=INFO
export NCCL_ASYNC_ERROR_HANDLING=1
export NCCL_BLOCKING_WAIT=1

再跑一遍,可看到某些 allreduce 一直等不到某 rank 进来。

3️⃣检查 Sampler 与 DataLoader 参数
  • DistributedSampler 必须 搭配 sampler.set_epoch(epoch)
  • DataLoader 里不要再写 shuffle=True
  • 若数据不可整除,优先 drop_last=True;否则确保各 rank 最后一个 batch 大小一致(例如补齐/填充)。

🟢 解决方案(修复版)

  • 严格对齐 Sampler 语义 + 丢最后不齐整的 batch
python 复制代码
import os, torch, torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler, Dataset

class DummyDS(Dataset):
    def __init__(self, N=1003): self.N=N
    def __len__(self): return self.N
    def __getitem__(self, i):
        x = torch.randn(32, 3, 224, 224)
        y = torch.randint(0, 10, (32,))
        return x, y

def setup():
    dist.init_process_group("nccl")
    torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))

def main():
    setup()
    rank = dist.get_rank()
    device = torch.device("cuda", int(os.environ["LOCAL_RANK"]))

    ds = DummyDS()
    # 关键 1:使用 DistributedSampler,统一交给它洗牌
    sampler = DistributedSampler(ds, shuffle=True, drop_last=True)  # ✅
    # 关键 2:DataLoader 里不要再写 shuffle
    loader = DataLoader(ds, batch_size=2, sampler=sampler, num_workers=4, pin_memory=True)

    model = torch.nn.Linear(3*224*224, 10).to(device)
    ddp = DDP(model, device_ids=[device.index], find_unused_parameters=False)  # 如无动态分支,关掉更稳更快
    opt = torch.optim.SGD(ddp.parameters(), lr=0.1)

    for epoch in range(5):
        sampler.set_epoch(epoch)  # ✅ 关键 3:每个 epoch 设置不同随机种子
        for x, y in loader:
            x = x.view(x.size(0), -1).to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)
            opt.zero_grad(set_to_none=True)
            loss = torch.nn.functional.cross_entropy(ddp(x), y)
            loss.backward()
            opt.step()
        if rank == 0:
            print(f"epoch {epoch} ok")

    dist.barrier()  # ✅ 收尾同步,避免 rank 提前退出
    dist.destroy_process_group()

if __name__ == "__main__":
    main()
  • 必须保留最后一批

如果确实不能 drop_last=True(例如小数据集),可考虑对齐 batch 大小

  1. Padding/Repeat :在 collate_fn 里把最后一批补齐到一致大小
  2. EvenlyDistributedSampler :自定义 sampler,确保各 rank 拿到完全等长的 index 列表(对总长度做上采样)。

示例(最简单的"循环补齐"):

python 复制代码
class EvenSampler(DistributedSampler):
    def __iter__(self):
        # 先拿到原始 index,再做均匀补齐
        indices = list(super().__iter__())
        # 使得 len(indices) 可整除 num_replicas
        rem = len(indices) % self.num_replicas
        if rem != 0:
            pad = self.num_replicas - rem
            indices += indices[:pad]     # 简单重复前几个样本
        return iter(indices)

总结

以上是这次 DDP 卡死问题从现象 → 排查 → 解决 的完整记录。这个坑非常高频 ,尤其在课程项目/科研代码里常被忽视。希望这篇复盘能让你在分布式训练时少掉一把汗。最终定位是 DistributedSampler 使用不当 + drop_last=False + 忘记 set_epoch引发各 rank 步数不一致,导致 allreduce 永久等待。

相关推荐
科技小花5 小时前
全球化深水区,数据治理成为企业出海 “核心竞争力”
大数据·数据库·人工智能·数据治理·数据中台·全球化
zhuiyisuifeng6 小时前
2026前瞻:GPTimage2镜像官网或将颠覆视觉创作
人工智能·gpt
徐健峰6 小时前
GPT-image-2 热门玩法实战(一):AI 看手相 — 一张手掌照片生成专业手相分析图
人工智能·gpt
weixin_370976356 小时前
AI的终极赛跑:进入AGI,还是泡沫破灭?
大数据·人工智能·agi
Slow菜鸟6 小时前
AI学习篇(五) | awesome-design-md 使用说明
人工智能·学习
冬奇Lab7 小时前
RAG 系列(五):Embedding 模型——语义理解的核心
人工智能·llm·aigc
深小乐7 小时前
AI 周刊【2026.04.27-05.03】:Anthropic 9000亿美元估值、英伟达死磕智能体、中央重磅定调AI
人工智能
码点滴7 小时前
什么时候用 DeepSeek V4,而不是 GPT-5/Claude/Gemini?
人工智能·gpt·架构·大模型·deepseek
狐狐生风7 小时前
LangChain 向量存储:Chroma、FAISS
人工智能·python·学习·langchain·faiss·agentai
波动几何7 小时前
CDA架构代码工坊技能cda-code-lab
人工智能