[源码解析] PyTorch 分布式(1) --- 数据加载之DistributedSampler - 罗西的思考 - 博客园
初始化
python
class DistributedSampler(Sampler[T_co]):
def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None,
rank: Optional[int] = None, shuffle: bool = True,
seed: int = 0, drop_last: bool = False) -> None:
# If the dataset length is evenly divisible by # of replicas, then there
# is no need to drop any data, since the dataset will be split equally.
if self.drop_last and len(self.dataset) % self.num_replicas != 0:
# Split to nearest available length that is evenly divisible.
# This is to ensure each rank receives the same amount of data when
# using this Sampler.
self.num_samples = math.ceil(
(len(self.dataset) - self.num_replicas) / self.num_replicas
)
else:
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # 向上取整
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
self.seed = seed
如果不drop_last,那就
python
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
向上取整。
迭代
在迭代的时候,如果不能整除,那就把indices的前几个样本复制一遍:
python
def __iter__(self) -> Iterator[T_co]:
if self.shuffle:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = list(range(len(self.dataset)))
if not self.drop_last: # 一般进入这里,不会丢掉剩下的训练数据
# add extra samples to make it evenly divisible
padding_size = self.total_size - len(indices)
if padding_size <= len(indices):
indices += indices[:padding_size] # 把indices的前几个复制一次
else:
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
else:
# remove tail of data to make it evenly divisible.
indices = indices[:self.total_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
最关键的还是这行:
python
indices = indices[self.rank:self.total_size:self.num_replicas]
规定了每个rank的取数据的索引,起始索引是rank,每间隔num_replicas取一个
shuffle数据集
每次epoch都会shuffle数据集,但是不同进程如何保持shuffle之后数据集一致性?
++DistributedSampler 使用当前的epoch作为随机数种子,在计算index之前就进行配置,从而保证不同进程都使用同样的随机数种子++,这样shuffle出来的数据就能确保一致。
python
sampler = DistributedSampler(dataset) if is_distributed else None
loader = DataLoader(dataset, shuffle=(sampler is None), ...,
sampler=sampler)
for epoch in range(start_epoch, n_epochs):
if is_distributed:
sampler.set_epoch(epoch) # 这设置epoch
train(loader)
设置 random 种子的具体使用是在迭代函数之中:
python
def __iter__(self) -> Iterator[T_co]:
if self.shuffle:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch) # 这里设置随机种子
indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
else:
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
在 PyTorch 中,torch.randperm(n)
函数用于生成一个从 0
到 n-1
的随机排列的整数序列。这个函数是非常有用的,尤其是在需要随机打乱数据或索引时,比如在训练机器学习模型时打乱数据顺序,以确保模型训练的泛化能力。