TIME - MoE 模型代码 3.2——Time-MoE-main/time_moe/datasets/time_moe_dataset.py

源码:GitHub - Time-MoE/Time-MoE: [ICLR 2025 Spotlight] Official implementation of "Time-MoE: Billion-Scale Time Series Foundation Models with Mixture of Experts"

这段代码定义了一个用于时间序列数据处理的 TimeMoEDataset 类,支持多种数据格式和归一化方法。


1. 类定义与初始化 (__init__)

python 复制代码
class TimeMoEDataset(TimeSeriesDataset):
    def __init__(self, data_folder, normalization_method=None):
        self.data_folder = data_folder
        self.normalization_method = normalization_method
        self.datasets = []  # 存储子数据集(BinaryDataset/GeneralDataset)
        self.num_tokens = None  # 总时间点数量

        # 处理归一化方法
        if normalization_method is None:
            self.normalization_method = None
        elif isinstance(normalization_method, str):
            if normalization_method.lower() == 'max':
                self.normalization_method = max_scaler  # 最大值归一化
            elif normalization_method.lower() == 'zero':
                self.normalization_method = zero_scaler  # 标准化(Z-score)
            else:
                raise ValueError(f'未知归一化方法: {normalization_method}')
        else:
            self.normalization_method = normalization_method  # 自定义归一化函数

        # 加载数据:支持二进制文件或普通文件/文件夹
        if BinaryDataset.is_valid_path(self.data_folder):
            ds = BinaryDataset(self.data_folder)
            if len(ds) > 0:
                self.datasets.append(ds)
        elif GeneralDataset.is_valid_path(self.data_folder):
            ds = GeneralDataset(self.data_folder)
            if len(ds) > 0:
                self.datasets.append(ds)
        else:
            # 递归遍历文件夹,加载所有有效文件
            for root, dirs, files in os.walk(self.data_folder):
                for file in files:
                    fn_path = os.path.join(root, file)
                    # 跳过二进制元数据文件,加载普通数据集
                    if file != BinaryDataset.meta_file_name and GeneralDataset.is_valid_path(fn_path):
                        ds = GeneralDataset(fn_path)
                        if len(ds) > 0:
                            self.datasets.append(ds)
                for sub_folder in dirs:
                    folder_path = os.path.join(root, sub_folder)
                    # 检查子文件夹是否为二进制数据集
                    if BinaryDataset.is_valid_path(folder_path):
                        ds = BinaryDataset(folder_path)
                        if len(ds) > 0:
                            self.datasets.append(ds)

        # 计算累计长度数组,用于快速定位子数据集
        self.cumsum_lengths = [0]
        for ds in self.datasets:
            self.cumsum_lengths.append(self.cumsum_lengths[-1] + len(ds))
        self.num_sequences = self.cumsum_lengths[-1]  # 总序列数
  • 数据加载
    • 支持两种数据集类型:BinaryDataset(二进制格式)和 GeneralDataset(普通文本 / CSV 等)。
    • 通过 os.walk 递归遍历文件夹,自动识别有效数据文件,避免手动指定每个文件路径。
  • 归一化处理
    • 内置两种归一化方法:max_scaler(最大值归一化)和 zero_scaler(Z-score 标准化)。
    • 支持自定义归一化函数,通过 normalization_method 参数传入。
  • 子数据集管理
    • cumsum_lengths 数组记录每个子数据集的起始索引(类似前缀和),例如 cumsum_lengths = [0, 100, 300] 表示第一个子数据集有 100 条序列,第二个有 200 条。

2. 序列索引与获取 (__getitem__)

python 复制代码
def __getitem__(self, seq_idx):
    if seq_idx >= self.cumsum_lengths[-1] or seq_idx < 0:
        raise ValueError(f'索引越界: {seq_idx}')

    # 二分查找确定子数据集索引和偏移量
    dataset_idx = binary_search(self.cumsum_lengths, seq_idx)
    dataset_offset = seq_idx - self.cumsum_lengths[dataset_idx]
    seq = self.datasets[dataset_idx][dataset_offset]  # 获取原始序列

    # 应用归一化
    if self.normalization_method is not None:
        seq = self.normalization_method(seq)
    return seq
  • 二分查找 :通过 binary_search 函数在 cumsum_lengths 中快速定位序列所属的子数据集(时间复杂度为 \(O(\log N)\))。
  • 归一化应用 :对获取的序列调用 normalization_method 函数,返回归一化后的数据(numpy 数组)。

3. 辅助方法

4. 归一化函数实现

5. 二分查找函数 (binary_search)

python 复制代码
def binary_search(sorted_list, value):
    low, high = 0, len(sorted_list) - 1
    best_index = -1
    while low <= high:
        mid = (low + high) // 2
        if sorted_list[mid] <= value:  # 寻找最大的不超过value的索引
            best_index = mid
            low = mid + 1
        else:
            high = mid - 1
    return best_index
  • 在有序数组 sorted_list(即 cumsum_lengths)中查找 value 所属的区间,返回子数据集索引。例如,若 cumsum_lengths = [0, 100, 300]value=150 会被定位到索引 1(第二个子数据集,偏移量 50)。

6. 总结

TimeMoEDataset 是一个高效、鲁棒的时间序列数据集加载器,核心功能包括:

  • 多格式数据加载:自动识别二进制和普通文件,支持递归遍历文件夹。
  • 灵活归一化:内置两种常用归一化方法,支持自定义函数,处理边界情况。
  • 高效索引:通过前缀和数组和二分查找,快速定位子数据集,适合大规模数据。

该类为后续模型训练(如 TimeMoeTrainer)提供了统一的数据接口,确保数据预处理的标准化和高效性。

相关推荐
中杯可乐多加冰16 分钟前
【AI落地应用实战】AIGC赋能职场PPT汇报:从效率工具到辅助优化
人工智能·深度学习·神经网络·aigc·powerpoint·ai赋能
东临碣石8226 分钟前
【AI论文】BlenderFusion:基于三维场景的视觉编辑与生成式合成
人工智能
正在走向自律27 分钟前
第二章-AIGC入门-开启AIGC音频探索之旅:从入门到实践(6/36)
人工智能·aigc·音视频·语音识别·ai音乐·ai 音频·智能语音助手
Trent198533 分钟前
影楼精修-智能修图Agent
图像处理·人工智能·计算机视觉·aigc
烟锁池塘柳034 分钟前
【大模型】解码策略:Greedy Search、Beam Search、Top-k/Top-p、Temperature Sampling等
人工智能·深度学习·机器学习
风逸hhh1 小时前
python打卡day58@浙大疏锦行
开发语言·python
盼小辉丶1 小时前
PyTorch实战(14)——条件生成对抗网络(conditional GAN,cGAN)
人工智能·pytorch·生成对抗网络
Allen_LVyingbo2 小时前
数智读书笔记系列035《未来医疗:医疗4.0引领第四次医疗产业变革》
人工智能·经验分享·笔记·健康医疗
zzc9212 小时前
时频图数据集更正程序,去除坐标轴白边及调整对应的标签值
人工智能·深度学习·数据集·标签·时频图·更正·白边
isNotNullX2 小时前
什么是数据分析?常见方法全解析
大数据·数据库·数据仓库·人工智能·数据分析