【推荐系统】深度学习训练框架(十):PyTorch Dataset—PyTorch数据基石

在PyTorch中,Dataset类是用来表示数据集的抽象类,需要继承它并实现几个关键方法。以下是详细的介绍和示例:

1. 基本Dataset类结构

python 复制代码
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os

class CustomDataset(Dataset):
    def __init__(self, data, labels, transform=None):
        """
        初始化Dataset
        Args:
            data: 数据(可以是列表、数组等)
            labels: 标签
            transform: 数据预处理/增强
        """
        self.data = data
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        """返回数据集大小"""
        return len(self.data)
    
    def __getitem__(self, idx):
        """
        根据索引返回一个样本
        Args:
            idx: 索引
        Returns:
            sample: 一个样本(数据+标签)
        """
        sample = self.data[idx]
        label = self.labels[idx]
        
        if self.transform:
            sample = self.transform(sample)
        
        return sample, label

2. 实际应用示例

示例1:图像分类数据集

python 复制代码
import torchvision.transforms as transforms

class ImageDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
        
        # 默认转换
        if transform is None:
            self.transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                   std=[0.229, 0.224, 0.225])
            ])
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

示例2:文本分类数据集

python 复制代码
from torch.nn.utils.rnn import pad_sequence
import torch

class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        
        # 分词
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        # 展平,因为tokenizer返回的是批次维度
        input_ids = encoding['input_ids'].squeeze(0)
        attention_mask = encoding['attention_mask'].squeeze(0)
        
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': torch.tensor(label, dtype=torch.long)
        }

# 用于DataLoader的collate函数
def collate_fn(batch):
    input_ids = [item['input_ids'] for item in batch]
    attention_mask = [item['attention_mask'] for item in batch]
    labels = [item['labels'] for item in batch]
    
    # 填充
    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
    attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
    labels = torch.stack(labels)
    
    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'labels': labels
    }

3. 使用DataLoader加载数据

python 复制代码
# 创建数据集实例
dataset = CustomDataset(data, labels, transform=transform)

# 使用DataLoader
dataloader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,  # 训练时打乱数据
    num_workers=4,  # 多进程加载
    pin_memory=True  # 加速GPU传输
)

# 遍历数据
for batch_idx, (data, labels) in enumerate(dataloader):
    print(f"Batch {batch_idx}:")
    print(f"  Data shape: {data.shape}")
    print(f"  Labels shape: {labels.shape}")
    # 训练代码...

4. 高级Dataset类

示例3:带数据增强的数据集

python 复制代码
class AugmentedDataset(Dataset):
    def __init__(self, dataset, augmentations):
        self.dataset = dataset
        self.augmentations = augmentations
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        
        # 应用数据增强
        if self.augmentations:
            image = self.augmentations(image)
        
        return image, label

示例4:多模态数据集

python 复制代码
class MultiModalDataset(Dataset):
    def __init__(self, images, texts, labels, image_transform, tokenizer):
        self.images = images
        self.texts = texts
        self.labels = labels
        self.image_transform = image_transform
        self.tokenizer = tokenizer
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image = Image.open(self.images[idx]).convert('RGB')
        text = self.texts[idx]
        label = self.labels[idx]
        
        # 图像处理
        if self.image_transform:
            image = self.image_transform(image)
        
        # 文本处理
        text_encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=128,
            return_tensors='pt'
        )
        
        return {
            'image': image,
            'input_ids': text_encoding['input_ids'].squeeze(0),
            'attention_mask': text_encoding['attention_mask'].squeeze(0),
            'label': torch.tensor(label)
        }

5. 实用技巧和建议

python 复制代码
# 1. 缓存数据以加速加载
class CachedDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.cache = {}
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        if idx not in self.cache:
            self.cache[idx] = self.dataset[idx]
        return self.cache[idx]

# 2. 子集划分
from torch.utils.data import Subset

# 划分训练集和验证集
indices = list(range(len(dataset)))
split = int(0.8 * len(dataset))
train_indices = indices[:split]
val_indices = indices[split:]

train_dataset = Subset(dataset, train_indices)
val_dataset = Subset(dataset, val_indices)

# 3. 数据采样器
from torch.utils.data import WeightedRandomSampler

# 处理类别不平衡
class_counts = [100, 20, 300]  # 每个类别的样本数
weights = 1. / torch.tensor(class_counts, dtype=torch.float)
sample_weights = weights[labels]

sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),
    replacement=True
)

dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)

6. 使用内置数据集

PyTorch也提供了许多内置数据集:

python 复制代码
import torchvision.datasets as datasets
from torchvision import transforms

# MNIST数据集
mnist_dataset = datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transforms.ToTensor()
)

# CIFAR-10数据集
cifar_dataset = datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
)

7. Dataset与Tokenizer的关系理解

Dataset和Tokenizer在NLP任务中是分工协作、相互配合的关系。让我用一个比喻来解释:

🏭 工厂流水线比喻

想象一个食品加工厂:

  • Dataset = 仓库管理员 + 物流系统

    • 知道数据在哪里
    • 如何按顺序取出原料
    • 如何批量组织运输
  • Tokenizer = 食品加工机

    • 将原料(文本)切割成标准大小
    • 添加调味料(特殊标记)
    • 包装成统一格式

🔄 工作流程详解

1. 角色分工
python 复制代码
# Dataset: 负责数据管理
class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer):
        self.texts = texts      # 原始文本
        self.labels = labels    # 原始标签
        self.tokenizer = tokenizer  # 工具
        
    def __getitem__(self, idx):
        raw_text = self.texts[idx]     # 获取原始文本
        label = self.labels[idx]       # 获取原始标签
        
        # Tokenizer: 负责文本处理
        encoded = self.tokenizer(raw_text)  # 文本 → 数字ID
        
        return {
            'input_ids': encoded['input_ids'],
            'attention_mask': encoded['attention_mask'],
            'labels': label
        }
2. 处理过程对比
方面 Dataset Tokenizer
职责 数据管理、索引、批处理 文本预处理、编码
输入 原始数据文件/路径 原始文本字符串
输出 组织好的样本批次 数值化/向量化表示
关注点 数据流、内存效率 语言学特征、模型兼容性

📝 实际示例分析

示例1:标准用法
python 复制代码
from transformers import AutoTokenizer
from torch.utils.data import Dataset

class NLP_dataset(Dataset):
    def __init__(self, file_path, tokenizer_name, max_length=128):
        # Dataset: 加载原始数据
        self.data = self._load_data(file_path)  # 返回 [(text1, label1), ...]
        
        # Tokenizer: 初始化工具
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        self.max_length = max_length
    
    def __getitem__(self, idx):
        text, label = self.data[idx]
        
        # Tokenizer在此被Dataset调用
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        # Dataset返回处理后的数据
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(label)
        }
示例2:可视化处理流程
python 复制代码
# 原始数据状态
raw_data = [
    ("I love machine learning", 1),
    ("Deep learning is amazing", 1),
    ("This is boring", 0)
]

# Dataset的作用:
# 1. 存储这些数据对
# 2. 提供按索引访问的接口

# 当访问dataset[0]时:
text = "I love machine learning"
label = 1

# Tokenizer的作用:
tokens = tokenizer.tokenize(text)
# ["I", "love", "machine", "learning"]

input_ids = tokenizer.convert_tokens_to_ids(tokens)
# [1045, 2293, 4083, 4084]  # 实际BERT的ID

# 最终Dataset返回:
{
    'input_ids': [1045, 2293, 4083, 4084],
    'attention_mask': [1, 1, 1, 1],
    'labels': 1
}

🔧 不同场景下的协作模式

模式1:Dataset内部集成Tokenizer(最常见)
python 复制代码
class DatasetWithTokenizer(Dataset):
    def __init__(self, texts, labels, tokenizer):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer  # Tokenizer作为参数传入
    
    def __getitem__(self, idx):
        text = self.texts[idx]
        
        # Dataset调用Tokenizer的服务
        encoded = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=128,
            padding='max_length',
            return_tensors='pt'
        )
        
        return encoded, self.labels[idx]
模式2:Dataset外部处理(预处理模式)
python 复制代码
# 先统一预处理所有数据
all_encoded = []
for text in all_texts:
    encoded = tokenizer(text)  # 提前处理
    all_encoded.append(encoded)

# Dataset只负责存储和返回
class PreprocessedDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings  # 已经tokenized的数据
        self.labels = labels
    
    def __getitem__(self, idx):
        return self.encodings[idx], self.labels[idx]
模式3:动态处理(更灵活)
python 复制代码
class DynamicDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, augment_fn=None):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.augment_fn = augment_fn  # 数据增强函数
    
    def __getitem__(self, idx):
        text = self.texts[idx]
        
        # 动态数据增强
        if self.augment_fn:
            text = self.augment_fn(text)
        
        # 动态分词
        encoded = self.tokenizer(
            text,
            padding='max_length',
            truncation=True,
            max_length=512
        )
        
        return {k: torch.tensor(v) for k, v in encoded.items()}, self.labels[idx]

🎯 关键理解点

1. 数据流方向
复制代码
原始文本 → Dataset获取 → Tokenizer处理 → 模型输入
      ↓           ↓            ↓           ↓
    文件/DB     管理组织     文本转数字   张量格式
2. 生命周期管理
  • Tokenizer通常只创建一次(模型相关)
  • Dataset可能创建多个(训练集、验证集、测试集)
python 复制代码
# 一次性创建
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# 多个数据集共用同一个tokenizer
train_dataset = MyDataset(train_texts, train_labels, tokenizer)
val_dataset = MyDataset(val_texts, val_labels, tokenizer)  # 同一个tokenizer
3. 依赖关系
python 复制代码
# Tokenizer不依赖Dataset
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
# 可以单独使用
result = tokenizer("Hello world!")

# Dataset依赖Tokenizer
class MyDataset(Dataset):
    def __init__(self, tokenizer):  # 依赖注入
        self.tokenizer = tokenizer

💡 最佳实践建议

1. Tokenizer选择策略
python 复制代码
# 根据模型选择对应的Tokenizer
MODEL_TO_TOKENIZER = {
    'bert': BertTokenizer,
    'roberta': RobertaTokenizer,
    'gpt2': GPT2Tokenizer,
    't5': T5Tokenizer
}

def create_dataset(model_type, data, labels):
    tokenizer_class = MODEL_TO_TOKENIZER[model_type]
    tokenizer = tokenizer_class.from_pretrained(model_type)
    return CustomDataset(data, labels, tokenizer)
2. 缓存优化
python 复制代码
class CachedTokenizationDataset(Dataset):
    def __init__(self, texts, labels, tokenizer):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.cache = {}  # 缓存tokenization结果
    
    def __getitem__(self, idx):
        if idx not in self.cache:
            text = self.texts[idx]
            encoded = self.tokenizer(text, padding='max_length', truncation=True)
            self.cache[idx] = (encoded, self.labels[idx])
        return self.cache[idx]
3. 错误处理
python 复制代码
class RobustDataset(Dataset):
    def __getitem__(self, idx):
        try:
            text = self.texts[idx]
            # 处理可能的文本异常
            if not isinstance(text, str) or not text.strip():
                text = "[EMPTY]"
            
            encoded = self.tokenizer(
                text,
                padding='max_length',
                truncation=True,
                max_length=self.max_len,
                return_tensors='pt'
            )
            
            return encoded, self.labels[idx]
        except Exception as e:
            # 返回一个安全的默认值
            return self._get_default_item()

🎓 总结比喻

Dataset 就像餐厅的服务员

  • 知道菜单(数据索引)
  • 按顺序上菜(数据顺序)
  • 组织多人用餐(批处理)

Tokenizer 就像厨房的厨师

  • 处理食材(文本分词)
  • 按菜谱烹饪(添加特殊标记)
  • 摆盘(padding/truncation)

两者配合:服务员(Dataset)从顾客(数据源)那里拿到点单(原始文本),交给厨师(Tokenizer)烹饪处理,最后上菜给食客(模型)。

这种分工使得:

  1. Tokenizer专注:文本处理逻辑,与模型架构对齐
  2. Dataset专注:数据管理逻辑,与训练流程对齐
  3. 松耦合设计:可以更换不同的Tokenizer而不影响Dataset结构
相关推荐
oak隔壁找我24 分钟前
Python + Langchain + Streamlit + DashScope 实现一个网页版聊天机器人
人工智能
是Dream呀29 分钟前
昇腾实战|算子模板库Catlass与CANN生态适配
开发语言·人工智能·python·华为
tanxiaomi30 分钟前
Redisson分布式锁 和 乐观锁的使用场景
java·分布式·mysql·面试
曦云沐31 分钟前
第二篇:LangChain 1.0 模块化架构与依赖管理
人工智能·langchain·智能体
长桥夜波33 分钟前
机器学习日报23
人工智能·机器学习
roman_日积跬步-终至千里35 分钟前
【模式识别与机器学习(9)】数据预处理-第一部分:数据基础认知
人工智能·机器学习
FL162386312941 分钟前
自动驾驶场景驾驶员注意力安全行为睡驾分心驾驶疲劳驾驶检测数据集VOC+YOLO格式5370张6类别
人工智能·yolo·自动驾驶
Java中文社群43 分钟前
找到漏洞了!抓紧薅~N8N调用即梦全免费
人工智能
培根芝士1 小时前
使用llm-compressor 对 Qwen3-14B 做 AWQ + INT4 量化
人工智能·python