TIME - MoE 模型代码 3.2——Time-MoE-main/time_moe/datasets/time_moe_dataset.py

源码:GitHub - Time-MoE/Time-MoE: [ICLR 2025 Spotlight] Official implementation of "Time-MoE: Billion-Scale Time Series Foundation Models with Mixture of Experts"

这段代码定义了一个用于时间序列数据处理的 TimeMoEDataset 类,支持多种数据格式和归一化方法。


1. 类定义与初始化 (__init__)

python 复制代码
class TimeMoEDataset(TimeSeriesDataset):
    def __init__(self, data_folder, normalization_method=None):
        self.data_folder = data_folder
        self.normalization_method = normalization_method
        self.datasets = []  # 存储子数据集(BinaryDataset/GeneralDataset)
        self.num_tokens = None  # 总时间点数量

        # 处理归一化方法
        if normalization_method is None:
            self.normalization_method = None
        elif isinstance(normalization_method, str):
            if normalization_method.lower() == 'max':
                self.normalization_method = max_scaler  # 最大值归一化
            elif normalization_method.lower() == 'zero':
                self.normalization_method = zero_scaler  # 标准化(Z-score)
            else:
                raise ValueError(f'未知归一化方法: {normalization_method}')
        else:
            self.normalization_method = normalization_method  # 自定义归一化函数

        # 加载数据:支持二进制文件或普通文件/文件夹
        if BinaryDataset.is_valid_path(self.data_folder):
            ds = BinaryDataset(self.data_folder)
            if len(ds) > 0:
                self.datasets.append(ds)
        elif GeneralDataset.is_valid_path(self.data_folder):
            ds = GeneralDataset(self.data_folder)
            if len(ds) > 0:
                self.datasets.append(ds)
        else:
            # 递归遍历文件夹,加载所有有效文件
            for root, dirs, files in os.walk(self.data_folder):
                for file in files:
                    fn_path = os.path.join(root, file)
                    # 跳过二进制元数据文件,加载普通数据集
                    if file != BinaryDataset.meta_file_name and GeneralDataset.is_valid_path(fn_path):
                        ds = GeneralDataset(fn_path)
                        if len(ds) > 0:
                            self.datasets.append(ds)
                for sub_folder in dirs:
                    folder_path = os.path.join(root, sub_folder)
                    # 检查子文件夹是否为二进制数据集
                    if BinaryDataset.is_valid_path(folder_path):
                        ds = BinaryDataset(folder_path)
                        if len(ds) > 0:
                            self.datasets.append(ds)

        # 计算累计长度数组,用于快速定位子数据集
        self.cumsum_lengths = [0]
        for ds in self.datasets:
            self.cumsum_lengths.append(self.cumsum_lengths[-1] + len(ds))
        self.num_sequences = self.cumsum_lengths[-1]  # 总序列数
  • 数据加载
    • 支持两种数据集类型:BinaryDataset(二进制格式)和 GeneralDataset(普通文本 / CSV 等)。
    • 通过 os.walk 递归遍历文件夹,自动识别有效数据文件,避免手动指定每个文件路径。
  • 归一化处理
    • 内置两种归一化方法:max_scaler(最大值归一化)和 zero_scaler(Z-score 标准化)。
    • 支持自定义归一化函数,通过 normalization_method 参数传入。
  • 子数据集管理
    • cumsum_lengths 数组记录每个子数据集的起始索引(类似前缀和),例如 cumsum_lengths = [0, 100, 300] 表示第一个子数据集有 100 条序列,第二个有 200 条。

2. 序列索引与获取 (__getitem__)

python 复制代码
def __getitem__(self, seq_idx):
    if seq_idx >= self.cumsum_lengths[-1] or seq_idx < 0:
        raise ValueError(f'索引越界: {seq_idx}')

    # 二分查找确定子数据集索引和偏移量
    dataset_idx = binary_search(self.cumsum_lengths, seq_idx)
    dataset_offset = seq_idx - self.cumsum_lengths[dataset_idx]
    seq = self.datasets[dataset_idx][dataset_offset]  # 获取原始序列

    # 应用归一化
    if self.normalization_method is not None:
        seq = self.normalization_method(seq)
    return seq
  • 二分查找 :通过 binary_search 函数在 cumsum_lengths 中快速定位序列所属的子数据集(时间复杂度为 \(O(\log N)\))。
  • 归一化应用 :对获取的序列调用 normalization_method 函数,返回归一化后的数据(numpy 数组)。

3. 辅助方法

4. 归一化函数实现

5. 二分查找函数 (binary_search)

python 复制代码
def binary_search(sorted_list, value):
    low, high = 0, len(sorted_list) - 1
    best_index = -1
    while low <= high:
        mid = (low + high) // 2
        if sorted_list[mid] <= value:  # 寻找最大的不超过value的索引
            best_index = mid
            low = mid + 1
        else:
            high = mid - 1
    return best_index
  • 在有序数组 sorted_list(即 cumsum_lengths)中查找 value 所属的区间,返回子数据集索引。例如,若 cumsum_lengths = [0, 100, 300]value=150 会被定位到索引 1(第二个子数据集,偏移量 50)。

6. 总结

TimeMoEDataset 是一个高效、鲁棒的时间序列数据集加载器,核心功能包括:

  • 多格式数据加载:自动识别二进制和普通文件,支持递归遍历文件夹。
  • 灵活归一化:内置两种常用归一化方法,支持自定义函数,处理边界情况。
  • 高效索引:通过前缀和数组和二分查找,快速定位子数据集,适合大规模数据。

该类为后续模型训练(如 TimeMoeTrainer)提供了统一的数据接口,确保数据预处理的标准化和高效性。

相关推荐
go546315846513 分钟前
基于深度学习的食管癌右喉返神经旁淋巴结预测系统研究
图像处理·人工智能·深度学习·神经网络·算法
Blossom.11815 分钟前
基于深度学习的图像分类:使用Capsule Networks实现高效分类
人工智能·python·深度学习·神经网络·机器学习·分类·数据挖掘
CodeCraft Studio23 分钟前
借助Aspose.HTML控件,在 Python 中将 HTML 转换为 Markdown
开发语言·python·html·markdown·aspose·html转markdown·asposel.html
悠哉悠哉愿意42 分钟前
【电赛学习笔记】MaxiCAM 项目实践——与单片机的串口通信
笔记·python·单片机·嵌入式硬件·学习·视觉检测
封奚泽优1 小时前
使用Python实现单词记忆软件
开发语言·python·random·qpushbutton·qtwidgets·qtcore·qtgui
想变成树袋熊1 小时前
【自用】NLP算法面经(6)
人工智能·算法·自然语言处理
Goona_1 小时前
拒绝SQL恐惧:用Python+pyqt打造任意Excel数据库查询系统
数据库·python·sql·excel·pyqt
格林威1 小时前
Baumer工业相机堡盟工业相机如何通过YoloV8深度学习模型实现沙滩小人检测识别(C#代码UI界面版)
人工智能·深度学习·数码相机·yolo·计算机视觉
checkcheckck2 小时前
spring ai 适配 流式回答、mcp、milvus向量数据库、rag、聊天会话记忆
人工智能
Microvision维视智造2 小时前
从“人工眼”到‘智能眼’:EZ-Vision视觉系统如何重构生产线视觉检测精度?
图像处理·人工智能·重构·视觉检测