通常,数据集通过__getitem__
方法返回单个样本,而DataLoader负责将这些样本批量组合。以下是常见的返回值类型:
张量(Tensor):最常见的情况,返回一个或多个张量。DataLoader会自动将多个样本的张量堆叠成批次。
列表(List):可以返回一个列表,其中包含多个张量或其他类型。DataLoader会尝试将列表中的每个元素分别批量处理。
字典(Dictionary):返回一个字典,键是数据字段名,值是对应的张量或数据。DataLoader会按字段名分别批量处理。
元组(Tuple):返回一个元组,其中包含多个张量或其他类型。DataLoader会分别对元组中的每个元素进行批量处理。
命名元组(NamedTuple):类似于元组,但可以通过字段名访问,DataLoader处理方式与元组类似。
自定义数据类型 :如果返回的是自定义类型,需要确保DataLoader知道如何组合这些类型。通常,自定义类型需要实现相应的拼接方法,或者使用
default_collate
函数能够处理。
示例1:返回张量
import torch
from torch.utils.data import Dataset, DataLoader
class TensorDataset(Dataset):
def __len__(self):
return 10
def __getitem__(self, index):
return torch.tensor([index, index*2])
dataset = TensorDataset()
dataloader = DataLoader(dataset, batch_size=2)
for batch in dataloader:
print(batch)
# 输出:一个批次的张量,形状为[2, 2]
示例2:返回元组(多个张量)
class TupleDataset(Dataset):
def __len__(self):
return 10
def __getitem__(self, index):
return torch.tensor(index), torch.tensor(index*2)
dataset = TupleDataset()
dataloader = DataLoader(dataset, batch_size=2)
for batch in dataloader:
print(batch)
# 输出:一个元组,包含两个张量,每个张量的形状为[2]
示例3:返回字典
class DictDataset(Dataset):
def __len__(self):
return 10
def __getitem__(self, index):
return {'data': torch.tensor(index), 'label': torch.tensor(index*2)}
dataset = DictDataset()
dataloader = DataLoader(dataset, batch_size=2)
for batch in dataloader:
print(batch)
# 输出:一个字典,包含两个键,每个键对应一个形状为[2]的张量
示例4:返回列表
class ListDataset(Dataset):
def __len__(self):
return 10
def __getitem__(self, index):
return [torch.tensor(index), torch.tensor(index*2)]
dataset = ListDataset()
dataloader = DataLoader(dataset, batch_size=2)
for batch in dataloader:
print(batch)
# 输出:一个列表,包含两个张量,每个张量的形状为[2]
示例5:自定义collate_fn
def custom_collate_fn(batch):
# 假设batch是多个样本的列表,每个样本是一个张量,但张量长度不同
# 这里我们使用填充0到最大长度
data = [item[0] for item in batch] # 假设每个样本是一个元组,第一个元素是张量
labels = [item[1] for item in batch]
# 填充数据
lengths = [len(d) for d in data]
max_len = max(lengths)
padded_data = torch.zeros(len(batch), max_len)
for i, d in enumerate(data):
padded_data[i, :lengths[i]] = d
return padded_data, torch.tensor(labels)
class VariableLengthDataset(Dataset):
def __len__(self):
return 10
def __getitem__(self, index):
length = torch.randint(1, 5, (1,)).item()
data = torch.randn(length)
label = index % 2
return data, label
dataset = VariableLengthDataset()
dataloader = DataLoader(dataset, batch_size=2, collate_fn=custom_collate_fn)
for batch in dataloader:
print(batch)
break
1. 基本数据类型
单个张量
# Dataset 返回单个张量
class SingleTensorDataset:
def __getitem__(self, index):
return torch.tensor([index, index*2, index*3])
def __len__(self):
return 100
# DataLoader 会自动堆叠成批次
dataloader = DataLoader(SingleTensorDataset(), batch_size=4)
for batch in dataloader:
print(batch.shape) # torch.Size([4, 3])
元组 (最常用)
# 返回 (input, target) 元组
class TupleDataset:
def __getitem__(self, index):
x = torch.randn(10) # 特征
y = torch.tensor(index % 3) # 标签
return x, y # 返回元组
dataloader = DataLoader(TupleDataset(), batch_size=4)
for inputs, targets in dataloader:
print(inputs.shape, targets.shape) # torch.Size([4, 10]), torch.Size([4])
字典
# 返回字典,键值对形式
class DictDataset:
def __getitem__(self, index):
return {
'input_ids': torch.randint(0, 100, (10,)),
'attention_mask': torch.ones(10),
'labels': torch.tensor(index % 2)
}
dataloader = DataLoader(DictDataset(), batch_size=4)
for batch in dataloader:
print(batch.keys()) # dict_keys(['input_ids', 'attention_mask', 'labels'])
print(batch['input_ids'].shape) # torch.Size([4, 10])
列表
# 返回列表
class ListDataset:
def __getitem__(self, index):
return [torch.tensor(index), torch.randn(5), f"sample_{index}"]
dataloader = DataLoader(ListDataset(), batch_size=4)
for batch in dataloader:
print(batch) # [tensor, tensor, list_of_strings]
2. 复杂嵌套结构
嵌套字典
class NestedDictDataset:
def __getitem__(self, index):
return {
'image': torch.randn(3, 224, 224),
'metadata': {
'filename': f'img_{index}.jpg',
'size': (224, 224),
'timestamp': index * 1000
},
'labels': {
'class': torch.tensor(index % 10),
'bbox': torch.tensor([0.1, 0.2, 0.8, 0.9])
}
}
dataloader = DataLoader(NestedDictDataset(), batch_size=4)
for batch in dataloader:
print(batch['image'].shape) # torch.Size([4, 3, 224, 224])
print(batch['metadata']['filename']) # 列表: ['img_0.jpg', ...]
命名元组
from collections import namedtuple
Sample = namedtuple('Sample', ['features', 'label', 'id'])
class NamedTupleDataset:
def __getitem__(self, index):
return Sample(
features=torch.randn(10),
label=torch.tensor(index % 3),
id=f"sample_{index}"
)
dataloader = DataLoader(NamedTupleDataset(), batch_size=4)
for batch in dataloader:
print(type(batch)) # <class '__main__.Sample'>
print(batch.features.shape) # torch.Size([4, 10])
3. 自定义数据类型
自定义类实例
class DataSample:
def __init__(self, data, target, meta):
self.data = data
self.target = target
self.meta = meta
class CustomClassDataset:
def __getitem__(self, index):
return DataSample(
data=torch.randn(10),
target=torch.tensor(index % 2),
meta={'index': index, 'name': f'sample_{index}'}
)
# 需要自定义 collate_fn
def custom_collate(batch):
return DataSample(
data=torch.stack([sample.data for sample in batch]),
target=torch.stack([sample.target for sample in batch]),
meta=[sample.meta for sample in batch]
)
dataloader = DataLoader(CustomClassDataset(), batch_size=4, collate_fn=custom_collate)
4. 混合数据类型
张量 + Python 基本类型
class MixedDataset:
def __getitem__(self, index):
return (
torch.randn(10), # 张量
index % 3, # Python整数
f"sample_{index}", # 字符串
[index, index*2], # 列表
{'idx': index} # 字典
)
dataloader = DataLoader(MixedDataset(), batch_size=4)
for tensor_data, int_data, str_data, list_data, dict_data in dataloader:
print(tensor_data.shape) # torch.Size([4, 10]) - 张量被堆叠
print(int_data) # tensor([0, 1, 2, 3]) - 数字被转换为张量
print(str_data) # ['sample_0', ...] - 字符串保持为列表
5. 特殊返回值处理
None 值处理
class DatasetWithNone:
def __getitem__(self, index):
if index % 5 == 0: # 每5个样本返回None
return None
return torch.randn(10), torch.tensor(index % 3)
# 需要过滤 None 值
def filter_none_collate(batch):
batch = [sample for sample in batch if sample is not None]
return default_collate(batch) if batch else None
dataloader = DataLoader(DatasetWithNone(), batch_size=4, collate_fn=filter_none_collate)
可变长度序列
class VariableLengthDataset:
def __getitem__(self, index):
length = torch.randint(5, 15, (1,)).item()
sequence = torch.randn(length, 10) # 可变长度序列
return sequence, torch.tensor(index % 3)
# 使用 pad_sequence 处理可变长度
from torch.nn.utils.rnn import pad_sequence
def pad_collate(batch):
sequences, labels = zip(*batch)
padded_sequences = pad_sequence(sequences, batch_first=True)
labels = torch.stack(labels)
return padded_sequences, labels
dataloader = DataLoader(VariableLengthDataset(), batch_size=4, collate_fn=pad_collate)
6. 实际应用示例
计算机视觉任务
class VisionDataset:
def __getitem__(self, index):
return {
'image': torch.randn(3, 224, 224), # 图像
'label': torch.tensor(index % 1000), # 分类标签
'bbox': torch.tensor([[10, 20, 100, 150]]), # 检测框
'mask': torch.randn(224, 224) > 0.5, # 分割掩码
'image_id': index # 图像ID
}
NLP 任务
class NLPDataset:
def __getitem__(self, index):
return {
'input_ids': torch.randint(0, 1000, (128,)),
'attention_mask': torch.ones(128),
'token_type_ids': torch.zeros(128),
'labels': torch.randint(0, 2, (1,)),
'text': f"这是第{index}个样本"
}