
在PyTorch中,Dataset类是用来表示数据集的抽象类,需要继承它并实现几个关键方法。以下是详细的介绍和示例:
1. 基本Dataset类结构
python
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
class CustomDataset(Dataset):
def __init__(self, data, labels, transform=None):
"""
初始化Dataset
Args:
data: 数据(可以是列表、数组等)
labels: 标签
transform: 数据预处理/增强
"""
self.data = data
self.labels = labels
self.transform = transform
def __len__(self):
"""返回数据集大小"""
return len(self.data)
def __getitem__(self, idx):
"""
根据索引返回一个样本
Args:
idx: 索引
Returns:
sample: 一个样本(数据+标签)
"""
sample = self.data[idx]
label = self.labels[idx]
if self.transform:
sample = self.transform(sample)
return sample, label
2. 实际应用示例
示例1:图像分类数据集
python
import torchvision.transforms as transforms
class ImageDataset(Dataset):
def __init__(self, image_paths, labels, transform=None):
self.image_paths = image_paths
self.labels = labels
self.transform = transform
# 默认转换
if transform is None:
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image_path = self.image_paths[idx]
image = Image.open(image_path).convert('RGB')
label = self.labels[idx]
if self.transform:
image = self.transform(image)
return image, label
示例2:文本分类数据集
python
from torch.nn.utils.rnn import pad_sequence
import torch
class TextDataset(Dataset):
def __init__(self, texts, labels, tokenizer, max_length=128):
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
label = self.labels[idx]
# 分词
encoding = self.tokenizer(
text,
truncation=True,
padding='max_length',
max_length=self.max_length,
return_tensors='pt'
)
# 展平,因为tokenizer返回的是批次维度
input_ids = encoding['input_ids'].squeeze(0)
attention_mask = encoding['attention_mask'].squeeze(0)
return {
'input_ids': input_ids,
'attention_mask': attention_mask,
'labels': torch.tensor(label, dtype=torch.long)
}
# 用于DataLoader的collate函数
def collate_fn(batch):
input_ids = [item['input_ids'] for item in batch]
attention_mask = [item['attention_mask'] for item in batch]
labels = [item['labels'] for item in batch]
# 填充
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
labels = torch.stack(labels)
return {
'input_ids': input_ids,
'attention_mask': attention_mask,
'labels': labels
}
3. 使用DataLoader加载数据
python
# 创建数据集实例
dataset = CustomDataset(data, labels, transform=transform)
# 使用DataLoader
dataloader = DataLoader(
dataset,
batch_size=32,
shuffle=True, # 训练时打乱数据
num_workers=4, # 多进程加载
pin_memory=True # 加速GPU传输
)
# 遍历数据
for batch_idx, (data, labels) in enumerate(dataloader):
print(f"Batch {batch_idx}:")
print(f" Data shape: {data.shape}")
print(f" Labels shape: {labels.shape}")
# 训练代码...
4. 高级Dataset类
示例3:带数据增强的数据集
python
class AugmentedDataset(Dataset):
def __init__(self, dataset, augmentations):
self.dataset = dataset
self.augmentations = augmentations
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
image, label = self.dataset[idx]
# 应用数据增强
if self.augmentations:
image = self.augmentations(image)
return image, label
示例4:多模态数据集
python
class MultiModalDataset(Dataset):
def __init__(self, images, texts, labels, image_transform, tokenizer):
self.images = images
self.texts = texts
self.labels = labels
self.image_transform = image_transform
self.tokenizer = tokenizer
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image = Image.open(self.images[idx]).convert('RGB')
text = self.texts[idx]
label = self.labels[idx]
# 图像处理
if self.image_transform:
image = self.image_transform(image)
# 文本处理
text_encoding = self.tokenizer(
text,
truncation=True,
padding='max_length',
max_length=128,
return_tensors='pt'
)
return {
'image': image,
'input_ids': text_encoding['input_ids'].squeeze(0),
'attention_mask': text_encoding['attention_mask'].squeeze(0),
'label': torch.tensor(label)
}
5. 实用技巧和建议
python
# 1. 缓存数据以加速加载
class CachedDataset(Dataset):
def __init__(self, dataset):
self.dataset = dataset
self.cache = {}
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
if idx not in self.cache:
self.cache[idx] = self.dataset[idx]
return self.cache[idx]
# 2. 子集划分
from torch.utils.data import Subset
# 划分训练集和验证集
indices = list(range(len(dataset)))
split = int(0.8 * len(dataset))
train_indices = indices[:split]
val_indices = indices[split:]
train_dataset = Subset(dataset, train_indices)
val_dataset = Subset(dataset, val_indices)
# 3. 数据采样器
from torch.utils.data import WeightedRandomSampler
# 处理类别不平衡
class_counts = [100, 20, 300] # 每个类别的样本数
weights = 1. / torch.tensor(class_counts, dtype=torch.float)
sample_weights = weights[labels]
sampler = WeightedRandomSampler(
weights=sample_weights,
num_samples=len(sample_weights),
replacement=True
)
dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)
6. 使用内置数据集
PyTorch也提供了许多内置数据集:
python
import torchvision.datasets as datasets
from torchvision import transforms
# MNIST数据集
mnist_dataset = datasets.MNIST(
root='./data',
train=True,
download=True,
transform=transforms.ToTensor()
)
# CIFAR-10数据集
cifar_dataset = datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
)
7. Dataset与Tokenizer的关系理解
Dataset和Tokenizer在NLP任务中是分工协作、相互配合的关系。让我用一个比喻来解释:
🏭 工厂流水线比喻
想象一个食品加工厂:
-
Dataset = 仓库管理员 + 物流系统
- 知道数据在哪里
- 如何按顺序取出原料
- 如何批量组织运输
-
Tokenizer = 食品加工机
- 将原料(文本)切割成标准大小
- 添加调味料(特殊标记)
- 包装成统一格式
🔄 工作流程详解
1. 角色分工
python
# Dataset: 负责数据管理
class TextDataset(Dataset):
def __init__(self, texts, labels, tokenizer):
self.texts = texts # 原始文本
self.labels = labels # 原始标签
self.tokenizer = tokenizer # 工具
def __getitem__(self, idx):
raw_text = self.texts[idx] # 获取原始文本
label = self.labels[idx] # 获取原始标签
# Tokenizer: 负责文本处理
encoded = self.tokenizer(raw_text) # 文本 → 数字ID
return {
'input_ids': encoded['input_ids'],
'attention_mask': encoded['attention_mask'],
'labels': label
}
2. 处理过程对比
| 方面 | Dataset | Tokenizer |
|---|---|---|
| 职责 | 数据管理、索引、批处理 | 文本预处理、编码 |
| 输入 | 原始数据文件/路径 | 原始文本字符串 |
| 输出 | 组织好的样本批次 | 数值化/向量化表示 |
| 关注点 | 数据流、内存效率 | 语言学特征、模型兼容性 |
📝 实际示例分析
示例1:标准用法
python
from transformers import AutoTokenizer
from torch.utils.data import Dataset
class NLP_dataset(Dataset):
def __init__(self, file_path, tokenizer_name, max_length=128):
# Dataset: 加载原始数据
self.data = self._load_data(file_path) # 返回 [(text1, label1), ...]
# Tokenizer: 初始化工具
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
self.max_length = max_length
def __getitem__(self, idx):
text, label = self.data[idx]
# Tokenizer在此被Dataset调用
encoding = self.tokenizer(
text,
truncation=True,
padding='max_length',
max_length=self.max_length,
return_tensors='pt'
)
# Dataset返回处理后的数据
return {
'input_ids': encoding['input_ids'].squeeze(0),
'attention_mask': encoding['attention_mask'].squeeze(0),
'labels': torch.tensor(label)
}
示例2:可视化处理流程
python
# 原始数据状态
raw_data = [
("I love machine learning", 1),
("Deep learning is amazing", 1),
("This is boring", 0)
]
# Dataset的作用:
# 1. 存储这些数据对
# 2. 提供按索引访问的接口
# 当访问dataset[0]时:
text = "I love machine learning"
label = 1
# Tokenizer的作用:
tokens = tokenizer.tokenize(text)
# ["I", "love", "machine", "learning"]
input_ids = tokenizer.convert_tokens_to_ids(tokens)
# [1045, 2293, 4083, 4084] # 实际BERT的ID
# 最终Dataset返回:
{
'input_ids': [1045, 2293, 4083, 4084],
'attention_mask': [1, 1, 1, 1],
'labels': 1
}
🔧 不同场景下的协作模式
模式1:Dataset内部集成Tokenizer(最常见)
python
class DatasetWithTokenizer(Dataset):
def __init__(self, texts, labels, tokenizer):
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer # Tokenizer作为参数传入
def __getitem__(self, idx):
text = self.texts[idx]
# Dataset调用Tokenizer的服务
encoded = self.tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=128,
padding='max_length',
return_tensors='pt'
)
return encoded, self.labels[idx]
模式2:Dataset外部处理(预处理模式)
python
# 先统一预处理所有数据
all_encoded = []
for text in all_texts:
encoded = tokenizer(text) # 提前处理
all_encoded.append(encoded)
# Dataset只负责存储和返回
class PreprocessedDataset(Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings # 已经tokenized的数据
self.labels = labels
def __getitem__(self, idx):
return self.encodings[idx], self.labels[idx]
模式3:动态处理(更灵活)
python
class DynamicDataset(Dataset):
def __init__(self, texts, labels, tokenizer, augment_fn=None):
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer
self.augment_fn = augment_fn # 数据增强函数
def __getitem__(self, idx):
text = self.texts[idx]
# 动态数据增强
if self.augment_fn:
text = self.augment_fn(text)
# 动态分词
encoded = self.tokenizer(
text,
padding='max_length',
truncation=True,
max_length=512
)
return {k: torch.tensor(v) for k, v in encoded.items()}, self.labels[idx]
🎯 关键理解点
1. 数据流方向
原始文本 → Dataset获取 → Tokenizer处理 → 模型输入
↓ ↓ ↓ ↓
文件/DB 管理组织 文本转数字 张量格式
2. 生命周期管理
- Tokenizer通常只创建一次(模型相关)
- Dataset可能创建多个(训练集、验证集、测试集)
python
# 一次性创建
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# 多个数据集共用同一个tokenizer
train_dataset = MyDataset(train_texts, train_labels, tokenizer)
val_dataset = MyDataset(val_texts, val_labels, tokenizer) # 同一个tokenizer
3. 依赖关系
python
# Tokenizer不依赖Dataset
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
# 可以单独使用
result = tokenizer("Hello world!")
# Dataset依赖Tokenizer
class MyDataset(Dataset):
def __init__(self, tokenizer): # 依赖注入
self.tokenizer = tokenizer
💡 最佳实践建议
1. Tokenizer选择策略
python
# 根据模型选择对应的Tokenizer
MODEL_TO_TOKENIZER = {
'bert': BertTokenizer,
'roberta': RobertaTokenizer,
'gpt2': GPT2Tokenizer,
't5': T5Tokenizer
}
def create_dataset(model_type, data, labels):
tokenizer_class = MODEL_TO_TOKENIZER[model_type]
tokenizer = tokenizer_class.from_pretrained(model_type)
return CustomDataset(data, labels, tokenizer)
2. 缓存优化
python
class CachedTokenizationDataset(Dataset):
def __init__(self, texts, labels, tokenizer):
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer
self.cache = {} # 缓存tokenization结果
def __getitem__(self, idx):
if idx not in self.cache:
text = self.texts[idx]
encoded = self.tokenizer(text, padding='max_length', truncation=True)
self.cache[idx] = (encoded, self.labels[idx])
return self.cache[idx]
3. 错误处理
python
class RobustDataset(Dataset):
def __getitem__(self, idx):
try:
text = self.texts[idx]
# 处理可能的文本异常
if not isinstance(text, str) or not text.strip():
text = "[EMPTY]"
encoded = self.tokenizer(
text,
padding='max_length',
truncation=True,
max_length=self.max_len,
return_tensors='pt'
)
return encoded, self.labels[idx]
except Exception as e:
# 返回一个安全的默认值
return self._get_default_item()
🎓 总结比喻
Dataset 就像餐厅的服务员:
- 知道菜单(数据索引)
- 按顺序上菜(数据顺序)
- 组织多人用餐(批处理)
Tokenizer 就像厨房的厨师:
- 处理食材(文本分词)
- 按菜谱烹饪(添加特殊标记)
- 摆盘(padding/truncation)
两者配合:服务员(Dataset)从顾客(数据源)那里拿到点单(原始文本),交给厨师(Tokenizer)烹饪处理,最后上菜给食客(模型)。
这种分工使得:
- Tokenizer专注:文本处理逻辑,与模型架构对齐
- Dataset专注:数据管理逻辑,与训练流程对齐
- 松耦合设计:可以更换不同的Tokenizer而不影响Dataset结构