PyTorch分布式数据加载学习 DistributedSampler

[源码解析] PyTorch 分布式(1) --- 数据加载之DistributedSampler - 罗西的思考 - 博客园

初始化

python 复制代码
class DistributedSampler(Sampler[T_co]):
    def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None,
                 rank: Optional[int] = None, shuffle: bool = True,
                 seed: int = 0, drop_last: bool = False) -> None:
        # If the dataset length is evenly divisible by # of replicas, then there
        # is no need to drop any data, since the dataset will be split equally.
        if self.drop_last and len(self.dataset) % self.num_replicas != 0:
            # Split to nearest available length that is evenly divisible.
            # This is to ensure each rank receives the same amount of data when
            # using this Sampler.
            self.num_samples = math.ceil(
                (len(self.dataset) - self.num_replicas) / self.num_replicas
            )
        else:
            self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # 向上取整
        self.total_size = self.num_samples * self.num_replicas
        self.shuffle = shuffle
        self.seed = seed

如果不drop_last,那就

python 复制代码
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)

向上取整。

迭代

在迭代的时候,如果不能整除,那就把indices的前几个样本复制一遍:

python 复制代码
    def __iter__(self) -> Iterator[T_co]:
        if self.shuffle:
            # deterministically shuffle based on epoch and seed
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)
            indices = torch.randperm(len(self.dataset), generator=g).tolist()
        else:
            indices = list(range(len(self.dataset)))

        if not self.drop_last:  # 一般进入这里,不会丢掉剩下的训练数据
            # add extra samples to make it evenly divisible
            padding_size = self.total_size - len(indices)
            if padding_size <= len(indices):
                indices += indices[:padding_size]   # 把indices的前几个复制一次
            else:
                indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
        else:
            # remove tail of data to make it evenly divisible.
            indices = indices[:self.total_size]
        assert len(indices) == self.total_size

        # subsample
        indices = indices[self.rank:self.total_size:self.num_replicas]
        assert len(indices) == self.num_samples

        return iter(indices)

最关键的还是这行:

python 复制代码
indices = indices[self.rank:self.total_size:self.num_replicas]

规定了每个rank的取数据的索引,起始索引是rank,每间隔num_replicas取一个

shuffle数据集

每次epoch都会shuffle数据集,但是不同进程如何保持shuffle之后数据集一致性

++DistributedSampler 使用当前的epoch作为随机数种子,在计算index之前就进行配置,从而保证不同进程都使用同样的随机数种子++,这样shuffle出来的数据就能确保一致。

python 复制代码
sampler = DistributedSampler(dataset) if is_distributed else None
loader = DataLoader(dataset, shuffle=(sampler is None), ...,
                            sampler=sampler)
	for epoch in range(start_epoch, n_epochs):
    	if is_distributed:
        	sampler.set_epoch(epoch) # 这设置epoch
        train(loader)

设置 random 种子的具体使用是在迭代函数之中:

python 复制代码
    def __iter__(self) -> Iterator[T_co]:
        if self.shuffle:
            # deterministically shuffle based on epoch and seed
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch) # 这里设置随机种子
            indices = torch.randperm(len(self.dataset), generator=g).tolist()  # type: ignore[arg-type]
        else:
            indices = list(range(len(self.dataset)))  # type: ignore[arg-type]

在 PyTorch 中,torch.randperm(n) 函数用于生成一个从 0n-1 的随机排列的整数序列。这个函数是非常有用的,尤其是在需要随机打乱数据或索引时,比如在训练机器学习模型时打乱数据顺序,以确保模型训练的泛化能力。

相关推荐
藓类少女8 天前
【深度学习】使用硬件加速模型训练速度
人工智能·深度学习·分布式训练·gpu
Amd7942 个月前
Nuxt.js 应用中的 app:beforeMount 钩子详解
生命周期·初始化·数据加载·钩子·用户认证·nuxtjs·应用优化
爱喝白开水a3 个月前
AI大语言模型之分布式训练概述
人工智能·llm·分布式训练·embedding·ai大模型·计算机技术·大模型训练
AI_小站3 个月前
LLM分布式预训练浅析
人工智能·llm·分布式训练·大语言模型·agent·计算机技术·大模型应用
Hi202402175 个月前
将数据切分成N份,采用NCCL异步通信,让all_gather+matmul尽量Overlap
pytorch·python·性能优化·分布式训练·nccl·融合算子
XMoyas5 个月前
deepspeed docker集群实现多机多卡训练----问题记录及解决方案资源汇总
docker·大模型·分布式训练·deepspeed·多机多卡
hjxu20166 个月前
大模型训练框架DeepSpeed使用入门(1): 训练设置
大模型·分布式训练·deepspeed
华为云开发者联盟6 个月前
理论+实践,带你了解分布式训练
机器学习·分布式训练·大语言模型·华为云开发者联盟
李一二7 个月前
Pytorch DistributedDataParallel(DDP)教程二:快速入门实践篇
分布式训练·ddp