DataSet-深度学习中的常见类

深度学习中Dataset类通用的架构思路

Dataset 类设计的必备部分

1. 初始化 __init__

  • 配置和路径管理 :保存 config,区分 train/val/test 路径。
  • 加载原始数据:CSV、JSON、Numpy、Parquet 等。
  • 预处理器/归一化器 :如 StandardScaler,或者 Tokenizer(在 NLP 任务里)。
  • 准备辅助信息:比如 meta 特征、文本 embedding。
  • 构造样本列表 (self.samples):保证后面取样时直接 O(1) 访问。

2. 数据预处理

  • normalize / inverse_transform:数值数据标准化和反变换。
  • tokenize / pad:文本分词、对齐。
  • feature engineering:特征拼接、缺失值处理。

3. 核心接口

  • __len__: 返回数据集样本数。
  • __getitem__: 返回一个样本(通常是 (features, label) 的 tuple 或 dict)。

4. 可选接口

  • get_scaler(): 返回归一化器。
  • get_vocab(): NLP 任务里返回词表。
  • collate_fn: 定义 batch 内如何拼接(特别是变长序列)。
  • save_cache / load_cache: 大数据集可以存缓存,避免每次都重新处理。

5. 继承关系

  • BaseDataset:负责

    • 通用逻辑(加载文件、归一化、拼装 sample)。
    • 提供钩子函数,比如 load_paths(flag)process_sample(sample)
  • 子类 :只需要实现 路径差异样本加工方式差异


通用代码结构示意

python 复制代码
class BaseDataset(Dataset):
    def __init__(self, config, flag="train", scaler=None):
        self.config = config
        self.flag = flag
        self.scaler = scaler or StandardScaler()
        self.samples = []
        self._load_data()
        self._build_samples()

    def _load_data(self):
        """子类可重写,加载原始数据"""
        raise NotImplementedError

    def _build_samples(self):
        """子类可重写,拼装每个样本的x, y, feats"""
        raise NotImplementedError

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

    def get_scaler(self):
        return self.scaler

    def inverse_transform(self, x):
        return x * self.std + self.mean

子类只管:

python 复制代码
class ElectricityDataset(BaseDataset):
    def _load_data(self):
        # 只写路径和文件加载逻辑
        pass

    def _build_samples(self):
        # 根据任务需要定义样本结构
        pass

调用示例

python 复制代码
data_config = {
    "root": "data/electricity/",
    "train_file": "train.json",
    "train_meta_file": "train_meta.npy",
    "train_news_file": "train_news.npy"
}

train_config = {
    "batch_size": 64,
    "learning_rate": 1e-3,
    "epochs": 20
}

train_ds = ElectricityDataset(data_config, flag="train")

train_loader = DataLoader(
    train_ds,
    batch_size=train_config["batch_size"],
    shuffle=True,
    collate_fn=custom_collate_fn
)

)
相关推荐
墨染天姬13 小时前
【AI】端侧AIBOX可以部署哪些智能体
人工智能
AI成长日志13 小时前
【Agentic RL】1.1 什么是Agentic RL:从传统RL到智能体学习
人工智能·学习·算法
2501_9481142413 小时前
2026年大模型API聚合平台技术评测:企业级接入层的治理演进与星链4SAPI架构观察
大数据·人工智能·gpt·架构·claude
小小工匠13 小时前
LLM - awesome-design-md 从 DESIGN.md 到“可对话的设计系统”:用纯文本驱动 AI 生成一致 UI 的新范式
人工智能·ui
黎阳之光13 小时前
黎阳之光:视频孪生领跑者,铸就中国数字科技全球竞争力
大数据·人工智能·算法·安全·数字孪生
小超同学你好13 小时前
面向 LLM 的程序设计 6:Tool Calling 的完整生命周期——从定义、决策、执行到观测回注
人工智能·语言模型
智星云算力14 小时前
本地GPU与租用GPU混合部署:混合算力架构搭建指南
人工智能·架构·gpu算力·智星云·gpu租用
jinanwuhuaguo14 小时前
截止到4月8日,OpenClaw 2026年4月更新深度解读剖析:从“能力回归”到“信任内建”的范式跃迁
android·开发语言·人工智能·深度学习·kotlin
xiaozhazha_14 小时前
效率提升80%:2026年AI CRM与ERP深度集成的架构设计与实现
人工智能
枫叶林FYL14 小时前
【自然语言处理 NLP】7.2.2 安全性评估与Constitutional AI
人工智能·自然语言处理