PyTorch DataLoader 接受的返回值类型

通常,数据集通过__getitem__方法返回单个样本,而DataLoader负责将这些样本批量组合。以下是常见的返回值类型:

  1. 张量(Tensor):最常见的情况,返回一个或多个张量。DataLoader会自动将多个样本的张量堆叠成批次。

  2. 列表(List):可以返回一个列表,其中包含多个张量或其他类型。DataLoader会尝试将列表中的每个元素分别批量处理。

  3. 字典(Dictionary):返回一个字典,键是数据字段名,值是对应的张量或数据。DataLoader会按字段名分别批量处理。

  4. 元组(Tuple):返回一个元组,其中包含多个张量或其他类型。DataLoader会分别对元组中的每个元素进行批量处理。

  5. 命名元组(NamedTuple):类似于元组,但可以通过字段名访问,DataLoader处理方式与元组类似。

  6. 自定义数据类型 :如果返回的是自定义类型,需要确保DataLoader知道如何组合这些类型。通常,自定义类型需要实现相应的拼接方法,或者使用default_collate函数能够处理。

示例1:返回张量

复制代码
import torch
from torch.utils.data import Dataset, DataLoader

class TensorDataset(Dataset):
    def __len__(self):
        return 10

    def __getitem__(self, index):
        return torch.tensor([index, index*2])

dataset = TensorDataset()
dataloader = DataLoader(dataset, batch_size=2)

for batch in dataloader:
    print(batch)
    # 输出:一个批次的张量,形状为[2, 2]

示例2:返回元组(多个张量)

复制代码
class TupleDataset(Dataset):
    def __len__(self):
        return 10

    def __getitem__(self, index):
        return torch.tensor(index), torch.tensor(index*2)

dataset = TupleDataset()
dataloader = DataLoader(dataset, batch_size=2)

for batch in dataloader:
    print(batch)
    # 输出:一个元组,包含两个张量,每个张量的形状为[2]

示例3:返回字典

复制代码
class DictDataset(Dataset):
    def __len__(self):
        return 10

    def __getitem__(self, index):
        return {'data': torch.tensor(index), 'label': torch.tensor(index*2)}

dataset = DictDataset()
dataloader = DataLoader(dataset, batch_size=2)

for batch in dataloader:
    print(batch)
    # 输出:一个字典,包含两个键,每个键对应一个形状为[2]的张量

示例4:返回列表

复制代码
class ListDataset(Dataset):
    def __len__(self):
        return 10

    def __getitem__(self, index):
        return [torch.tensor(index), torch.tensor(index*2)]

dataset = ListDataset()
dataloader = DataLoader(dataset, batch_size=2)

for batch in dataloader:
    print(batch)
    # 输出:一个列表,包含两个张量,每个张量的形状为[2]

示例5:自定义collate_fn

复制代码
def custom_collate_fn(batch):
    # 假设batch是多个样本的列表,每个样本是一个张量,但张量长度不同
    # 这里我们使用填充0到最大长度
    data = [item[0] for item in batch]  # 假设每个样本是一个元组,第一个元素是张量
    labels = [item[1] for item in batch]
    # 填充数据
    lengths = [len(d) for d in data]
    max_len = max(lengths)
    padded_data = torch.zeros(len(batch), max_len)
    for i, d in enumerate(data):
        padded_data[i, :lengths[i]] = d
    return padded_data, torch.tensor(labels)

class VariableLengthDataset(Dataset):
    def __len__(self):
        return 10

    def __getitem__(self, index):
        length = torch.randint(1, 5, (1,)).item()
        data = torch.randn(length)
        label = index % 2
        return data, label

dataset = VariableLengthDataset()
dataloader = DataLoader(dataset, batch_size=2, collate_fn=custom_collate_fn)

for batch in dataloader:
    print(batch)
    break

1. 基本数据类型

单个张量

复制代码
# Dataset 返回单个张量
class SingleTensorDataset:
    def __getitem__(self, index):
        return torch.tensor([index, index*2, index*3])
    
    def __len__(self):
        return 100

# DataLoader 会自动堆叠成批次
dataloader = DataLoader(SingleTensorDataset(), batch_size=4)
for batch in dataloader:
    print(batch.shape)  # torch.Size([4, 3])

元组 (最常用)

复制代码
# 返回 (input, target) 元组
class TupleDataset:
    def __getitem__(self, index):
        x = torch.randn(10)  # 特征
        y = torch.tensor(index % 3)  # 标签
        return x, y  # 返回元组

dataloader = DataLoader(TupleDataset(), batch_size=4)
for inputs, targets in dataloader:
    print(inputs.shape, targets.shape)  # torch.Size([4, 10]), torch.Size([4])

字典

复制代码
# 返回字典,键值对形式
class DictDataset:
    def __getitem__(self, index):
        return {
            'input_ids': torch.randint(0, 100, (10,)),
            'attention_mask': torch.ones(10),
            'labels': torch.tensor(index % 2)
        }
    
dataloader = DataLoader(DictDataset(), batch_size=4)
for batch in dataloader:
    print(batch.keys())  # dict_keys(['input_ids', 'attention_mask', 'labels'])
    print(batch['input_ids'].shape)  # torch.Size([4, 10])

列表

复制代码
# 返回列表
class ListDataset:
    def __getitem__(self, index):
        return [torch.tensor(index), torch.randn(5), f"sample_{index}"]

dataloader = DataLoader(ListDataset(), batch_size=4)
for batch in dataloader:
    print(batch)  # [tensor, tensor, list_of_strings]

2. 复杂嵌套结构

嵌套字典

复制代码
class NestedDictDataset:
    def __getitem__(self, index):
        return {
            'image': torch.randn(3, 224, 224),
            'metadata': {
                'filename': f'img_{index}.jpg',
                'size': (224, 224),
                'timestamp': index * 1000
            },
            'labels': {
                'class': torch.tensor(index % 10),
                'bbox': torch.tensor([0.1, 0.2, 0.8, 0.9])
            }
        }

dataloader = DataLoader(NestedDictDataset(), batch_size=4)
for batch in dataloader:
    print(batch['image'].shape)  # torch.Size([4, 3, 224, 224])
    print(batch['metadata']['filename'])  # 列表: ['img_0.jpg', ...]

命名元组

复制代码
from collections import namedtuple

Sample = namedtuple('Sample', ['features', 'label', 'id'])

class NamedTupleDataset:
    def __getitem__(self, index):
        return Sample(
            features=torch.randn(10),
            label=torch.tensor(index % 3),
            id=f"sample_{index}"
        )

dataloader = DataLoader(NamedTupleDataset(), batch_size=4)
for batch in dataloader:
    print(type(batch))  # <class '__main__.Sample'>
    print(batch.features.shape)  # torch.Size([4, 10])

3. 自定义数据类型

自定义类实例

复制代码
class DataSample:
    def __init__(self, data, target, meta):
        self.data = data
        self.target = target
        self.meta = meta

class CustomClassDataset:
    def __getitem__(self, index):
        return DataSample(
            data=torch.randn(10),
            target=torch.tensor(index % 2),
            meta={'index': index, 'name': f'sample_{index}'}
        )

# 需要自定义 collate_fn
def custom_collate(batch):
    return DataSample(
        data=torch.stack([sample.data for sample in batch]),
        target=torch.stack([sample.target for sample in batch]),
        meta=[sample.meta for sample in batch]
    )

dataloader = DataLoader(CustomClassDataset(), batch_size=4, collate_fn=custom_collate)

4. 混合数据类型

张量 + Python 基本类型

复制代码
class MixedDataset:
    def __getitem__(self, index):
        return (
            torch.randn(10),           # 张量
            index % 3,                 # Python整数
            f"sample_{index}",         # 字符串
            [index, index*2],          # 列表
            {'idx': index}             # 字典
        )

dataloader = DataLoader(MixedDataset(), batch_size=4)
for tensor_data, int_data, str_data, list_data, dict_data in dataloader:
    print(tensor_data.shape)  # torch.Size([4, 10]) - 张量被堆叠
    print(int_data)           # tensor([0, 1, 2, 3]) - 数字被转换为张量
    print(str_data)           # ['sample_0', ...] - 字符串保持为列表

5. 特殊返回值处理

None 值处理

复制代码
class DatasetWithNone:
    def __getitem__(self, index):
        if index % 5 == 0:  # 每5个样本返回None
            return None
        return torch.randn(10), torch.tensor(index % 3)

# 需要过滤 None 值
def filter_none_collate(batch):
    batch = [sample for sample in batch if sample is not None]
    return default_collate(batch) if batch else None

dataloader = DataLoader(DatasetWithNone(), batch_size=4, collate_fn=filter_none_collate)

可变长度序列

复制代码
class VariableLengthDataset:
    def __getitem__(self, index):
        length = torch.randint(5, 15, (1,)).item()
        sequence = torch.randn(length, 10)  # 可变长度序列
        return sequence, torch.tensor(index % 3)

# 使用 pad_sequence 处理可变长度
from torch.nn.utils.rnn import pad_sequence

def pad_collate(batch):
    sequences, labels = zip(*batch)
    padded_sequences = pad_sequence(sequences, batch_first=True)
    labels = torch.stack(labels)
    return padded_sequences, labels

dataloader = DataLoader(VariableLengthDataset(), batch_size=4, collate_fn=pad_collate)

6. 实际应用示例

计算机视觉任务

复制代码
class VisionDataset:
    def __getitem__(self, index):
        return {
            'image': torch.randn(3, 224, 224),      # 图像
            'label': torch.tensor(index % 1000),    # 分类标签
            'bbox': torch.tensor([[10, 20, 100, 150]]),  # 检测框
            'mask': torch.randn(224, 224) > 0.5,    # 分割掩码
            'image_id': index                       # 图像ID
        }

NLP 任务

复制代码
class NLPDataset:
    def __getitem__(self, index):
        return {
            'input_ids': torch.randint(0, 1000, (128,)),
            'attention_mask': torch.ones(128),
            'token_type_ids': torch.zeros(128),
            'labels': torch.randint(0, 2, (1,)),
            'text': f"这是第{index}个样本"
        }
相关推荐
无风听海2 小时前
神经网络之几个简单的激活函数足够表达世界的复杂性吗
人工智能·深度学习·神经网络
铮铭2 小时前
【论文阅读】GR-2:用于机器人操作的生成式视频-语言-动作模型
人工智能
Sugar_pp3 小时前
【论文阅读】Railway rutting defects detection based on improved RT‑DETR
论文阅读·深度学习·目标检测·transformer
DisonTangor3 小时前
百度开源 Qianfan-VL: 领域增强的通用视觉语言模型
人工智能·百度·语言模型
黎燃3 小时前
从“解题”到“证明”——OpenAI 通用大模型如何摘取 IMO 2025 金牌
人工智能
CodeCraft Studio3 小时前
Visual Studio 2026 Insiders 重磅发布:AI 深度集成、性能飞跃、全新设计
ide·人工智能·microsoft·visual studio
Charles豪3 小时前
MR、AR、VR:技术浪潮下安卓应用的未来走向
android·java·人工智能·xr·mr
起个名字费劲死了3 小时前
Pytorch Yolov11 OBB 旋转框检测+window部署+推理封装 留贴记录
c++·人工智能·pytorch·python·深度学习·yolo·机器人
Tadas-Gao3 小时前
华为OmniPlacement技术深度解析:突破超大规模MoE模型推理瓶颈的创新设计
人工智能·架构·大模型·llm