TIME - MoE 模型代码 3.2——Time-MoE-main/time_moe/datasets/time_moe_dataset.py

源码:GitHub - Time-MoE/Time-MoE: [ICLR 2025 Spotlight] Official implementation of "Time-MoE: Billion-Scale Time Series Foundation Models with Mixture of Experts"

这段代码定义了一个用于时间序列数据处理的 TimeMoEDataset 类,支持多种数据格式和归一化方法。


1. 类定义与初始化 (__init__)

python 复制代码
class TimeMoEDataset(TimeSeriesDataset):
    def __init__(self, data_folder, normalization_method=None):
        self.data_folder = data_folder
        self.normalization_method = normalization_method
        self.datasets = []  # 存储子数据集(BinaryDataset/GeneralDataset)
        self.num_tokens = None  # 总时间点数量

        # 处理归一化方法
        if normalization_method is None:
            self.normalization_method = None
        elif isinstance(normalization_method, str):
            if normalization_method.lower() == 'max':
                self.normalization_method = max_scaler  # 最大值归一化
            elif normalization_method.lower() == 'zero':
                self.normalization_method = zero_scaler  # 标准化(Z-score)
            else:
                raise ValueError(f'未知归一化方法: {normalization_method}')
        else:
            self.normalization_method = normalization_method  # 自定义归一化函数

        # 加载数据:支持二进制文件或普通文件/文件夹
        if BinaryDataset.is_valid_path(self.data_folder):
            ds = BinaryDataset(self.data_folder)
            if len(ds) > 0:
                self.datasets.append(ds)
        elif GeneralDataset.is_valid_path(self.data_folder):
            ds = GeneralDataset(self.data_folder)
            if len(ds) > 0:
                self.datasets.append(ds)
        else:
            # 递归遍历文件夹,加载所有有效文件
            for root, dirs, files in os.walk(self.data_folder):
                for file in files:
                    fn_path = os.path.join(root, file)
                    # 跳过二进制元数据文件,加载普通数据集
                    if file != BinaryDataset.meta_file_name and GeneralDataset.is_valid_path(fn_path):
                        ds = GeneralDataset(fn_path)
                        if len(ds) > 0:
                            self.datasets.append(ds)
                for sub_folder in dirs:
                    folder_path = os.path.join(root, sub_folder)
                    # 检查子文件夹是否为二进制数据集
                    if BinaryDataset.is_valid_path(folder_path):
                        ds = BinaryDataset(folder_path)
                        if len(ds) > 0:
                            self.datasets.append(ds)

        # 计算累计长度数组,用于快速定位子数据集
        self.cumsum_lengths = [0]
        for ds in self.datasets:
            self.cumsum_lengths.append(self.cumsum_lengths[-1] + len(ds))
        self.num_sequences = self.cumsum_lengths[-1]  # 总序列数
  • 数据加载
    • 支持两种数据集类型:BinaryDataset(二进制格式)和 GeneralDataset(普通文本 / CSV 等)。
    • 通过 os.walk 递归遍历文件夹,自动识别有效数据文件,避免手动指定每个文件路径。
  • 归一化处理
    • 内置两种归一化方法:max_scaler(最大值归一化)和 zero_scaler(Z-score 标准化)。
    • 支持自定义归一化函数,通过 normalization_method 参数传入。
  • 子数据集管理
    • cumsum_lengths 数组记录每个子数据集的起始索引(类似前缀和),例如 cumsum_lengths = [0, 100, 300] 表示第一个子数据集有 100 条序列,第二个有 200 条。

2. 序列索引与获取 (__getitem__)

python 复制代码
def __getitem__(self, seq_idx):
    if seq_idx >= self.cumsum_lengths[-1] or seq_idx < 0:
        raise ValueError(f'索引越界: {seq_idx}')

    # 二分查找确定子数据集索引和偏移量
    dataset_idx = binary_search(self.cumsum_lengths, seq_idx)
    dataset_offset = seq_idx - self.cumsum_lengths[dataset_idx]
    seq = self.datasets[dataset_idx][dataset_offset]  # 获取原始序列

    # 应用归一化
    if self.normalization_method is not None:
        seq = self.normalization_method(seq)
    return seq
  • 二分查找 :通过 binary_search 函数在 cumsum_lengths 中快速定位序列所属的子数据集(时间复杂度为 \(O(\log N)\))。
  • 归一化应用 :对获取的序列调用 normalization_method 函数,返回归一化后的数据(numpy 数组)。

3. 辅助方法

4. 归一化函数实现

5. 二分查找函数 (binary_search)

python 复制代码
def binary_search(sorted_list, value):
    low, high = 0, len(sorted_list) - 1
    best_index = -1
    while low <= high:
        mid = (low + high) // 2
        if sorted_list[mid] <= value:  # 寻找最大的不超过value的索引
            best_index = mid
            low = mid + 1
        else:
            high = mid - 1
    return best_index
  • 在有序数组 sorted_list(即 cumsum_lengths)中查找 value 所属的区间,返回子数据集索引。例如,若 cumsum_lengths = [0, 100, 300]value=150 会被定位到索引 1(第二个子数据集,偏移量 50)。

6. 总结

TimeMoEDataset 是一个高效、鲁棒的时间序列数据集加载器,核心功能包括:

  • 多格式数据加载:自动识别二进制和普通文件,支持递归遍历文件夹。
  • 灵活归一化:内置两种常用归一化方法,支持自定义函数,处理边界情况。
  • 高效索引:通过前缀和数组和二分查找,快速定位子数据集,适合大规模数据。

该类为后续模型训练(如 TimeMoeTrainer)提供了统一的数据接口,确保数据预处理的标准化和高效性。

相关推荐
B站计算机毕业设计之家1 小时前
智慧交通项目:Python+PySide6 车辆检测系统 YOLOv8+OpenCV 自定义视频 自定义检测区域 (源码+文档)✅
大数据·python·opencv·yolo·智慧交通·交通·车流量
ting_zh1 小时前
PyTorch、TensorFlow、JAX 简介
人工智能·pytorch·tensorflow
java1234_小锋1 小时前
TensorFlow2 Python深度学习 - 深度学习概述
python·深度学习·tensorflow·tensorflow2·python深度学习
数据与人工智能律师2 小时前
AI的法治迷宫:技术层、模型层、应用层的法律痛点
大数据·网络·人工智能·云计算·区块链
椒颜皮皮虾྅2 小时前
【DeploySharp 】基于DeploySharp 的深度学习模型部署测试平台:安装和使用流程
人工智能·深度学习·开源·c#·openvino
迈火2 小时前
PuLID_ComfyUI:ComfyUI中的图像生成强化插件
开发语言·人工智能·python·深度学习·计算机视觉·stable diffusion·语音识别
AI新兵4 小时前
AI大事记10:从对抗到创造——生成对抗网络 (GANs)
人工智能·神经网络·生成对抗网络
却道天凉_好个秋4 小时前
深度学习(十五):Dropout
人工智能·深度学习·dropout
你好~每一天4 小时前
2025 中小企业 AI 转型:核心岗技能 “怎么证、怎么用”?
人工智能·百度·数据挖掘·数据分析·职业·转行
浔川python社5 小时前
《网络爬虫技术规范与应用指南系列》(xc—5)完
爬虫·python