torch.utils.data

整体架构

平时使用 pytorch 加载数据时大概是这样的:

py 复制代码
import numpy as np
from torch.utils.data import Dataset, DataLoader

class ExampleDataset(Dataset):
	def __init__(self):
		self.data = [1, 2, 3, 4, 5]

	def __getitem__(self, idx):
		return self.data[idx]

	def __len__(self):
		return len(self.data)

def collate_fn(batch):
	return np.array(batch)

dataset = ExampleDataset()  # create the dataset
dataloader = DataLoader(
	dataset=dataset,
	batch_size=2,
	shuffle=True,
	num_workers=4,
	collate_fn=collate_fn
)
for datapoint in dataloader:
	print(datapoint)
  1. 继承 Dataset 类,定义一个迭代器,包含两个魔法方法:__getitem__(self, idx)__len__(self),分别实现如何获取一条数据和如何设定数据长度;
  2. 定义 collate_fn 函数,设定如何组织一个 batch
  3. 实例化 Dataset,并和 collate_fn 一起传入 DataLoader,参数 batch_size 设置批大小、shuffle 设置是否打乱、num_workers 设置并行加载数据的进程数。

然而,背后到底干了什么,我们不清楚,甚至遇到 DataLoader 的如 samplerbatch_samplerworker_init_fn 的其他参数,就会懵逼。那就看一看官方文档,了解一下 torch.utils.data 是如何工作的。

上图是数据加载的整体框架图,官网说 DataLoader 组合datasetsampler,多个 workers 根据 dataset 提供的数据副本sampler 提供的 keys 并行地加载数据,并通过 collate_fn 组成 batch 供用户迭代。需要注意的有:

  1. 每个 worker 持有数据的一个副本,故占用内存 "主线程内存 * num_workers";
  2. 即使用户不提供 sampler 对象 (通常不提供),DataLoader 也会根据 shuffle 参数创建一个默认的 sampler 对象 ;一旦提供了,其前路的 shuffle 参数不能为 True (不提供就好);
  3. 即使用户不提供 batch_sampler 对象 (通常不提供),DataLoader 也会根据 batch_sampler, drop_last 参数创建一个默认的 batch_sampler 对象 ;一旦提供了,其前路的 shuffle, drop_last 不能为 Truebatch_size 必须为 1 1 1,sampler 必须为 None,因为创建 BatchSampler 时已经有了这些参数;

    本质上是把创建 batch_sampler 的活拉出来由用户在 DataLoader 外自定义地做了。

Dataset

分为两种:map-styleiterable-style 。前者的数据可通过 [idx or key] 访问,后者的数据只能通过迭代器 next 一个个访问。所以上面架构中的采样器是对于 map-style 数据集说的iterable-style 的数据集的访问顺序由迭代器决定。

Sampler

torch.utils.data.Sampler 的子类或 Iterable,两个例子:

py 复制代码
class AccedingSequenceLengthSampler(tu_data.Sampler[int]):
	def __init__(self, data: List[str]) -> None:
		super().__init__()
		self.data = data

	def __len__(self) -> int:
		return len(self.data)

	def __iter__(self) -> Iterator[int]:
		"""
		:return: 实现了按数据长短顺序访问数据集
		"""
		sizes = torch.tensor([len(x) for x in self.data])
		yield from torch.argsort(sizes).tolist()


class AccedingSequenceLengthBatchSampler(tu_data.Sampler[List[int]]):
	def __init__(self, data: List[str], batch_size: int) -> None:
		super().__init__()
		self.data = data
		self.batch_size = batch_size

	def __len__(self) -> int:
		return (len(self.data) + self.batch_size - 1) // self.batch_size

	def __iter__(self) -> Iterator[List[int]]:
		sizes = torch.tensor([len(x) for x in self.data])
		for batch in torch.chunk(torch.argsort(sizes), len(self)):  # 按块遍历
			yield batch.tolist()

Batch

batch_sampler 提供一批下标,取得一批数据后由 collate_fn 将这批数据整合:

py 复制代码
if collate_fn is None:
	if self._auto_collation:
		collate_fn = _utils.collate.default_collate
	else:  # self.batch_sampler is None: (batch_size is None) and (batch_sampler is None)
		collate_fn = _utils.collate.default_convert

分两种情况:

  • automatic batching is disabled :调用 default_convert 函数简单地将 NumPy arrays 转化为 PyTorch Tensor;
  • automatic batching is enabled :调用 default_collate 函数,转化会变得复杂一点:
py 复制代码
from torch.utils import data as tu_data
import collections

# %% Example with a batch of `int`s:
tu_data.default_collate([0, 1, 2, 3])
# tensor([0, 1, 2, 3])

# %% Example with a batch of `str`s:
tu_data.default_collate(['a', 'b', 'c'])
# ['a', 'b', 'c']

# %% Example with `Map` inside the batch:
tu_data.default_collate([
	{'A': 0, 'B': 1},
	{'A': 100, 'B': 100}
])
# {'A': tensor([0, 100]), 'B': tensor([1, 100])}, 同 key 的合并了

# %% Example with `NamedTuple` inside the batch:
Point = collections.namedtuple('Point', ['x', 'y'])
tu_data.default_collate([Point(0, 0), Point(1, 1)])
# Point(x=tensor([0, 1]), y=tensor([0, 1])), 同 name 的合并了, 大概和 dict 一样吧

# %% Example with `Tuple` inside the batch:
tu_data.default_collate([(0, 1), (2, 3)])
# [tensor([0, 2]), tensor([1, 3])], 对 list 内部执行 collate

# %% Example with `List` inside the batch:
tu_data.default_collate([[0, 1], [2, 3]])  # [tensor([0, 2]), tensor([1, 3])], 对 list 内部执行 collate, 并没有变成二维 tensor

Multi-process Data Loading

dataset, collate_fn, and worker_init_fn are passed to each worker ,大概能说明 batch 是在子进程内部 合成的。

有一个需要注意的地方是内存增长问题,当 __get_item__(self, key) 访问数据时,由于 Python 对象的 refcount 机制,数据会不断地复制 ,从而内存爆炸。但这里说解决 number of workers * size of parent process 问题,就不追究了,反正尽量用 numpy 或 pytorch tensor 吧。
iterable-style datasets 的随机性

相关推荐
jieshenai6 分钟前
torch 高维矩阵乘法分析,一文说透
pytorch·深度学习·矩阵
苏苏susuus16 小时前
深度学习:PyTorch张量基本运算、形状改变、索引操作、升维降维、维度转置、张量拼接
人工智能·pytorch·深度学习
凡人的AI工具箱18 小时前
PyTorch深度学习框架60天进阶学习计划 - 第58天端到端对话系统(一):打造你的专属AI语音助手
人工智能·pytorch·python·深度学习·mcp·a2a
知舟不叙20 小时前
深度学习——基于PyTorch的MNIST手写数字识别详解
人工智能·pytorch·深度学习·手写数字识别
Crabfishhhhh1 天前
神经网络学习-神经网络简介【Transformer、pytorch、Attention介绍与区别】
pytorch·python·神经网络·学习·transformer
whyeekkk1 天前
python打卡第52天
pytorch·python·深度学习
猎嘤一号1 天前
使用 PyTorch 和 SwanLab 实时可视化模型训练
人工智能·pytorch·深度学习
福大大架构师每日一题1 天前
pytorch v2.7.1 发布!全面修复关键BUG,性能与稳定性再升级,2025年深度学习利器必备!
pytorch·深度学习·bug
凡人的AI工具箱1 天前
PyTorch深度学习框架60天进阶学习计划-第57天:因果推理模型(二)- 高级算法与深度学习融合
人工智能·pytorch·深度学习·学习·mcp·a2a
四川兔兔1 天前
pytorch 之 nn 库与调试
人工智能·pytorch·python