整体框架
python
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def train(local_rank, world_size):
# 初始化进程组
dist.init_process_group("nccl", rank=local_rank, world_size=world_size)
torch.cuda.set_device(local_rank)
# 数据加载器
dataset = MyDataset()
sampler = DistributedSampler(dataset)
dataloader = DataLoader(dataset, sampler=sampler)
# 模型与DDP包装
model = MyModel().to(local_rank)
model = DDP(model, device_ids=[local_rank])
# 优化器
optimizer = torch.optim.Adam(model.parameters())
# 训练循环
for epoch in range(epochs):
sampler.set_epoch(epoch)
for batch in dataloader:
outputs = model(batch)
loss = compute_loss(outputs)
loss.backward()
optimizer.step()
optimizer.zero_grad()
# 保存模型(仅主进程)
if local_rank == 0:
torch.save(model.module.state_dict(), "model.pth")
# 清理进程组
dist.destroy_process_group()
分布式命令
bash
python3 -m torch.distributed.run --nproc_per_node 2 --nnodes 1 --node_rank $RANK --master_addr $ADDR --master_port $PORT
RANK, ADDR, PORT这些每个进程都需要各自配,每个卡是一个进审
SyncBN和ShuffleBN的区别
- SyncBN实际上是把个卡上的数据拼在一起算了个更大更全的均值和方差。
- ShuffleBN是过Encoder前先shuffle,过完Encoder再unshuffle,防止模型中的Encoder层带着BN给偷懒
BN如何处理:SyncBN
涉及到BN时,应该使用SyncBN,代码大致如下:
bash
model = MyModel() # 定义原始模型(包含普通BN层)
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device) # 替换为SyncBN
SyncBN的通信原语为all_gather,实际上就是把各块卡搜集到的均值和方差汇总起来。PyTorch DDP中的BN同步通过 all_gather 原语实现全局统计量聚合,既保证了多卡训练的一致性,又通过优化通信量避免了性能瓶颈。
MoCo V1中的ShuffleBN
MoCo V1中就使用了Shuffling BN的操作。BN大部分的时候是在当前GPU上算的,使用BN的时候BN的running mean和runnning variance很容易让模型找到正确的解。Shuffling BN就是算之前先把样本顺序打乱,送到多卡上,算完再合在一起。
来看一下MoCoV1中的实现,其实就是过Encoder前_batch_shuffle_ddp,过完Encoder后再_batch_unshuffle_ddp恢复回去:
bash
import torch
import torch.nn as nn
class MoCo(nn.Module):
"""
Build a MoCo model with: a query encoder, a key encoder, and a queue
https://arxiv.org/abs/1911.05722
"""
def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07, mlp=False):
"""
dim: feature dimension (default: 128)
K: queue size; number of negative keys (default: 65536)
m: moco momentum of updating key encoder (default: 0.999)
T: softmax temperature (default: 0.07)
"""
super(MoCo, self).__init__()
self.K = K
self.m = m
self.T = T
# create the encoders
# num_classes is the output fc dimension
self.encoder_q = base_encoder(num_classes=dim)
self.encoder_k = base_encoder(num_classes=dim)
if mlp: # hack: brute-force replacement
dim_mlp = self.encoder_q.fc.weight.shape[1]
self.encoder_q.fc = nn.Sequential(
nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc
)
self.encoder_k.fc = nn.Sequential(
nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc
)
for param_q, param_k in zip(
self.encoder_q.parameters(), self.encoder_k.parameters()
):
param_k.data.copy_(param_q.data) # initialize
param_k.requires_grad = False # not update by gradient
# create the queue
self.register_buffer("queue", torch.randn(dim, K))
self.queue = nn.functional.normalize(self.queue, dim=0)
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
@torch.no_grad()
def _momentum_update_key_encoder(self):
"""
Momentum update of the key encoder
"""
for param_q, param_k in zip(
self.encoder_q.parameters(), self.encoder_k.parameters()
):
param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m)
@torch.no_grad()
def _dequeue_and_enqueue(self, keys):
# gather keys before updating queue
keys = concat_all_gather(keys)
batch_size = keys.shape[0]
ptr = int(self.queue_ptr)
assert self.K % batch_size == 0 # for simplicity
# replace the keys at ptr (dequeue and enqueue)
self.queue[:, ptr : ptr + batch_size] = keys.T
ptr = (ptr + batch_size) % self.K # move pointer
self.queue_ptr[0] = ptr
@torch.no_grad()
def _batch_shuffle_ddp(self, x):
"""
Batch shuffle, for making use of BatchNorm.
*** Only support DistributedDataParallel (DDP) model. ***
"""
# gather from all gpus
batch_size_this = x.shape[0]
x_gather = concat_all_gather(x)
batch_size_all = x_gather.shape[0]
num_gpus = batch_size_all // batch_size_this
# random shuffle index
idx_shuffle = torch.randperm(batch_size_all).cuda()
# broadcast to all gpus
torch.distributed.broadcast(idx_shuffle, src=0)
# index for restoring
idx_unshuffle = torch.argsort(idx_shuffle)
# shuffled index for this gpu
gpu_idx = torch.distributed.get_rank()
idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]
return x_gather[idx_this], idx_unshuffle
@torch.no_grad()
def _batch_unshuffle_ddp(self, x, idx_unshuffle):
"""
Undo batch shuffle.
*** Only support DistributedDataParallel (DDP) model. ***
"""
# gather from all gpus
batch_size_this = x.shape[0]
x_gather = concat_all_gather(x)
batch_size_all = x_gather.shape[0]
num_gpus = batch_size_all // batch_size_this
# restored index for this gpu
gpu_idx = torch.distributed.get_rank()
idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]
return x_gather[idx_this]
def forward(self, im_q, im_k):
"""
Input:
im_q: a batch of query images
im_k: a batch of key images
Output:
logits, targets
"""
# compute query features
q = self.encoder_q(im_q) # queries: NxC
q = nn.functional.normalize(q, dim=1)
# compute key features
with torch.no_grad(): # no gradient to keys
self._momentum_update_key_encoder() # update the key encoder
# shuffle for making use of BN
im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k)
k = self.encoder_k(im_k) # keys: NxC
k = nn.functional.normalize(k, dim=1)
# undo shuffle
k = self._batch_unshuffle_ddp(k, idx_unshuffle)
# compute logits
# Einstein sum is more intuitive
# positive logits: Nx1
l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1)
# negative logits: NxK
l_neg = torch.einsum("nc,ck->nk", [q, self.queue.clone().detach()])
# logits: Nx(1+K)
logits = torch.cat([l_pos, l_neg], dim=1)
# apply temperature
logits /= self.T
# labels: positive key indicators
labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
# dequeue and enqueue
self._dequeue_and_enqueue(k)
return logits, labels
# utils
@torch.no_grad()
def concat_all_gather(tensor):
"""
Performs all_gather operation on the provided tensors.
*** Warning ***: torch.distributed.all_gather has no gradient.
"""
tensors_gather = [
torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
]
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
output = torch.cat(tensors_gather, dim=0)
return output