PyTorch DDP流程和SyncBN、ShuffleBN

整体框架

python 复制代码
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def train(local_rank, world_size):
    # 初始化进程组
    dist.init_process_group("nccl", rank=local_rank, world_size=world_size)
    torch.cuda.set_device(local_rank)
    
    # 数据加载器
    dataset = MyDataset()
    sampler = DistributedSampler(dataset)
    dataloader = DataLoader(dataset, sampler=sampler)
    
    # 模型与DDP包装
    model = MyModel().to(local_rank)
    model = DDP(model, device_ids=[local_rank])
    
    # 优化器
    optimizer = torch.optim.Adam(model.parameters())
    
    # 训练循环
    for epoch in range(epochs):
        sampler.set_epoch(epoch)
        for batch in dataloader:
            outputs = model(batch)
            loss = compute_loss(outputs)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        
        # 保存模型(仅主进程)
        if local_rank == 0:
            torch.save(model.module.state_dict(), "model.pth")
    
    # 清理进程组
    dist.destroy_process_group()

分布式命令

bash 复制代码
python3 -m torch.distributed.run --nproc_per_node 2 --nnodes 1 --node_rank $RANK --master_addr $ADDR --master_port $PORT

RANK, ADDR, PORT这些每个进程都需要各自配,每个卡是一个进审

SyncBN和ShuffleBN的区别

  • SyncBN实际上是把个卡上的数据拼在一起算了个更大更全的均值和方差
  • ShuffleBN是过Encoder前先shuffle,过完Encoder再unshuffle,防止模型中的Encoder层带着BN给偷懒

BN如何处理:SyncBN

涉及到BN时,应该使用SyncBN,代码大致如下:

bash 复制代码
model = MyModel()  # 定义原始模型(包含普通BN层)
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)  # 替换为SyncBN

SyncBN的通信原语为all_gather,实际上就是把各块卡搜集到的均值和方差汇总起来。PyTorch DDP中的BN同步通过 all_gather 原语实现全局统计量聚合,既保证了多卡训练的一致性,又通过优化通信量避免了性能瓶颈。

MoCo V1中的ShuffleBN

MoCo V1中就使用了Shuffling BN的操作。BN大部分的时候是在当前GPU上算的,使用BN的时候BN的running mean和runnning variance很容易让模型找到正确的解。Shuffling BN就是算之前先把样本顺序打乱,送到多卡上,算完再合在一起。

来看一下MoCoV1中的实现,其实就是过Encoder前_batch_shuffle_ddp,过完Encoder后再_batch_unshuffle_ddp恢复回去:

bash 复制代码
import torch
import torch.nn as nn


class MoCo(nn.Module):
    """
    Build a MoCo model with: a query encoder, a key encoder, and a queue
    https://arxiv.org/abs/1911.05722
    """

    def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07, mlp=False):
        """
        dim: feature dimension (default: 128)
        K: queue size; number of negative keys (default: 65536)
        m: moco momentum of updating key encoder (default: 0.999)
        T: softmax temperature (default: 0.07)
        """
        super(MoCo, self).__init__()

        self.K = K
        self.m = m
        self.T = T

        # create the encoders
        # num_classes is the output fc dimension
        self.encoder_q = base_encoder(num_classes=dim)
        self.encoder_k = base_encoder(num_classes=dim)

        if mlp:  # hack: brute-force replacement
            dim_mlp = self.encoder_q.fc.weight.shape[1]
            self.encoder_q.fc = nn.Sequential(
                nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc
            )
            self.encoder_k.fc = nn.Sequential(
                nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc
            )

        for param_q, param_k in zip(
            self.encoder_q.parameters(), self.encoder_k.parameters()
        ):
            param_k.data.copy_(param_q.data)  # initialize
            param_k.requires_grad = False  # not update by gradient

        # create the queue
        self.register_buffer("queue", torch.randn(dim, K))
        self.queue = nn.functional.normalize(self.queue, dim=0)

        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        """
        Momentum update of the key encoder
        """
        for param_q, param_k in zip(
            self.encoder_q.parameters(), self.encoder_k.parameters()
        ):
            param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m)

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        # gather keys before updating queue
        keys = concat_all_gather(keys)

        batch_size = keys.shape[0]

        ptr = int(self.queue_ptr)
        assert self.K % batch_size == 0  # for simplicity

        # replace the keys at ptr (dequeue and enqueue)
        self.queue[:, ptr : ptr + batch_size] = keys.T
        ptr = (ptr + batch_size) % self.K  # move pointer

        self.queue_ptr[0] = ptr

    @torch.no_grad()
    def _batch_shuffle_ddp(self, x):
        """
        Batch shuffle, for making use of BatchNorm.
        *** Only support DistributedDataParallel (DDP) model. ***
        """
        # gather from all gpus
        batch_size_this = x.shape[0]
        x_gather = concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # random shuffle index
        idx_shuffle = torch.randperm(batch_size_all).cuda()

        # broadcast to all gpus
        torch.distributed.broadcast(idx_shuffle, src=0)

        # index for restoring
        idx_unshuffle = torch.argsort(idx_shuffle)

        # shuffled index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this], idx_unshuffle

    @torch.no_grad()
    def _batch_unshuffle_ddp(self, x, idx_unshuffle):
        """
        Undo batch shuffle.
        *** Only support DistributedDataParallel (DDP) model. ***
        """
        # gather from all gpus
        batch_size_this = x.shape[0]
        x_gather = concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # restored index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this]

    def forward(self, im_q, im_k):
        """
        Input:
            im_q: a batch of query images
            im_k: a batch of key images
        Output:
            logits, targets
        """

        # compute query features
        q = self.encoder_q(im_q)  # queries: NxC
        q = nn.functional.normalize(q, dim=1)

        # compute key features
        with torch.no_grad():  # no gradient to keys
            self._momentum_update_key_encoder()  # update the key encoder

            # shuffle for making use of BN
            im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k)

            k = self.encoder_k(im_k)  # keys: NxC
            k = nn.functional.normalize(k, dim=1)

            # undo shuffle
            k = self._batch_unshuffle_ddp(k, idx_unshuffle)

        # compute logits
        # Einstein sum is more intuitive
        # positive logits: Nx1
        l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1)
        # negative logits: NxK
        l_neg = torch.einsum("nc,ck->nk", [q, self.queue.clone().detach()])

        # logits: Nx(1+K)
        logits = torch.cat([l_pos, l_neg], dim=1)

        # apply temperature
        logits /= self.T

        # labels: positive key indicators
        labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()

        # dequeue and enqueue
        self._dequeue_and_enqueue(k)

        return logits, labels


# utils
@torch.no_grad()
def concat_all_gather(tensor):
    """
    Performs all_gather operation on the provided tensors.
    *** Warning ***: torch.distributed.all_gather has no gradient.
    """
    tensors_gather = [
        torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
    ]
    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)

    output = torch.cat(tensors_gather, dim=0)
    return output
相关推荐
逢生博客几秒前
阿里 FunASR 开源中文语音识别大模型应用示例(准确率比faster-whisper高)
人工智能·python·语音识别·funasr
噔噔噔噔@4 分钟前
软件测试对于整个行业的重要性及必要性
python·单元测试·压力测试
赵谨言33 分钟前
基于Python的Django框架的个人博客管理系统
经验分享·python·毕业设计
Guarding and trust44 分钟前
python系统之综合案例:用python打造智能诗词生成助手
服务器·数据库·python
淮北4941 小时前
ros调试工具foxglove使用指南三:在3d空间写写画画(Panel->3D ->Scene entity)
python·学习·3d·机器人
mosquito_lover11 小时前
Python实现音频数字水印方法
python·音视频
哲讯智能科技1 小时前
智慧能源新篇章:SAP如何赋能光伏行业数字化转型
大数据·人工智能
云卓SKYDROID1 小时前
无人机DSP处理器工作要点!
人工智能·无人机·科普·云卓科技
gang_unerry1 小时前
量子退火与机器学习(2):少量实验即可找到新材料,黑盒优化➕量子退火
人工智能·机器学习·量子计算·量子退火
訾博ZiBo1 小时前
AI日报 - 2025年4月2日
人工智能