TIME - MoE 模型代码 3.2——Time-MoE-main/time_moe/datasets/time_moe_dataset.py

源码:GitHub - Time-MoE/Time-MoE: [ICLR 2025 Spotlight] Official implementation of "Time-MoE: Billion-Scale Time Series Foundation Models with Mixture of Experts"

这段代码定义了一个用于时间序列数据处理的 TimeMoEDataset 类,支持多种数据格式和归一化方法。


1. 类定义与初始化 (__init__)

python 复制代码
class TimeMoEDataset(TimeSeriesDataset):
    def __init__(self, data_folder, normalization_method=None):
        self.data_folder = data_folder
        self.normalization_method = normalization_method
        self.datasets = []  # 存储子数据集(BinaryDataset/GeneralDataset)
        self.num_tokens = None  # 总时间点数量

        # 处理归一化方法
        if normalization_method is None:
            self.normalization_method = None
        elif isinstance(normalization_method, str):
            if normalization_method.lower() == 'max':
                self.normalization_method = max_scaler  # 最大值归一化
            elif normalization_method.lower() == 'zero':
                self.normalization_method = zero_scaler  # 标准化(Z-score)
            else:
                raise ValueError(f'未知归一化方法: {normalization_method}')
        else:
            self.normalization_method = normalization_method  # 自定义归一化函数

        # 加载数据:支持二进制文件或普通文件/文件夹
        if BinaryDataset.is_valid_path(self.data_folder):
            ds = BinaryDataset(self.data_folder)
            if len(ds) > 0:
                self.datasets.append(ds)
        elif GeneralDataset.is_valid_path(self.data_folder):
            ds = GeneralDataset(self.data_folder)
            if len(ds) > 0:
                self.datasets.append(ds)
        else:
            # 递归遍历文件夹,加载所有有效文件
            for root, dirs, files in os.walk(self.data_folder):
                for file in files:
                    fn_path = os.path.join(root, file)
                    # 跳过二进制元数据文件,加载普通数据集
                    if file != BinaryDataset.meta_file_name and GeneralDataset.is_valid_path(fn_path):
                        ds = GeneralDataset(fn_path)
                        if len(ds) > 0:
                            self.datasets.append(ds)
                for sub_folder in dirs:
                    folder_path = os.path.join(root, sub_folder)
                    # 检查子文件夹是否为二进制数据集
                    if BinaryDataset.is_valid_path(folder_path):
                        ds = BinaryDataset(folder_path)
                        if len(ds) > 0:
                            self.datasets.append(ds)

        # 计算累计长度数组,用于快速定位子数据集
        self.cumsum_lengths = [0]
        for ds in self.datasets:
            self.cumsum_lengths.append(self.cumsum_lengths[-1] + len(ds))
        self.num_sequences = self.cumsum_lengths[-1]  # 总序列数
  • 数据加载
    • 支持两种数据集类型:BinaryDataset(二进制格式)和 GeneralDataset(普通文本 / CSV 等)。
    • 通过 os.walk 递归遍历文件夹,自动识别有效数据文件,避免手动指定每个文件路径。
  • 归一化处理
    • 内置两种归一化方法:max_scaler(最大值归一化)和 zero_scaler(Z-score 标准化)。
    • 支持自定义归一化函数,通过 normalization_method 参数传入。
  • 子数据集管理
    • cumsum_lengths 数组记录每个子数据集的起始索引(类似前缀和),例如 cumsum_lengths = [0, 100, 300] 表示第一个子数据集有 100 条序列,第二个有 200 条。

2. 序列索引与获取 (__getitem__)

python 复制代码
def __getitem__(self, seq_idx):
    if seq_idx >= self.cumsum_lengths[-1] or seq_idx < 0:
        raise ValueError(f'索引越界: {seq_idx}')

    # 二分查找确定子数据集索引和偏移量
    dataset_idx = binary_search(self.cumsum_lengths, seq_idx)
    dataset_offset = seq_idx - self.cumsum_lengths[dataset_idx]
    seq = self.datasets[dataset_idx][dataset_offset]  # 获取原始序列

    # 应用归一化
    if self.normalization_method is not None:
        seq = self.normalization_method(seq)
    return seq
  • 二分查找 :通过 binary_search 函数在 cumsum_lengths 中快速定位序列所属的子数据集(时间复杂度为 \(O(\log N)\))。
  • 归一化应用 :对获取的序列调用 normalization_method 函数,返回归一化后的数据(numpy 数组)。

3. 辅助方法

4. 归一化函数实现

5. 二分查找函数 (binary_search)

python 复制代码
def binary_search(sorted_list, value):
    low, high = 0, len(sorted_list) - 1
    best_index = -1
    while low <= high:
        mid = (low + high) // 2
        if sorted_list[mid] <= value:  # 寻找最大的不超过value的索引
            best_index = mid
            low = mid + 1
        else:
            high = mid - 1
    return best_index
  • 在有序数组 sorted_list(即 cumsum_lengths)中查找 value 所属的区间,返回子数据集索引。例如,若 cumsum_lengths = [0, 100, 300]value=150 会被定位到索引 1(第二个子数据集,偏移量 50)。

6. 总结

TimeMoEDataset 是一个高效、鲁棒的时间序列数据集加载器,核心功能包括:

  • 多格式数据加载:自动识别二进制和普通文件,支持递归遍历文件夹。
  • 灵活归一化:内置两种常用归一化方法,支持自定义函数,处理边界情况。
  • 高效索引:通过前缀和数组和二分查找,快速定位子数据集,适合大规模数据。

该类为后续模型训练(如 TimeMoeTrainer)提供了统一的数据接口,确保数据预处理的标准化和高效性。

相关推荐
大龄程序员狗哥15 小时前
第47篇:使用Speech-to-Text API快速构建语音应用(操作教程)
人工智能
KKKlucifer15 小时前
数据安全合规自动化:策略落地、审计追溯与风险闭环技术解析
人工智能·安全
m0_7485548115 小时前
golang如何实现用户订阅偏好管理_golang用户订阅偏好管理实现总结
jvm·数据库·python
RWKV元始智能15 小时前
RWKV超并发项目教程,RWKV-LM训练提速40%
人工智能·rnn·深度学习·自然语言处理·开源
dyj09516 小时前
Dify - (一)、本地部署Dify+聊天助手/Agent
人工智能·docker·容器
墨染天姬16 小时前
【AI】Hermes的GEPA算法
人工智能·算法
小超同学你好16 小时前
OpenClaw 深度解析系列 · 第8篇:Learning & Adaptation(学习与自适应)
人工智能·语言模型·chatgpt
紫微AI16 小时前
前端文本测量成了卡死一切创新的最后瓶颈,pretext实现突破了
前端·人工智能·typescript
码途漫谈16 小时前
Easy-Vibe开发篇阅读笔记(四)——前端开发之结合 Agent Skills 美化界面
人工智能·笔记·ai·开源·ai编程
smj2302_7968265216 小时前
解决leetcode第3911题.移除子数组元素后第k小偶数
数据结构·python·算法·leetcode