PyTorch DDP 随机卡死复盘:最后一个 batch 挂起,NCCL 等待不返回,三步修复 Sampler & drop_last
很多人在接触深度学习的过程往往都是从自己的笔记本开始的,但是从接触工作后,更多的是通过分布式的训练来模型。由于经验的不足常常会遇到分布式训练"玄学卡死":多卡的训练偶发在 epoch 尾部停住不动,并且GPU 利用率掉到 0%,日志无异常。为了解决首次接触分布式训练的人的疑问,本文从bug现象以及调试逐一分析。
❓ Bug 现象
在我们进行多卡训练的时候,偶尔会出现随机在某些 epoch 尾部卡住,无异常栈;nvidia-smi
显示两卡功耗接近空闲。偶尔能看到 NCCL 打印(并不总出现):
cmd
NCCL WARN Reduce failed: ... Async operation timed out
接着通过kill -SIGQUIT
打印 Python 栈后发现停在 反向传播的梯度 allreduce*上(DistributedDataParallel
内部)。
但是这个现象在关掉 DDP(单卡训练)完全正常;把 batch_size 改小/大,卡住概率改变但仍会发生。
📽️ 场景重现
当我们的问题在单卡不会出现,但是多卡会出现问题的时候,问题点集中在数据的问题上。主要原因以下:
1️⃣ shuffle=True
与 DistributedSampler
混用(会被忽略但容易误导)。
2️⃣ drop_last=False
时,最后一个小批的样本数在不同 rank 上可能不一致 (当 len(dataset)
不是 world_size
的整数倍且某些数据被过滤/增强丢弃时尤其明显)。
3️⃣ 每个 epoch 忘记调用 sampler.set_epoch(epoch)
,导致各 rank 的随机顺序不一致。
以下是笔者在多卡训练遇到的问题代码
python
import os, random, torch, torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, Dataset, DistributedSampler
class DummyDS(Dataset):
def __init__(self, N=1003): # 刻意设成非 world_size 整数倍
self.N = N
def __len__(self): return self.N
def __getitem__(self, i):
x = torch.randn(32, 3, 224, 224)
y = torch.randint(0, 10, (32,)) # 模拟有时会丢弃某些样本的增强(省略)
return x, y
def setup():
dist.init_process_group("nccl")
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
def main():
setup()
rank = dist.get_rank()
device = torch.device("cuda", int(os.environ["LOCAL_RANK"]))
ds = DummyDS()
sampler = DistributedSampler(ds, shuffle=True, drop_last=False) # ❌ drop_last=False
# ❌ DataLoader 里又写了 shuffle=True(被忽略,但容易误以为生效)
loader = DataLoader(ds, batch_size=2, shuffle=True, sampler=sampler, num_workers=4)
model = torch.nn.Linear(3*224*224, 10).to(device)
model = DDP(model, device_ids=[device.index])
opt = torch.optim.SGD(model.parameters(), lr=0.1)
for epoch in range(5):
# ❌ 忘记 sampler.set_epoch(epoch)
for x, y in loader:
x = x.view(x.size(0), -1).to(device)
y = y.to(device)
opt.zero_grad()
loss = torch.nn.functional.cross_entropy(model(x), y)
loss.backward() # 🔥 偶发卡在这里(allreduce)
opt.step()
if rank == 0:
print(f"epoch {epoch} done")
dist.destroy_process_group()
if __name__ == "__main__":
main()
🔴触发条件(满足一两个就可能复现):
1️⃣ len(dataset)
不是 world_size
的整数倍。
2️⃣ 动态数据过滤/增强(例如有时返回 None
或丢样),导致各 rank 实际步数不同。
3️⃣ 忘记 sampler.set_epoch(epoch)
,各 rank 洗牌次序不同。
4️⃣ drop_last=False
,导致最后一个 batch 在各 rank 的样本数不同。
5️⃣ 某些自定义 collate_fn 在"空 batch"时直接 continue
。
✔️ Debug
1️⃣ 先确认"各 rank 步数一致"
在训练 loop 里加统计(不要只在 rank0 打印):
python
from collections import Counter
steps = Counter()
for i, _ in enumerate(loader):
steps[rank] += 1
dist.all_reduce(torch.tensor([steps[rank]], device=device), op=dist.ReduceOp.SUM)
# 或每个 rank 各自 print,检查是否相等
我的现象 :有的 epoch,rank0 比 rank1 多 1--2 个 step。
2️⃣开启 NCCL 调试
在启动前设置:
python
export NCCL_DEBUG=INFO
export NCCL_ASYNC_ERROR_HANDLING=1
export NCCL_BLOCKING_WAIT=1
再跑一遍,可看到某些 allreduce 一直等不到某 rank 进来。
3️⃣检查 Sampler 与 DataLoader 参数
DistributedSampler
必须 搭配sampler.set_epoch(epoch)
。- DataLoader 里不要再写
shuffle=True
。 - 若数据不可整除,优先
drop_last=True
;否则确保各 rank 最后一个 batch 大小一致(例如补齐/填充)。
🟢 解决方案(修复版)
- 严格对齐 Sampler 语义 + 丢最后不齐整的 batch
python
import os, torch, torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler, Dataset
class DummyDS(Dataset):
def __init__(self, N=1003): self.N=N
def __len__(self): return self.N
def __getitem__(self, i):
x = torch.randn(32, 3, 224, 224)
y = torch.randint(0, 10, (32,))
return x, y
def setup():
dist.init_process_group("nccl")
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
def main():
setup()
rank = dist.get_rank()
device = torch.device("cuda", int(os.environ["LOCAL_RANK"]))
ds = DummyDS()
# 关键 1:使用 DistributedSampler,统一交给它洗牌
sampler = DistributedSampler(ds, shuffle=True, drop_last=True) # ✅
# 关键 2:DataLoader 里不要再写 shuffle
loader = DataLoader(ds, batch_size=2, sampler=sampler, num_workers=4, pin_memory=True)
model = torch.nn.Linear(3*224*224, 10).to(device)
ddp = DDP(model, device_ids=[device.index], find_unused_parameters=False) # 如无动态分支,关掉更稳更快
opt = torch.optim.SGD(ddp.parameters(), lr=0.1)
for epoch in range(5):
sampler.set_epoch(epoch) # ✅ 关键 3:每个 epoch 设置不同随机种子
for x, y in loader:
x = x.view(x.size(0), -1).to(device, non_blocking=True)
y = y.to(device, non_blocking=True)
opt.zero_grad(set_to_none=True)
loss = torch.nn.functional.cross_entropy(ddp(x), y)
loss.backward()
opt.step()
if rank == 0:
print(f"epoch {epoch} ok")
dist.barrier() # ✅ 收尾同步,避免 rank 提前退出
dist.destroy_process_group()
if __name__ == "__main__":
main()
- 必须保留最后一批
如果确实不能 drop_last=True
(例如小数据集),可考虑对齐 batch 大小:
- Padding/Repeat :在
collate_fn
里把最后一批补齐到一致大小; - EvenlyDistributedSampler :自定义 sampler,确保各 rank 拿到完全等长的 index 列表(对总长度做上采样)。
示例(最简单的"循环补齐"):
python
class EvenSampler(DistributedSampler):
def __iter__(self):
# 先拿到原始 index,再做均匀补齐
indices = list(super().__iter__())
# 使得 len(indices) 可整除 num_replicas
rem = len(indices) % self.num_replicas
if rem != 0:
pad = self.num_replicas - rem
indices += indices[:pad] # 简单重复前几个样本
return iter(indices)
总结
以上是这次 DDP 卡死问题从现象 → 排查 → 解决 的完整记录。这个坑非常高频 ,尤其在课程项目/科研代码里常被忽视。希望这篇复盘能让你在分布式训练时少掉一把汗。最终定位是 DistributedSampler
使用不当 + drop_last=False
+ 忘记 set_epoch
引发各 rank 步数不一致,导致 allreduce 永久等待。