PyTorch分布式数据加载学习 DistributedSampler

[源码解析] PyTorch 分布式(1) --- 数据加载之DistributedSampler - 罗西的思考 - 博客园

初始化

python 复制代码
class DistributedSampler(Sampler[T_co]):
    def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None,
                 rank: Optional[int] = None, shuffle: bool = True,
                 seed: int = 0, drop_last: bool = False) -> None:
        # If the dataset length is evenly divisible by # of replicas, then there
        # is no need to drop any data, since the dataset will be split equally.
        if self.drop_last and len(self.dataset) % self.num_replicas != 0:
            # Split to nearest available length that is evenly divisible.
            # This is to ensure each rank receives the same amount of data when
            # using this Sampler.
            self.num_samples = math.ceil(
                (len(self.dataset) - self.num_replicas) / self.num_replicas
            )
        else:
            self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # 向上取整
        self.total_size = self.num_samples * self.num_replicas
        self.shuffle = shuffle
        self.seed = seed

如果不drop_last,那就

python 复制代码
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)

向上取整。

迭代

在迭代的时候,如果不能整除,那就把indices的前几个样本复制一遍:

python 复制代码
    def __iter__(self) -> Iterator[T_co]:
        if self.shuffle:
            # deterministically shuffle based on epoch and seed
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)
            indices = torch.randperm(len(self.dataset), generator=g).tolist()
        else:
            indices = list(range(len(self.dataset)))

        if not self.drop_last:  # 一般进入这里,不会丢掉剩下的训练数据
            # add extra samples to make it evenly divisible
            padding_size = self.total_size - len(indices)
            if padding_size <= len(indices):
                indices += indices[:padding_size]   # 把indices的前几个复制一次
            else:
                indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
        else:
            # remove tail of data to make it evenly divisible.
            indices = indices[:self.total_size]
        assert len(indices) == self.total_size

        # subsample
        indices = indices[self.rank:self.total_size:self.num_replicas]
        assert len(indices) == self.num_samples

        return iter(indices)

最关键的还是这行:

python 复制代码
indices = indices[self.rank:self.total_size:self.num_replicas]

规定了每个rank的取数据的索引,起始索引是rank,每间隔num_replicas取一个

shuffle数据集

每次epoch都会shuffle数据集,但是不同进程如何保持shuffle之后数据集一致性

++DistributedSampler 使用当前的epoch作为随机数种子,在计算index之前就进行配置,从而保证不同进程都使用同样的随机数种子++,这样shuffle出来的数据就能确保一致。

python 复制代码
sampler = DistributedSampler(dataset) if is_distributed else None
loader = DataLoader(dataset, shuffle=(sampler is None), ...,
                            sampler=sampler)
	for epoch in range(start_epoch, n_epochs):
    	if is_distributed:
        	sampler.set_epoch(epoch) # 这设置epoch
        train(loader)

设置 random 种子的具体使用是在迭代函数之中:

python 复制代码
    def __iter__(self) -> Iterator[T_co]:
        if self.shuffle:
            # deterministically shuffle based on epoch and seed
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch) # 这里设置随机种子
            indices = torch.randperm(len(self.dataset), generator=g).tolist()  # type: ignore[arg-type]
        else:
            indices = list(range(len(self.dataset)))  # type: ignore[arg-type]

在 PyTorch 中,torch.randperm(n) 函数用于生成一个从 0n-1 的随机排列的整数序列。这个函数是非常有用的,尤其是在需要随机打乱数据或索引时,比如在训练机器学习模型时打乱数据顺序,以确保模型训练的泛化能力。

相关推荐
Yeliang Wu8 天前
LLaMA-Factory 训练方法原理及实践(Ubuntu 22.04)
微调·分布式训练·量化·llamafactory
Yeliang Wu9 天前
LLaMA-Factory 分布式训练实践
大模型·微调·分布式训练·llamafactory·调优算法
Yeliang Wu9 天前
从原理到部署:LLaMA Factory 量化实战(Ubuntu 22.04)——PTQ/GPTQ/AWQ 等 9 种方法
大模型·微调·分布式训练·llamafactory·调优算法
Yeliang Wu9 天前
LLaMA-Factory 加速技术全解析:FlashAttention/Unsloth/Liger Kernel 原理与 Ubuntu22.04 实践指南
微调·分布式训练·llamafactory·调优算法
howard200523 天前
5.1 Hive加载数据实战
hive·数据加载
Xxtaoaooo2 个月前
原生多模态AI架构:统一训练与跨模态推理的系统实现与性能优化
人工智能·架构·分布式训练·多模态·模型优化
爱分享的飘哥4 个月前
第七十章:告别“手写循环”噩梦!Trainer结构搭建:PyTorch Lightning让你“一键炼丹”!
人工智能·pytorch·分布式训练·lightning·accelerate·训练框架·trainer
IT Panda6 个月前
[分布式并行策略] 数据并行 DP/DDP/FSDP/ZeRO
pytorch·分布式训练·dp·deepspeed·ddp·fsdp·zero
leo03086 个月前
torch.distributed.launch 、 torchrun 和 torch.distributed.run 无法与 nohup 兼容
人工智能·pytorch·python·深度学习·分布式训练
yuanlulu1 年前
llamafactory使用8张昇腾910b算力卡lora微调训练qwen2-72b大模型
lora·llm·transformer·分布式训练·大语言模型·huggingface·多卡训练