读代码3:OLMo3全详解 - layer2--Data (上)

arxiv:2 OLMo 2 Furious Olmo 3

github:GitHub - allenai/OLMo-core: PyTorch building blocks for the OLMo ecosystem

读代码0:OLMo3全详解 - 从OLMo 3 Tech Report开始

读代码1:OLMo3全详解 - 安装、准备与项目架构

读代码2:OLMo3全详解 - layer1--Foundation

读代码3:OLMo3全详解 - layer2--Data (下)

0.OLMo3架构树

架构树中对文件的重要性采用如下标记:# 表示当前主线必读;x 表示基础定义且复杂度较低,可快速略读;- 表示辅助模块、扩展实现或后续阅读部分;A* 表示按需进入的专题分支。

因文档篇幅限制 Layer 2分为了上下两篇,源代码中用到的工具类函数,逻辑没那么显然的我都放在了下篇的附录读代码3:OLMo3全详解 - layer2--Data (下)中,可以按需查阅。

OLMo-core 的 data/ 目录内部实际上并存两种组织方式:一条是当前 OLMo3 预训练主要依赖的官方数据主线(mixes/__init__py→ source_mixture.py → numpy_dataset.py → data_loader.py → collator.py),另一条是位于 data/composable/ 下的可组合式数据框架。前者更接近"官方默认 pipeline",后者则更像为灵活实验和自定义数据流准备的并行抽象层。

复制代码
olmo_core/
├─ data/                      [D2]  # 数据层:官方预训练数据主线 + composable 可组合数据框架
│  ├─ __init__.py             [D2]
│  │
│  ├─ types.py                [D2]  x 数据层基础类型定义(dtype、long-doc strategy 等)
│  ├─ tokenizer.py            [D2]  x tokenizer 配置抽象(vocab / special tokens / identifier)
│  │
│  ├─ mixes/                  [D2]  # 预定义 mix 配方资源:mix 名称 → (label, relative path template)
│  │  ├─ __init__.py          [D2]  # DataMix / DataMixBase;负责 mix 名称 → paths + labels
│  │  ├─ dolma17.txt
│  │  ├─ OLMo-longmino-mix-0625.txt
│  │  ├─ OLMo-longmino-mix-0925.txt
│  │  ├─ OLMo-midtraining-mix-0625-100B.txt
│  │  ├─ OLMo-midtraining-mix-0925-ingredient1-100B.txt
│  │  ├─ OLMo-midtraining-mix-0925-ingredient2-100B.txt
│  │  ├─ OLMo-mix-0625.txt
│  │  ├─ OLMo-mix-0625-150Bsample.txt
│  │  ├─ OLMo-mix-0625-700Bsample.txt
│  │  ├─ OLMo-mix-0625-official.txt
│  │  ├─ OLMo-mix-0925.txt
│  │  ├─ OLMo-mix-0925-official.txt
│  │  ├─ OLMoE-mix-0824.txt
│  │  └─ v3-small-ppl-validation.txt
│  │
│  ├─ source_mixtures/        [D2]  # 结构化 source mixture 配置(YAML;source 配比/过滤规则)
│  │  └─ OLMo3-32B-midtraining-modelnamefilter.yaml
│  │
│  ├─ source_mixture.py       [D2]  # 官方主线:paths + labels → source/path-level token budget
│  ├─ numpy_dataset.py        [D2]  # 官方主线:token arrays / token budget → instance
│  ├─ data_loader.py          [D2]  # 官方主线:dataset/instance → dataloader iteration
│  ├─ collator.py             [D2]  # 官方主线:sample list → batch(padding、mask、字段组织)
│  ├─ utils.py                [D2]  - 数据工具函数(segment / pack / bucket / slice 等)
│  │
│  └─ composable/             [D2]  - 并行存在的可组合数据管线框架(不是当前官方预训练主线的必经部分)
│     ├─ __init__.py
│     ├─ source_abc.py
│     ├─ instance_source.py
│     ├─ token_source.py
│     ├─ numpy_document_source.py
│     ├─ random_instance_source.py
│     ├─ sampling_document_source.py
│     ├─ sampling_instance_source.py
│     ├─ sampling_token_source.py
│     ├─ sliced_instance_source.py
│     ├─ sliced_token_source.py
│     ├─ packing_instance_source.py
│     ├─ concat_and_chunk_instance_source.py
│     ├─ mixing_document_source.py
│     ├─ mixing_instance_source.py
│     ├─ mixing_token_source.py
│     ├─ data_loader.py
│     ├─ utils.py
│     └─ visualize.py

olmo_core/  # OLMo-core:基础设施 → 数据 → 模型(nn) → 分布式(distributed) → 训练(train) → 生成(generate) → 评测(eval)
├─ __init__.py                [F1]  x 包入口/对外导出,此处为空
├─ aliases.py                 [F1]  x 类型别名;目前主要是 PathOrStr = Union[Path, PathLike, str]
├─ config.py                  [F1]  # 配置系统基类 Config;支持 dataclass ↔ dict/YAML/JSON;并定义 DType / StrEnum
├─ doc_utils.py               [F1]  x beta_feature 等文档/标记工具(很短,扫一眼)
├─ exceptions.py              [F1]  x 统一异常定义(很短,扫一眼)
├─ fs_cache.py                [F1]  # 文件系统缓存装饰器
├─ io.py                      [F1]  # 统一 I/O 抽象层(本地/S3/R2/Weka 等)
├─ py.typed                   [F1]  x 类型标记文件
├─ script_utils.py            [F1]  - CLI/脚本辅助;按需再读
├─ utils.py                   [F1]  - 通用工具;按需再读
├─ version.py                 [F1]  x 版本号
│
├─ data/                      [D2]  # 数据层:官方预训练数据主线 + composable 可组合数据框架
│  ├─ __init__.py             [D2]
│  │
│  ├─ types.py                [D2]  x 数据层基础类型定义(dtype、long-doc strategy 等)
│  ├─ tokenizer.py            [D2]  x tokenizer 配置抽象(vocab / special tokens / identifier)
│  │
│  ├─ mixes/                  [D2]  # 预定义 mix 配方资源:mix 名称 → (label, relative path template)
│  │  ├─ __init__.py          [D2]  # DataMix / DataMixBase;负责 mix 名称 → paths + labels
│  │  ├─ dolma17.txt
│  │  ├─ OLMo-longmino-mix-0625.txt
│  │  ├─ OLMo-longmino-mix-0925.txt
│  │  ├─ OLMo-midtraining-mix-0625-100B.txt
│  │  ├─ OLMo-midtraining-mix-0925-ingredient1-100B.txt
│  │  ├─ OLMo-midtraining-mix-0925-ingredient2-100B.txt
│  │  ├─ OLMo-mix-0625.txt
│  │  ├─ OLMo-mix-0625-150Bsample.txt
│  │  ├─ OLMo-mix-0625-700Bsample.txt
│  │  ├─ OLMo-mix-0625-official.txt
│  │  ├─ OLMo-mix-0925.txt
│  │  ├─ OLMo-mix-0925-official.txt
│  │  ├─ OLMoE-mix-0824.txt
│  │  └─ v3-small-ppl-validation.txt
│  │
│  ├─ source_mixtures/        [D2]  # 结构化 source mixture 配置(YAML;source 配比/过滤规则)
│  │  └─ OLMo3-32B-midtraining-modelnamefilter.yaml
│  │
│  ├─ source_mixture.py       [D2]  # 官方主线:paths + labels → source/path-level token budget
│  ├─ numpy_dataset.py        [D2]  # 官方主线:token arrays / token budget → instance
│  ├─ data_loader.py          [D2]  # 官方主线:dataset/instance → dataloader iteration
│  ├─ collator.py             [D2]  # 官方主线:sample list → batch(padding、mask、字段组织)
│  ├─ utils.py                [D2]  - 数据工具函数(segment / pack / bucket / slice 等)
│  │
│  └─ composable/             [D2]  - 并行存在的可组合数据管线框架(不是当前官方预训练主线的必经部分)
│     ├─ __init__.py
│     ├─ source_abc.py
│     ├─ instance_source.py
│     ├─ token_source.py
│     ├─ numpy_document_source.py
│     ├─ random_instance_source.py
│     ├─ sampling_document_source.py
│     ├─ sampling_instance_source.py
│     ├─ sampling_token_source.py
│     ├─ sliced_instance_source.py
│     ├─ sliced_token_source.py
│     ├─ packing_instance_source.py
│     ├─ concat_and_chunk_instance_source.py
│     ├─ mixing_document_source.py
│     ├─ mixing_instance_source.py
│     ├─ mixing_token_source.py
│     ├─ data_loader.py
│     ├─ utils.py
│     └─ visualize.py
│
├─ nn/                        [M3]  # 模型层:神经网络组件、attention、transformer、MoE、HF 兼容等
│  ├─ __init__.py             [M3]
│  ├─ config.py               [M3]  # nn 配置聚合层
│  ├─ buffer_cache.py         [M3]  # buffer/KV/临时张量缓存工具
│  ├─ feed_forward.py         [M3]  # dense FFN / MLP
│  ├─ layer_norm.py           [M3]  # LayerNorm / RMSNorm
│  ├─ lm_head.py              [M3]  # LM head / 输出投影
│  ├─ residual_stream.py      [M3]  # residual stream 抽象/工具
│  ├─ rope.py                 [M3]  # RoPE
│  ├─ utils.py                [M3]  # nn 工具
│  ├─ convolution.py          [A*]  # 卷积模块(特定架构/实验)
│  ├─ cross_entropy_loss.py   [A*]  # CE loss 模块封装
│  ├─ attention/              [M3]  # 注意力子系统:后端、KV-cache、实现适配
│  │  ├─ __init__.py          [M3]
│  │  ├─ backend.py           [M3]  # attention 后端选择/统一接口
│  │  ├─ base.py              [M3]  # attention 基类/通用逻辑
│  │  ├─ flash_attn_api.py    [M3]  # flash-attn 适配
│  │  ├─ flash_linear_attn_api.py [M3]  # flash-linear-attn 适配
│  │  ├─ te_attn_api.py       [M3]  # transformer-engine attention 适配
│  │  ├─ kv_cache.py          [M3]  # KV-cache 数据结构与更新逻辑
│  │  ├─ recurrent.py         [A*]  # recurrent/streaming attention
│  │  └─ ring.py              [A*]  # ring attention / ring-based communication
│  ├─ moe/                    [M3]  # MoE 子系统
│  │  ├─ __init__.py          [M3]
│  │  ├─ loss.py              [M3]
│  │  ├─ mlp.py               [M3]
│  │  ├─ moe.py               [M3]
│  │  ├─ parallel_mlp.py      [M3]
│  │  └─ router.py            [M3]
│  ├─ transformer/            [M3]  # Transformer 主体
│  │  ├─ __init__.py          [M3]
│  │  ├─ config.py            [M3]
│  │  ├─ block.py             [M3]  # Transformer block(attention+MLP+residual)
│  │  ├─ init.py              [M3]
│  │  └─ model.py             [M3]  # Transformer 总装:embedding、blocks、lm_head、forward
│  ├─ conversion/             [A*]  # state dict / checkpoint 格式转换
│  │  ├─ __init__.py          [A*]
│  │  ├─ state_converter.py   [A*]
│  │  └─ state_mapping.py     [A*]
│  ├─ functional/             [A*]  # 函数式实现集合
│  │  ├─ __init__.py          [A*]
│  │  └─ cross_entropy_loss.py[A*]
│  └─ hf/                     [A*]  # HuggingFace 兼容层
│     ├─ __init__.py          [A*]
│     ├─ checkpoint.py        [A*]
│     ├─ config.py            [A*]
│     └─ convert.py           [A*]
│
├─ kernels/                   [M3]  # 自定义/加速 kernels
│  ├─ __init__.py             [M3]
│  └─ moe.py                  [M3]  # MoE kernel/接口
│
├─ ops/                       [M3]  # 自定义 op 封装(高于 kernels 一层)
│  ├─ __init__.py             [M3]
│  └─ moe.py                  [M3]
│
├─ distributed/               [X4]  # 分布式基础设施:rank/world、并行策略、checkpoint 后端
│  ├─ __init__.py             [X4]
│  ├─ utils.py                [X4]  # rank/world/group/barrier/collective helper
│  ├─ nn.py                   [X4]  # 分布式场景下的 nn 辅助封装
│  ├─ checkpoint/
│  │  ├─ __init__.py          [X4]
│  │  └─ filesystem.py        [X4]  # checkpoint 文件系统后端/适配
│  └─ parallel/
│     ├─ __init__.py          [X4]
│     ├─ data_parallel.py     [X4]  # DP 封装
│     ├─ tensor_parallel.py   [X4]  # TP 封装
│     ├─ pipeline_parallel.py [X4]  # PP 封装
│     ├─ expert_parallel.py   [X4]  # EP / MoE 并行
│     └─ context_parallel.py  [X4]  # CP / 长上下文并行
│
├─ train/                     [T5]  # 训练系统:Trainer、callbacks、train_module
│  ├─ __init__.py             [T5]
│  ├─ common.py               [T5]  # train 公共类型/辅助
│  ├─ config.py               [T5]  # train 总配置
│  ├─ checkpoint.py           [T5]  # train 侧 checkpoint 逻辑
│  ├─ trainer.py              [T5]  # Trainer 主循环(fit/step/callback 调度/分布式)
│  ├─ utils.py                [T5]
│  ├─ callbacks/
│  │  ├─ __init__.py          [T5]
│  │  ├─ callback.py          [T5]  # callback 基类
│  │  ├─ batch_size_scheduler.py      [T5]
│  │  ├─ beaker.py                      [T5]
│  │  ├─ checkpointer.py                [T5]
│  │  ├─ comet.py                       [T5]
│  │  ├─ config_saver.py                [T5]
│  │  ├─ console_logger.py              [T5]
│  │  ├─ evaluator_callback.py          [T5]
│  │  ├─ gap_monitor.py                 [T5]
│  │  ├─ garbage_collector.py           [T5]
│  │  ├─ gpu_memory_monitor.py          [T5]
│  │  ├─ list_checkpointer.py           [T5]
│  │  ├─ metric_saver.py                [T5]
│  │  ├─ monkey_patcher.py              [T5]
│  │  ├─ profiler.py                    [T5]
│  │  ├─ sequence_length_scheduler.py   [T5]
│  │  ├─ slack_notifier.py              [T5]
│  │  ├─ speed_monitor.py               [T5]
│  │  └─ wandb.py                       [T5]
│  └─ train_module/
│     ├─ __init__.py          [T5]
│     ├─ config.py            [T5]
│     ├─ train_module.py      [T5]  # train_module 抽象:把 model/loss/metrics 打包成可训练单元
│     └─ transformer/
│        ├─ __init__.py       [T5]
│        ├─ config.py         [T5]
│        └─ train_module.py   [T5]
│
├─ generate/                  [G6]  # 生成/采样/聊天接口;更偏 inference/runtime
│  ├─ __init__.py             [G6]
│  ├─ sampling.py             [G6]  # top-k/top-p/temp/greedy 等采样
│  ├─ utils.py                [G6]  # logits 处理、停止条件等
│  ├─ chat.py                 [G6]  # chat 封装(prompt/role/模板等)
│  └─ generation_module/
│     ├─ __init__.py          [G6]
│     ├─ config.py            [G6]
│     ├─ generation_module.py [G6]  # generation module 基类/接口
│     └─ transformer/
│        ├─ __init__.py       [G6]
│        ├─ config.py         [G6]
│        └─ generation_module.py [G6]
│
├─ eval/                      [E7]  # 评测与 metrics;通常在训练/生成主线后再进入
│  ├─ __init__.py             [E7]
│  ├─ evaluator.py            [E7]  # evaluator:组织评测流程、运行与汇总
│  ├─ lm_evaluator.py         [E7]  # LM 评测适配
│  ├─ metrics.py              [E7]  # 指标定义与计算
│  └─ task_groups.py          [E7]  # 任务组/基准集合
│
├─ launch/                    [A*]  # 启动/编排(Beaker/GCP 等);按需读
│  ├─ __init__.py             [A*]
│  ├─ beaker.py               [A*]
│  ├─ reorder_ranks_in_gcp.py [A*]
│  ├─ select_beaker_hosts.py  [A*]
│  └─ utils.py                [A*]
│
├─ optim/                     [A*]  # 优化器与 scheduler;读 train 时按需进入
│  ├─ __init__.py             [A*]
│  ├─ config.py               [A*]
│  ├─ scheduler.py            [A*]
│  ├─ adam.py                 [A*]
│  ├─ adamw.py                [A*]
│  ├─ dion.py                 [A*]
│  ├─ lion.py                 [A*]
│  ├─ muon.py                 [A*]
│  ├─ noop.py                 [A*]
│  └─ skip_step_optimizer.py  [A*]
│
├─ float8/                    [A*]  # FP8 / torchao 相关封装;按需
│  ├─ __init__.py             [A*]
│  └─ ao.py                   [A*]
│
├─ internal/                  [A*]  # 内部实验/工作流 glue;外部阅读优先级低
│  ├─ __init__.py             [A*]
│  ├─ common.py               [A*]
│  ├─ cookbook.py             [A*]
│  ├─ experiment.py           [A*]
│  └─ ladder.py               [A*]
│
├─ model_ladder/              [A*]  # 配置生成/多规模 run 组织;按需
│  ├─ __init__.py             [A*]
│  ├─ base.py                 [A*]
│  ├─ transformer_model_configurator.py   [A*]
│  ├─ wsds_chinchilla_run_configurator.py [A*]
│  └─ utils.py                [A*]
│
└─ testing/                   [A*]  # 库内测试辅助;按需
   ├─ __init__.py             [A*]
   ├─ distributed.py          [A*]
   └─ utils.py                [A*]

1 预训练阶段的数据层到底在做什么?

在讨论大模型预训练时,我们很容易把注意力放在模型结构、并行策略、优化器这些位置上。但如果从训练系统真正跑起来的角度看,数据层其实同样关键。因为不管模型多大、并行多复杂,训练循环最终都要不断地回答一个非常朴素的问题:

这一轮 forward/backward,要吃什么数据?这些数据又是怎么被组织出来的?

如果再把问题说得具体一点,在 OLMo 这套代码里,预训练阶段的数据层要解决的,并不是"如何把原始文本分词"这个问题。因为在进入训练主路径之前,原始文档通常已经完成了清洗、切分、tokenization,并落成了按路径组织的 token id 文件,也就是 .npy 文件所保存的token arrays。也就是说,训练阶段面对的,不再是自然语言文本,而是已经准备好的 token 序列。

所以,预训练 data layer 真正要解决的问题其实是:

在已经得到大量 token id arrays 之后,如何把这些分散在不同 source、不同 path 下的 token 数据,持续组织成训练循环可以稳定消费的 batch?

这个问题如果继续往下拆,其实可以拆成四个更具体的子问题:

  1. 到底有哪些数据可以读? 也就是配置里的 mix 名称,最终对应真实存在的数据路径。
  2. 每个 source / path 这一轮应该取多少 token? 也就是 token budget planning:总共要消费多少 token,这些 token 在不同来源之间怎么分配。
  3. 这些 token 最终如何构造成训练样本? 也就是 dataset 层要回答的问题:什么算一个 instance,如何把底层 token 序列组织成 sample。
  4. 这些 sample 又如何进一步组成 batch? 也就是 DataLoader 和 DataCollator 要解决的问题:如何把 instance 变成模型前向传播真正接收的 batch 输入。

换句话说,预训练数据层并不是一个简单的"读文件模块",它更像是一条组织链路:上游面对的是静态存储好的 token 数据,中游要做预算和样本物化,下游则要把样本整理成训练循环能直接消费的 batch。如果用一条最直观的业务流程来表示,这条链路大概是这样的:

复制代码
raw document
↓
data cleaning
↓
tokenizer
↓
token ids      ← 这里才是Data Layer实际接触到的,构成了Data Layer一切处理的基础
↓
source mixture planning
↓
sample construction / dataset materialization
↓
DataLoader
↓
DataCollator
↓
batch
↓
model forward

这条流程里,前半段更接近离线数据生产:把原始文档整理成 token ids,并落到存储系统中。后半段才是训练循环真正依赖的 runtime data pipeline:训练时并不是重新处理原始文本,而是从已经准备好的 token 数据出发,按配置组织出一轮又一轮 batch。

这里有一个很重要的认识:训练系统真正关心的不是"有多少文件",而是"这一轮需要多少 token,这些 token 能否被稳定地组织成 batch"。举个最直观的例子。假设当前训练配置中:

  • sequence_length = 2048
  • micro_batch_size = 2
  • data_parallel_size = 2
  • gradient_accumulation_steps = 4

那么一次真正的参数更新(optimizer.step()),覆盖的 global instance 数就是:

2 × 2 × 4 = 16

若每个 instance 的长度都是 2048,则对应的 global token 数为:

16 × 2048 = 32768

在后续 OLMo 的代码语境中,global_batch_size 更常按 token 数 来理解,因此这里更准确地说,是"每步覆盖 16 个 instance,对应 32768 个 token"。

这样一来,问题就变得很清楚了。训练系统关心的不是某个 .npy 文件里存了多少 token,而是:

  • 这一轮总共要消费多少 token;
  • 这些 token 应该由哪些 source 提供;
  • 它们要如何被组织成固定长度的 instance;
  • 每个 rank 要如何拿到自己那一份 local batch;
  • 整个过程是否足够规整,以保证预训练阶段的吞吐率。

这也是为什么 OLMo 的 data layer 并不是围绕"文件读取"来设计的,而是围绕"如何把静态 token 数据逐层组织成训练输入"来设计的。

2 从业务流到抽象对象再到项目文件

从上述业务流出发,可以抽象出以下几类核心对象,它们共同构成了数据层的骨架:

复制代码
mix
→ source
→ path
→ token budget
→ instance
→ sample list
→ batch

2.1 mix:训练想读哪一批数据

最上游的对象不是 path,也不是 sample,而是 mix。

在训练配置里,我们通常不会直接手写一长串 .npy 文件路径,而是使用一个已经定义好的 mixture 名称,比如某个 pretraining mix、midtraining mix 或 validation mix。这个 mix 名称本身不是数据,它更像一个数据配方名:它规定了当前训练打算读取哪一类数据组合。

所以,mix 的作用并不是直接提供样本,而是给训练配置提供一个稳定的数据入口。配置里写的是一个抽象名称,系统再根据这个名称去解析出真实的数据来源和路径。

2.2 source:mix 的基本单元

当 mix 被展开之后,下一层对象就是 source。source 可以理解为 mix 内部的"数据来源单元"。它比 path 更高一层,也比 mix 更具体。比如说,一个 mix 里可能同时包含数学数据、代码数据、网页文本数据,那么这些"数学""代码""网页"就可以看成不同的 source。

source 这一层非常关键,因为后面的很多规划,尤其是 token budget planning,基本都是围绕 source 来做的。source 至少承担三种作用:

  • 它是配比单位。当我们说"数学数据占 75%,代码占 25%"时,配比不是在 path 级别做的,而是在 source 级别做的。
  • 它是约束单位 。像 max_source_fractionmax_repetition_ratio 这样的约束,限制的也不是单个文件,而是整个 source 最多能贡献多少 token。
  • 它是统计单位。在真正分配预算之前,系统首先会统计每个 source 的 token population,也就是这一类来源总体可提供多少 token,然后再根据目标比例决定它这一轮该贡献多少。

总体而言,source 是 mix 内部用于做配比、约束和统计的基本单元。

2.3 path:source 在存储层的具体载体

source 再往下,才是 path。

path 这一层相对直观,因为它已经接近真实的文件系统或对象存储了。一个 source 通常不会只对应一个文件,而会对应一组 tokenized data files。这些文件就是 source 在存储层的具体载体。

但 path 这一层也不能简单理解成"只是文件名"。在这条 data pipeline 里,path 更准确地说,是 source 的文件级承载形式。它回答的是:

  • 这个 source 具体由哪些 token files 构成;
  • 这些 files 在哪里;
  • 每个 file 大概有多少 token;
  • 后续 budget 该如何分配到这些 file 上。

所以,如果 mix 是"数据配方名",source 是"来源单元",那 path 就是真正落到存储系统上的文件级对象。而到了 path 这一层,我们不再说"数学数据占多少比例",而会开始说"这一轮从哪些具体文件里各取多少 token"。

2.4 token budget:训练时的关键约束

source_mixture.py 开始,data layer 的重点就逐渐从"数据来自哪里",转向了另一个更核心的对象:token budget

token budget 揭示了预训练 data layer 的核心不在于"有多少数据文件",而是当前这轮训练要消费多少 token,这些 token 如何按 step 结构被稳定地组织出来。也就是说,训练系统面对的关键问题不是:

这批数据一共有几个文件?

而是

这轮训练总共要消费多少 token?这些 token 在不同 source 之间怎么分?最后又如何对齐成整数个 step 所需的 fixed-length instances?

在这个层面上讲,token budget 本质上是一种面向训练循环的资源配额对象。如果再拆细一点,token budget 至少有三层:

  1. global token budget :这一轮总共想消费多少 token,比如 requested_tokens = 1,000,000
  2. source-level token budget:在总预算下,每个 source 分到多少 token,比如 source A 拿 750k,source B 拿 250k。
  3. path-level token budget :在 source 已定的前提下,再把预算分配到 source 下的不同 path 上,比如a0.npy分多少,a1.npy 分多少。

需要特别注意的是,这里的 budget 还不是最终样本,它只是"资源配额"。真正的训练样本还要等到更下游。而 source_mixture.py 的核心职责也不是构造 sample,而是决定"取多少"。

2.5 instance:训练时的基本单位

instance 可以理解为 dataset 层定义出来的最小训练样本单位。因为预训练阶段更注重效率,所以一个 instance 在预训练时往往有固定长度,也就是通常讲的 Fixed Sequence Length ,比如 sequence_length = 2048。在 FSL 模式下,所有的 instance 都是这个长度,也就是一个长度为 2048 的 token 序列,或者带着其他辅助字段的样本结构。

这里还需要注意另一个问题:原始文本往往以 document 为单位进行组织。在按固定大小分词后,每个 path 指向的文件其内部可能包含多个 document,那么 instance 与 document 也就存在多种关系,这是 numpy_dataset.py 最核心的地方之一。因为 dataset 层做的事情,并不是"把 document 原样拿出来",而是"定义什么才算一个训练样本"。比如:

  • FSL 模式中,instance 往往是从 token stream 中按固定长度直接切出来的一段,这个过程是无视 document 边界的;
  • padded 模式中,instance 更接近"一个 document 对应一个样本,不够长就 pad 到 sequence length,过长就截断";
  • packed 模式中,一个 instance 可能由多个 document 紧凑拼起来;
  • interleaved 模式中,一个 instance 甚至可能来自多个文档 chunk 的交错组合。

2.6 sample list:从 dataset 到 batch 之间的中间态

当 dataset 已经能够按索引返回 instance 之后,DataLoader 在 collator 之前通常会先持有一组 instance 列表。为了叙述方便,本文将这一中间态称为 sample list。因为从对象流的角度来看,DataLoader 首先做的,并不是把样本直接变成张量 batch,而是从 dataset 中取出一组 instance,把它们放到同一个列表里。这个列表里的每个元素仍然是独立 sample(此时 instance 在 DataLoader 上下文中被称为 sample),它们还没有被 pad、stack、张量化。真正负责把这组样本整理成 batch 的,是 collator:

复制代码
instance / sample
→ List[sample]
→ collator
→ batch

2.7 batch:当前 rank 真正送入模型前向的输入

再往下,才终于到了 batch。这一步是 data layer 的终点,也是训练时真正使用的对象。到了 batch 这一层,数据已经不再是"某个样本对象"或者"某个 sample 列表",而是一个张量化后的结构化输入。下游通过 batch["input_ids"]batch["labels"] 之类的字段访问的,就是这一层的数据。

同时,batch 还有一个很重要但很容易被忽略的特点:这里说的 batch 往往是当前 rank 上的 local batch,也就是当前 DP rank 在当前 step 中真正送进 model forward 的那一份输入,而不是"整个全局训练逻辑上的抽象 batch"。

所以,batch 至少同时具有两个属性:

  • 它是张量化后的输入结构 。不只是 input_ids,还可能包含 attention_masklabelsmetadataboundary 信息等。
  • 它是 rank-local 的训练输入单元。在当前 rank 上做前向传播的真正输入。

2.8 从抽象对象到项目文件

因为这个专栏的目的主要还是看代码,所以还需要从抽象对象再落回到项目代码中:

  • mixes/__init__.py 主要处理的是 mix → paths / labels
  • source_mixture.py 主要处理的是 source / path → token budget
  • numpy_dataset.py 主要处理的是 token arrays / budget → instance
  • data_loader.py 主要处理的是 instance → sample list
  • collator.py 主要处理的是 sample list → batch

2.9 两个简单的基础文件

2.9.1 types.py

文件很小,主要是定义 Data Layer 的基础类型(type definitions)主要包括两类:Dataset token dtype(uint8 uint16 uint32 uint64)和长文档处理策略(truncate fragment),如果是从上一篇看过来的话,就会觉得很简单了,唯一要注意的就是Union部分,其余部分不做详细注释:

python 复制代码
from typing import Type, Union

import numpy as np

from olmo_core.config import StrEnum

NumpyUIntTypes = Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]]
# Numpy dataset中通常使用 np.array(..., dtype=np.uint16)。
# 在typing中,Union[np.uint8, np.uint16]表示"变量是这些类型(np.uint8/np.uint16)的实例";
# 但这里dataset代码实际传递的是dtype类对象本身(如np.uint16),所以要写成Type[np.uint16]。

class LongDocStrategy(StrEnum):
    """
    Specifies how to handle documents that are longer than the max sequence length when packing.
    """

    truncate = "truncate" # 超过某一长度的文档会被阶段并丢弃
    """
    Long docs are truncated and the excess tokens are discarded.
    """

    fragment = "fragment" # 长文档会切成一个标准长和一个较短的部分,不会丢弃文档
    """
    Long docs are split into smaller docs so that no tokens are discarded, but you end up with
    fragmented docs.
    """


class NumpyDatasetDType(StrEnum):
    """
    Supported numpy unsigned integer data types for datasets.
    """

    uint8 = "uint8"
    uint16 = "uint16"
    uint32 = "uint32"
    uint64 = "uint64"

    def as_np_dtype(self) -> NumpyUIntTypes:
        """
        Convert the enum value to its corresponding numpy dtype.

        Returns:
            The numpy unsigned integer dtype corresponding to this enum value.
        """
        return getattr(np, str(self))

2.9.2 tokenizer.py

一部分负责定义tokenizer的配置抽象,而不是tokenizer的具体分词逻辑。整体也很简单,包括TokenizerNameTokenizerConfig两部分,值得关注的部分在于padded_vocab_size,将vocab_size向上对齐到某个数值的整数倍,从而减少部分shape不规则带来的性能损失。

python 复制代码
from dataclasses import dataclass
from typing import Optional

from ..config import Config, StrEnum

__all__ = [
    "TokenizerConfig",
    "TokenizerName",
]


class TokenizerName(StrEnum): # 实现分词器与name之间映射的枚举类
    """
    An enumeration of tokenizer identifiers commonly used OLMo researchers.
    """

    dolma2 = "allenai/dolma2-tokenizer"
    """
    The dolma2 tokenizer.
    """

    dolma2_sigdig = "allenai/dolma2-tokenizer-sigdig"
    """
    The R2L dolma2 tokenizer.
    """

    gpt_neox_olmo_dolma_v1_5 = "allenai/gpt-neox-olmo-dolma-v1_5"
    """
    A modified GPT NeoX tokenizer.
    """

    gpt2 = "gpt2"
    """
    The base GPT2 tokenizer.
    """


@dataclass
class TokenizerConfig(Config): # 继承自..config中的Config,支持as_dict(),from_dict(),merge()等
    """
    A configuration class that represents a tokenizer.
    """

    vocab_size: int
    """
    The vocab size.
    """

    eos_token_id: int
    """
    The end-of-sentence token ID.
    """

    pad_token_id: int
    """
    The padding token ID.
    """

    bos_token_id: Optional[int] = None
    """
    The begin-of-sentence token ID.
    """

    identifier: Optional[str] = None
    """
    The identifier of the tokenizer. Could be a path or HuggingFace identifier.
    """

    def padded_vocab_size(self, pad_multiple: int = 128) -> int: # 把vocab_size向上补齐到pad_multiple的整数倍
        """
        Returns the vocab size padded to be a multiple of ``pad_multiple``.
        This is useful to set model embeddings to this number to increase throughput.
        减少某些不规则 shape 带来的性能损失
        """
        return pad_multiple * ((self.vocab_size + pad_multiple - 1) // pad_multiple)

    @classmethod
    def dolma2(cls) -> "TokenizerConfig":
        """
        Get a :data:`~TokenizerName.dolma2` tokenizer config.
        """
        return cls(
            vocab_size=100278,
            eos_token_id=100257,
            pad_token_id=100277,
            identifier=TokenizerName.dolma2,
        )

    @classmethod
    def dolma2_sigdig(cls) -> "TokenizerConfig":
        """
        Get a :data:`~TokenizerName.dolma2_sigdig` tokenizer config.
        """
        return cls(
            vocab_size=100278,
            eos_token_id=100257,
            pad_token_id=100277,
            bos_token_id=100257,
            identifier=TokenizerName.dolma2_sigdig,
        )

    @classmethod
    def gpt_neox_olmo_dolma_v1_5(cls) -> "TokenizerConfig":
        """
        Get a :data:`~TokenizerName.gpt_neox_olmo_dolma_v1_5` tokenizer config.
        """
        return cls(
            vocab_size=50280,
            eos_token_id=50279,
            pad_token_id=1,
            identifier=TokenizerName.gpt_neox_olmo_dolma_v1_5,
        )

    @classmethod
    def gpt2(cls) -> "TokenizerConfig":
        """
        Get a :data:`~TokenizerName.gpt2` tokenizer config.
        """
        return cls(
            vocab_size=50257,
            eos_token_id=50256,
            bos_token_id=50256,
            pad_token_id=50256,
            identifier=TokenizerName.gpt2,
        )

    @classmethod
    def from_hf(cls, identifier: str) -> "TokenizerConfig":
        """
        Initialize a tokenizer config from a model on HuggingFace.
        从 HuggingFace 上的模型/ tokenizer 配置自动构造一个 TokenizerConfig
        :param identifier: The HF model identifier, e.g. "meta-llama/Llama-3.2-1B".
        """
        import json

        from cached_path import cached_path

        try: # 不同hugging face仓库中的配置不一致,先尝试一个路径失败则回退到另一个路径
            config_path = cached_path(f"hf://{identifier}/config.json")
        except FileNotFoundError:
            config_path = cached_path(f"hf://{identifier}/tokenizer_config.json")

        with config_path.open() as f:
            config = json.load(f)

        return cls(
            vocab_size=config["vocab_size"],
            eos_token_id=config["eos_token_id"],
            pad_token_id=config.get("pad_token_id", config["eos_token_id"]), # 存在pad_token_id则获取,否则回退到eos_token_id
            bos_token_id=config.get("bos_token_id"),
            identifier=identifier,
        )

3. mixes/init.py :从 mix 名称到 paths + labels

  1. 这个文件不处理 sample,也不处理 batch。它做的事情很前置,也很简单:把训练配置里的 mix 名称,解析成一组真实的数据路径以及对应的标签。
  2. 这一层解决的是"去哪里找数据"。如果说后面的 source_mixture.py 解决的是"这一轮取多少 token",那 mixes/__init__.py解决的就是更前面的问题:训练配置里写下的这个 mix 名称,最终对应哪些真实存在的 tokenized data files?
  3. 其核心输出是两样东西:pathslabels,最终会形成一个列表,里面每个元素是 (label, path) 这样的二元组。
  4. mix 文件里每一行本质上是一个二元组,比如 longmino, preprocessed/dolma3_longmino_0925/{TOKENIZER}/000000.npy。它由两部分组成:label 表示 provenance / source label,relative_path_template 则表示一个带有 {TOKENIZER} 占位符的相对路径模板。

所以这一步本质上是在做:

mix 名称 → 读取 mix txt → 模板替换 → 得到 paths + labels。

整体不难,主要包含了:

  1. DataMixBase 这个基类 。这里要关注的是接口本身,即以 base_dirtokenizer 为输入,以 pathslabels 为输出。它是一个抽象基类,具体的解析逻辑交由子类实现。
  2. DataMix基类的一个具体实现 。它本质上是一个枚举类,维护着"官方支持的数据配方名",并通过 _missing_() 方法向部分历史命名做兼容,同时实现了 build() 这个核心逻辑。

注释我放到源码中:

python 复制代码
import os
from abc import abstractmethod # 抽象方法,需要子类具体实现
from contextlib import contextmanager
from pathlib import Path
from typing import Generator, List, Tuple

from olmo_core.config import StrEnum

from ..tokenizer import TokenizerName

__all__ = ["DataMixBase", "DataMix"]


class DataMixBase(StrEnum):
    """
    Base class for enumeration of data mixes.
    """

    @abstractmethod
    def build(self, base_dir: str, tokenizer: str) -> Tuple[List[str], List[str]]:
        """
        Construct the data mix.
        输入 base_dir、tokenizer,输出 paths 和 labels。这里的 label 指的是 source label 或者说 provenance label。
        :param base_dir: Where the mix is stored, e.g. "s3://ai2-llm" or "/weka/oe-training-default/ai2-llm".
        :param tokenizer: The tokenizer identifier.

        :returns: A list of paths/URLs to the tokenized numpy data files in the mix and list
            of corresponding labels.
        """
        raise NotImplementedError


class DataMix(DataMixBase): # 使用枚举类解析可以避免在 config 中塞入大量路径
    """
    An enumeration of data mix names.
    """

    # Pretraining mixes
    OLMoE_mix_0824 = "OLMoE-mix-0824"
    dolma17 = "dolma17"
    OLMo_mix_0625 = "OLMo-mix-0625"
    OLMo_mix_0625_150Bsample = "OLMo-mix-0625-150Bsample"
    OLMo_mix_0625_700Bsample = "OLMo-mix-0625-700Bsample"
    OLMo_mix_0625_official = "OLMo-mix-0625-official"
    OLMo_mix_0925 = "OLMo-mix-0925"
    OLMo_mix_0925_official = "OLMo-mix-0925-official"

    # Midtraining mixes
    OLMo_midtraining_mix_0625_100B = "OLMo-midtraining-mix-0725-100B"
    OLMo_midtraining_mix_0925_ingredient1_100B = "OLMo-midtraining-mix-0925-ingredient1-100B"
    OLMo_midtraining_mix_0925_ingredient2_100B = "OLMo-midtraining-mix-0925-ingredient2-100B"

    # Long-context extension mixes
    OLMo_longmino_mix_0625 = "OLMo-longmino-mix-0625"
    OLMo_longmino_mix_0925 = "OLMo-longmino-mix-0925"

    # Validation mixes
    v3_small_ppl_validation = "v3-small-ppl-validation"

    @classmethod
    def _missing_(cls, value: object) -> "DataMix | None":
        # 这里的 _missing_ 是 Enum 类内置的一个 hook,专门用于处理枚举值查找失败时的自定义逻辑。
        # 当尝试创建一个枚举时,比如 DataMix("OLMo-mix-0625"),会首先匹配已有 member,失败时会调用 cls._missing_(value) 再次匹配。
        """Handle alias lookups."""
        # Aliases mapping
        aliases = { # 别名映射
            "dolma3-0625-6T-mix": "OLMo-mix-0625",
            "dolma3-0925-6T-mix": "OLMo-mix-0925",
            "dolma3-0925-150B-mix": "OLMo-mix-0625-150Bsample",
        }

        # Check if the value is an alias
        if isinstance(value, str) and value in aliases: # 根据别名(或者是老名字)查找对应新枚举
            # Look up the real value and return the corresponding enum member
            real_value = aliases[value]
            for member in cls: # 从类中遍历
                if member.value == real_value:
                    return member
        return None

    def build(self, base_dir: str, tokenizer: str) -> Tuple[List[str], List[str]]: # mix → concrete paths
        if not base_dir.endswith("/"): # 规范 base_dir
            base_dir = base_dir + "/"

        tokenizer_id: str = tokenizer # 类型注解 + 赋值
        # 根据不同 mix 和 tokenizer 做映射,将标识名映射到路径中的 id
        if self == DataMix.v3_small_ppl_validation:
            if tokenizer == TokenizerName.gpt_neox_olmo_dolma_v1_5:
                tokenizer_id = "gptneox20b"
            elif tokenizer == TokenizerName.dolma2:
                tokenizer_id = "dolma2-tokenizer"
        elif self == DataMix.OLMo_mix_0625:
            if tokenizer == TokenizerName.dolma2_sigdig:
                tokenizer_id = "dolma2-tokenizer-sigdig"
        elif self in [
            # Mixes used for OLMo3 training are saved with "dolma3-tokenizer" tokenizer,
            # which is exactly the same as "dolma2-tokenizer" but with a different name.
            DataMix.OLMo_mix_0625_official,
            DataMix.OLMo_mix_0925_official,
            DataMix.OLMo_midtraining_mix_0625_100B,
            DataMix.OLMo_midtraining_mix_0925_ingredient1_100B,
            DataMix.OLMo_midtraining_mix_0925_ingredient2_100B,
            DataMix.OLMo_longmino_mix_0625,
            DataMix.OLMo_longmino_mix_0925,
        ]:
            if tokenizer == TokenizerName.dolma2:
                tokenizer_id = "allenai/dolma3-tokenizer"
        elif tokenizer == TokenizerName.gpt_neox_olmo_dolma_v1_5:
            tokenizer_id = "gpt-neox-olmo-dolma-v1_5"

        paths = []
        labels = []
        with _get_data_mix_path(self) as mix_path:
            with mix_path.open() as f: # with 上下文管理打开文件
                # 这里用 OLMo-mix-0925.txt 中的一个例子
                # longmino,preprocessed/dolma3_longmino_0925/{TOKENIZER}/000000.npy
                # 本质上是一个 (label, relative_path_template) 的二元组,是一个带占位符的相对路径模板
                for line_num, line in enumerate(f):
                    line = line.strip() # 去掉首尾空白和换行
                    # line = "longmino,preprocessed/dolma3_longmino_0925/{TOKENIZER}/000000.npy"
                    if not line or line.startswith("#"): # 跳过空行和注释
                        continue
                    label, path = line.split(",") # 按逗号切成两个部分
                    # longmino, preprocessed/dolma3_longmino_0925/{TOKENIZER}/000000.npy
                    if "{TOKENIZER}" not in path: # 必须要有 {TOKENIZER}
                        raise ValueError(f"line {line_num + 1} in data mix '{self}' is invalid")
                    path = path.replace("{TOKENIZER}", tokenizer_id) # 将 {TOKENIZER} 替换成真实 tokenizer_id
                    # 假设 tokenizer_id = "allenai/dolma3-tokenizer"
                    # 则会得到 path = "preprocessed/dolma3_longmino_0925/allenai/dolma3-tokenizer/000000.npy"
                    # 这样就实现了向不同 tokenizer 的动态解析
                    paths.append(f"{base_dir}{path}") # 再拼上 base_dir,添加到 path 列表,比如 "s3://ai2-llm/"
                    labels.append(label)
        return paths, labels


@contextmanager # 使用 contextmanager 装饰,使 yield 前面的代码在进入 with 时执行,yield 返回的值作为 as 的对象,yield 后面的代码在离开 with 时执行
def _get_data_mix_path(name: str) -> Generator[Path, None, None]:
    # 把某个 mix 名称对应到包内资源文件 olmo_core/data/mixes/<name>.txt,
    # 比如 DataMix.OLMo_mix_0925 → olmo_core/data/mixes/OLMo-mix-0925.txt
    import importlib_resources

    try: # importlib_resources.as_file 可能会把 package 内资源临时解压到临时文件,所以是通过 yield 维持 path 外层生命周期而不是直接 return
        with importlib_resources.as_file( 
            importlib_resources.files("olmo_core").joinpath(
                f"data/mixes/{os.path.basename(name)}.txt"
            )
        ) as path:
            yield path
    finally:
        pass

4. source_mixture.py:从 paths + labels 到 token budget

上一章 mixes/__init__.py 把抽象的 mix 名称解析成了一组带标签的路径,即 paths + labels。但到了这一步,问题就不再是"有哪些数据",而变成了另一件更贴近训练循环的事情:

这一轮总共要消费多少 token?这些 token 在不同 source 之间怎么分?最后又怎样落实到各个 path 上?

source_mixture.py 解决的正是这个问题。它虽然仍属于 data layer,但已经不再是"数据入口解析层",而更像一个预算规划层。它站在 dataset 之前,先把 source 级别的 token 配额算清楚,再将这些配额下沉到 path,最后整理成一个可以交给下游 dataset 的 mixture 结果。

4.1 训练视角下的数据需求

在深入代码之前,有必要先理清训练过程中几个与数据相关的核心概念。这些概念直接决定了 source_mixture.py 如何计算 token 配额。

  • instance (或 sample):训练数据的最小单位,通常长度固定为 sequence_length(例如 2048 个 token)。

  • micro-batch :单个设备上一次前向/反向传播实际处理的 instance 数量,称为 micro_batch_size

  • data parallel size:数据并行组的大小,即同时参与计算的设备数量。

  • gradient accumulation steps:梯度累积步数,即在一次参数更新前累积多少次梯度。

  • global instance count per step :一次参数更新所覆盖的全局 instance 总数:micro_batch_size × data_parallel_size × gradient_accumulation_steps

  • global batch size :若采用 OLMo 当前代码中的口径,则更常指一次参数更新所覆盖的全局 token 总数,即global_batch_size = micro_batch_size × data_parallel_size × gradient_accumulation_steps

例如,当 micro_batch_size = 2data_parallel_size = 2gradient_accumulation_steps = 4sequence_length = 2048 时:

  • 一个 optimizer step 覆盖的全局 instance 数为 2 × 2 × 4 = 16
  • 对应的全局 token 数为 16 × 2048 = 32768
  • 每个 DP rank 在一个 optimizer step 内处理 2 × 4 = 8 个 instance,即 8 × 2048 = 16384 tokens

有了这些基础,我们再来看数据层如何响应训练循环的请求。假设训练配置中设定了:

  • requested_tokens = 1,000,000(本轮总共想消费的 token 数)
  • global_batch_size = 8192(每步的全局 token 数)
  • sequence_length = 2048

那么,每个 step 需要的 instance 数为:

python 复制代码
num_instances_per_step = global_batch_size / sequence_length   # 8192 / 2048 = 4

这意味着,数据 mixture 的构建必须确保最终能组成整数个这样的 step,而不仅仅是 token 总数大致正确。

4.2 source mixture 的配置与约束

从训练配置的角度看,一个 source mixture 通常会提供以下几类信息:

  • 有哪些 source(例如"数学"、"代码"、"网页");
  • 每个 source 对应的具体路径;
  • 每个 source 的目标比例(例如数学占 75%,代码占 25%);
  • 每个 source 最多允许使用自身总 token 的多少比例(max_source_fraction);
  • 每个 source 最多允许重复使用多少次(max_repetition_ratio);
  • 本轮训练总共想消费多少 token(requested_tokens)。

在这些约束下,最终要解决的问题是:

在满足比例要求和上限约束的前提下,这一轮到底应该从每个 source 中选出多少 token?

这个问题并不是简单的"乘比例"就能解决的,因为中间至少会遇到三类约束:

  1. source 总量约束:某个 source 的目标比例再高,它本身也只有那么多 token,不能无限提供。
  2. 重复率约束:如果这一轮想从某个 source 中拿到超过它原始 token 总量的配额,就意味着要重复使用数据,而这又会受到 max_repetition_ratio 的限制。
  3. 训练步长对齐约束:就算 source-level 的 token 配额已经算出来了,最后它仍然要能整理成整数个 fixed-length instances,并且与 global_batch_size、sequence_length 这些训练参数对齐。

因此,source_mixture.py 虽然会接触到具体路径,但它真正关心的并不是"路径字符串",而是这些路径背后承载了多少 token,以及本轮要从中拿走多少 token。

4.3 核心逻辑链路

这个文件的主线非常清晰:先在 source level 记账,再把账继续细化到 path level。系统首先不会立刻去操作每个 path,而是先把每个 source 当成一个整体,统计它的 token population,然后根据目标比例和总预算,给它分配一个"本轮应贡献多少 token"的数值。之后,才继续把这个 source-level 配额分摊到它内部的不同 path 上。

这个设计是很自然的,因为比例、上限、重复率这些约束,本来就是 source 级别的约束,而不是 path 级别的约束。只有先在 source level 把账算清楚,path-level 的分配才有依据。

整个逻辑链可以归纳为:

复制代码
统计每个 source 的 token 总量
↓
按目标比例计算这一轮每个 source 应拿多少 token
↓
检查 source 上限和重复率约束(若不足则按允许重复次数调整)
↓
将 source-level 配额分摊到各 path(按 path 的 token 总量比例)
↓
根据 sequence_length 和 global_batch_size 对 path-level 配额做对齐修正,确保最终能组成整数个 fixed-length instances 和整数个 training steps
↓
输出可交给下游 dataset 的 mixture 结果

4.4 上下游边界

如果把 source_mixture.py 放回整条对象流里看,它的位置非常明确:

在它之前,mixes/__init__.py 做的是:

复制代码
mix → paths + labels

在它之后,numpy_dataset.py 要做的是:

复制代码
token budget / token arrays → instance

那么 source_mixture.py 正好夹在中间,承担的是这一段:

复制代码
paths + labels → source/path-level token budget

所以,它和上下游的边界可以概括为:

  1. mixs/__init__.py 解决"去哪里找数据"
  2. source_mixture.py 解决"这一轮取多少 token"
  3. numpy_dataset.py 解决"取出来的 token 如何变成 instance"

4.5 核心对象结构

这个文件中的核心对象可以划分为三层:

4.5.1 配置层

这些 dataclass 构成了对外 API,用于描述一个或多个 source 的配置,以及构建 mixture dataset 的全局参数。

  • SourceMixtureConfig:单个 source 的配置,包括名称、路径、目标比例、max_source_fractionmax_repetition_ratio 等。
  • SourceMixtureList:一组 SourceMixtureConfig 的集合,代表整个 mixture。
  • SourceMixtureDatasetConfig:全局配置,包含 requested_tokenssequence_lengthglobal_batch_sizeseed 等训练相关参数,以及一个 SourceMixtureList

4.5.2 中间结果层

这些对象在构建过程中被逐步填充,用于记录中间状态。

  • SourceTokenDetails:记录某个 source 的 token 总量、可用的最大 token 数(考虑 max_source_fractionmax_repetition_ratio)等统计信息。
  • SourcePathTokens:记录某个 path 的路径、实际 token 数量、本轮分配的 token 配额。
  • SourceMixtureOutcome:记录一个 source 的最终分配结果,包含名称和一组 SourcePathTokens

4.5.3 结果层

SourceMixtureDataset 是最终被下游 numpy_dataset.py 消费的对象。它内部持有 sources 列表,每个元素是一个 SourceMixtureOutcome。构建完成后,其结构形如:

python 复制代码
SourceMixtureDataset
  sources = [
      SourceMixtureOutcome(
          name="longmino",
          path_tokens=[
              SourcePathTokens(
                  path="s3://ai2-llm/preprocessed/longmino/.../000000.npy",
                  tokens=262144,           # 该 path 本轮分配的 token 数
                  max_tokens=524288        # 该 path 总共可提供的 token 数
              ),
              SourcePathTokens(
                  path="s3://ai2-llm/preprocessed/longmino/.../000001.npy",
                  tokens=262144,
                  max_tokens=262144
              ),
              ...
          ]
      ),
      SourceMixtureOutcome(
          name="code",
          path_tokens=[...]
      ),
      ...
  ]

随后,SourceMixtureDataset 会通过 to_paths()to_index() 等方法,将这个结构转换为 dataset 的索引,供 numpy_dataset.py 使用。

4.6 小结

source_mixture.py 是数据层从"数据描述"向"训练消费"过渡的关键环节。它不再仅仅关心数据在哪里,而是结合训练参数(sequence_lengthglobal_batch_sizerequested_tokens)和 source 级约束(比例、重复率、总量),将抽象的 mix 转化为精确的 path-level token 配额。这个配额最终被下游的 dataset 用来切分出一个个 instance,为训练循环准备好稳定的数据流。

4.7 详细注释

python 复制代码
import logging
import math
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from itertools import chain # itertools用于对可迭代对象做处理,chain则是把多个可迭代对象首尾连接起来,像一个连续序列一样遍历(迭代器的列表)
from typing import Dict, List, Optional, Tuple, cast

import numpy as np
from rich.console import Console
from rich.progress import Progress
from rich.table import Table
from rich.text import Text

from olmo_core.aliases import PathOrStr
from olmo_core.config import Config
from olmo_core.data.types import NumpyUIntTypes
from olmo_core.exceptions import OLMoConfigurationError
from olmo_core.io import deterministic_glob_directory, file_exists, get_file_size

__all__ = [
    "SourceMixtureConfig",
    "SourceMixtureList",
    "SourceMixtureDatasetConfig",
]

log = logging.getLogger(__name__)


@dataclass
class SourceMixtureConfig(Config): # 用于指定对单个source的约束
    """
    Configuration for a single data source within a mixture.

    This class defines how a data source should be sampled and weighted when
    creating a training dataset from multiple sources. It allows control over
    the target proportion, repetition limits, and maximum usage fraction of
    the source data.
    """

    source_name: str # source的名字
    """
    The name of the source.
    """
    target_ratio: float # 当前source占全局的比例,打个比方target_ratio = 0.2 表示当前这个source占整个mixture的20%
    """
    The target ratio of the source in the mixture.
    """
    paths: List[str] # 这里传入的应该是经过解析后的path,或是某种pattern,而不是mixes.__init__中使用的那种带占位符{TOKENIZER}的路径
    """
    A list of paths to the source data.
    这里使用的path更接近于"s3://ai2-llm/preprocessed/dolma3_longmino_0925/allenai/dolma3-tokenizer/*.npy"
    或是某种pattern比如["/weka/.../000000.npy","/weka/.../000001.npy"]
    """
    max_repetition_ratio: float = 1.0 # 当前source最多被允许复用几次,假设source只有100M,这个值设为2的话当前source就可以最多贡献200M
    """
    The maximum ratio of repetitions of the source data to include in the mixture.
    This can be used to upsample the source data by setting the repetition ratio > 1.
    与下面的max_source_fraction共同决定当前source的贡献上限。
    """
    max_source_fraction: float = 1.0 # 当前source下最多允许使用多少数据,假设source只有100M,这个值设为0.2,那么当前source最多贡献20M
    """
    The maximum ratio of the source data to include in the mixture.
    与上面的max_repetition_ratio共同决定当前source的贡献上限。
    """

    _resolved_paths: Optional[List[str]] = None # 尽管没有声明field(init=False),但是这里的作用主要还是未初始化缓存,不要直接传入
    # 更严谨一点的写法是 _resolved_paths: Optional[List[str]] = field(default=None, init=False, repr=False) 避免对外暴露

    def validate(self): # 验证字段的合法性,逻辑很简单,不详细注释
        if self.target_ratio:
            if not 0 < self.target_ratio <= 1:
                raise OLMoConfigurationError("target_ratio must be > 0 and <= 1")
            if not 0 < self.max_source_fraction <= 1:
                raise OLMoConfigurationError("max_source_fraction must > 0 and <= 1")

        if self.max_repetition_ratio < 1:
            raise OLMoConfigurationError("max_repetition_ratio must be >= 1")

        if not self.paths:
            raise OLMoConfigurationError("paths must not be empty")

        if not 0 <= self.max_source_fraction <= 1:
            raise OLMoConfigurationError("max_source_fraction must be in the range [0, 1]")

    @property # 标记为property,从而可以不使用(),因为resolved_paths更像是一个自然属性,所以向字段的使用方式统一
    def resolved_paths(self) -> List[str]:
        """
        Resolve the paths, expanding any globs and validating existence.
        Caches the result after the first access.
        """
        if self._resolved_paths is not None: # 判断是否层解析并缓存过,解析过则直接返回
            return self._resolved_paths

        resolved: List[str] = [] # 带类型注解的赋值语句
        for path in self.paths:
            path_str = str(path)
            if "*" in path_str: # 如果是包含"*"的需要检索的pattern比如dolma3_longmino_0925/allenai/dolma3-tokenizer/*.npy
                matches = deterministic_glob_directory(path_str) # io.py中定义的类似glob.glob(pattern)的检索函数
                if not matches: # 未能匹配到则抛出异常信息
                    error_msg = f"Glob pattern '{path_str}' did not match any files"
                    # Add helpful hint for mix-0625 which has unavailable files
                    if "0625" in path_str:
                        error_msg += ( # OLMo-core v2.4.0 新增,0625的mix可能有些路径不再对外开放或是发生其他变化,尝试使用更新的0925
                            "\n\nNOTE: Some files in OLMo-mix-0625 are not available. "
                            "If you are resuming training from a checkpoint that used mix-0625, you will need to "
                            "switch to a newer mix such as OLMo-mix-0925. To continue training with a different "
                            "dataset mix, set 'ignore_fingerprint_mismatch=True' in your NumpyDataLoaderConfig "
                            "to bypass the fingerprint mismatch error. This will probably result in a different data order!"
                        )
                    raise FileNotFoundError(error_msg)
                resolved.extend(matches) # 匹配到的内容添加到resolved
            else: # 否则直接检查文件是否存在
                if not file_exists(path_str): # 不存在则抛出异常信息
                    error_msg = f"Path '{path_str}' does not exist"
                    # Add helpful hint for mix-0625 which has unavailable files
                    if "0625" in path_str:
                        error_msg += (
                            "\n\nNOTE: Some files in OLMo-mix-0625 are not available. "
                            "If you are resuming training from a checkpoint that used mix-0625, you will need to "
                            "switch to a newer mix such as OLMo-mix-0925. To continue training with a different "
                            "dataset mix, set 'ignore_fingerprint_mismatch=True' in your NumpyDataLoaderConfig "
                            "to bypass the fingerprint mismatch error. This will probably result in a different data order!"
                        )
                    raise FileNotFoundError(error_msg)
                resolved.append(path_str) # 存在则添加

        self._resolved_paths = resolved
        return resolved


@dataclass
class SourceMixtureList(Config): # 多个source的集合,包装一组source,并验证target_ratio总和为 1
    """
    A list of source configurations for building a mixture dataset.
    This class ensures that the target ratios of the sources sum to 1.0.

    The purpose of this class is to make managing sources independent from the details of
    materializing those sources with SourceMixtureDatasetConfig.build().

    With this separation, we can define a list of sources in a YAML file without also needing to
    specify parameters like requested_tokens, global_batch_size, or processes.

    SourceMixtureList只描述source及其目标比例,
    SourceMixtureDatasetConfig才描述requested_tokens/global_batch_size/seed等"具体构建参数"
    """

    sources: List[SourceMixtureConfig] # 用一套单个source的描述构成一个list来解析

    def validate(self): # 验证总和是否为1
        if not self.sources: #不允许为空
            raise OLMoConfigurationError("sources must not be empty")

        summed_weights = np.sum([source.target_ratio for source in self.sources]) # 校验总和

        if not np.allclose(summed_weights, 1.0):
            raise OLMoConfigurationError(f"target_ratios must sum to 1.0, got {summed_weights}")


@dataclass
class SourceTokenDetails: # 一个记账结构,保存source level信息
    """
    A class to hold intermediate selection details for a mixture source.
    """

    config: SourceMixtureConfig # 用于保存source的配置对象
    """
    The configuration object associated with the source.
    """
    population: int # 这个source实际拥有的token总数,值来自_count_tokens_for_paths()
    """
    The total number of tokens available for the source.
    """
    num_selected: int # 在最终dataset中,准备从这个source取多少token,计算方式是int(requested_tokens * target_ratio)
    """
    The number of tokens to select for the source.
    """

    def for_table(self, requested_tokens: int) -> Dict:
        return { # 实际上不参与算法,只用于打印信息
            "source_name": self.config.source_name,
            "source_population": f"{self.population}",
            "num_selected": f"{self.num_selected}",
            "target_ratio": str(self.config.target_ratio),
            "max_repetion_ratio": str(self.config.max_repetition_ratio),
            "max_source_fraction": str(self.config.max_source_fraction),
            "observed_source_ratio": f"{(self.num_selected / self.population):.4}", # 指出了使用了当前source多少比例的数据
            "observed_global_ratio": f"{(self.num_selected / requested_tokens):.4}", # 指出了当前source在最终dataset中的占比
        }


@dataclass
class SourcePathTokens: # 与上面相同,是记账结构用于保存信息,这里是path level信息
    path: str # 某个path下
    tokens: int # 最终允许使用多少token
    max_tokens: int # 最多能够提供多少token


@dataclass
class SourceMixtureOutcome: # 依然是一个记账结构,这里描述的source grouping信息
    name: str # 某个source
    """
    The name of the source.
    """
    path_tokens: List[SourcePathTokens] # 这个source对应哪些path,每个path使用多少token
    """
    A list of paths and the associated token counts.
    """


@dataclass
class SourceMixtureDataset:  # Note: "dataset" naming is a bit inconsistent with the rest of the codebase
    """
    A container for a fractionalized mixture of data sources. Do not construct directly,
    use :class:`SourceMixtureDatasetConfig` instead.

    See also :class:`~olmo_core.data.numpy_dataset.NumpyFSLDatasetMixture`, the downstream
    consumer of this dataset.

    这是最终给下游的数据结构,其本身主要提供两个结构:to_index()和to_paths()
    to_index()的输出是Dict[Tuple[str, int], int],形如:
    {
        (path1, 0): tokens1,
        (path2, 1): tokens2,
        ...
    }
    因为在允许重复采样的情况下(或是其他情况),一个path可能会出现多次,所以除了path以外还需要一个index来标识,
    所以是形如(path1, 0),因为可能path1出现了多次,所以需要区分(path1, 0)和(path1, 1)
    to_paths()的输出则是一个path列表List[PathOrStr],因为to_paths最后返回的是一个list,list本身就具有索引
    比如[path1, path2, path1, path3]这两个path1的索引一个是0一个是2,所以不会搞混。
    """

    sources: List[SourceMixtureOutcome]
    """
    A list of sources and their associated paths and token counts.
    """

    def to_index(self) -> Dict[Tuple[str, int], int]:
        """
        Convert the dataset to an indexed array of dict((int, path), int).
        这里最终的结果是形如:
        {
            ("pathA", 0): 500,
            ("pathB", 1): 300,
            ("pathC", 2): 200,
        }
        """
        return {
            (str(outcome.path), idx): outcome.tokens
            for idx, outcome in enumerate(
                list(chain.from_iterable([outcome.path_tokens for outcome in self.sources]))
            ) # 把path_tokens: List[SourcePathTokens]使用chain串起来再迭代
        }

    def to_paths(self) -> List[PathOrStr]:
        """
        Convert the dataset to a list of paths while maintaining stable ordering.
        """
        return [
            item.path
            for item in list(chain.from_iterable([outcome.path_tokens for outcome in self.sources]))
        ] # 与上面类似,只不过这里是list,并且这里只提供了path,没有提供该path能提供多少token,需要与上面一起用。


@dataclass
class SourceMixtureDatasetConfig(Config): # 核心在这里
    """
    Configuration for building a dataset from a fractionalized mixture of sources.

    This class manages the creation of training datasets by combining multiple data sources
    according to specified target ratios. It handles token counting, source selection,
    and ensures the final mixture meets the requested dataset size while maintaining
    the desired proportions across sources.

    The build process will:
    1. Count available tokens in each source
    2. Calculate token allocations based on target ratios
    3. Validate that sources have sufficient data
    4. Generate a mixture that respects repetition and fraction limits
    """

    source_list: SourceMixtureList # 输入多个source的集合,SourceMixtureList的资源是若干个SourceMixtureConfig组成的list
    """
    也就是说这里的组织方式是对于每个source都有一个SourceMixtureConfig描述
    若干个SourceMixtureConfig组成SourceMixtureList用来描述一组source
    在SourceMixtureDatasetConfig中构建数据
    A list of source configurations contained in a SourceMixtureList.
    """
    requested_tokens: int # 目标总token数
    """
    The desired dataset size, in tokens. This is used to determine the number of tokens to select from each source.
    The total dataset size will be greater than or equal to this value, depending on rounding.
    """
    global_batch_size: int # 按token计的global batch size
    """
    The global batch size for training, in tokens. Used to determine the total number of requested instances.
    """
    processes: int = 1 # 指定并行线程数
    """
    The number of processes to use for counting tokens in parallel.
    """
    seed: int = 42
    """
    The seed used to generate the dataset. Specifically this seed is used when sampling the actual
    instances to use from each source.
    """
    render_tables: bool = True
    """
    Whether to render tables of the mixture outcome.
    """
    quiet: bool = False

    def validate(self): # 验证参数
        if self.requested_tokens <= 0: # 总请求数不能小于0
            raise OLMoConfigurationError("requested_tokens must be > 0")
        self.source_list.validate() # 并且传入的SourceMixtureList验证其各source所占比例的总和也应该为1

    def build(self, *, npdtype: NumpyUIntTypes, sequence_length: int) -> SourceMixtureDataset: # 核心构建
        self.validate() # 上面的验证
        available_tokens_by_source: Dict[str, int] = {} # 每个source的可用token数
        # 打印日志
        log.info("---------------------------------------------------------")
        log.info("Generating a source mixture from configurations:")
        log.info(self.source_list.sources)

        # Count the number of tokens available for each source
        for source_config in self.source_list.sources: # 逐source,逐SourceMixtureConfig统计
            log.info(f"Counting tokens for source: {source_config.source_name}") # 打印当前source name
            available_tokens_by_source[source_config.source_name] = self._count_tokens_for_paths(
                # 计数当前Config下paths包含的tokens
                paths=cast(List[PathOrStr], source_config.resolved_paths),
                source=source_config.source_name,
                npdtype=npdtype,
            )
        # 接下来按target ratio,也就是每个source占全局的比例,计算每个source需要多少token
        tokens_details_by_source: List[SourceTokenDetails] = [] # 这里保存的是source level信息的记账结构
        max_tokens_cap_by_source: Dict[str, int] = {} # source config约束的所能提供的token的上限

        # Calculate the number of tokens available and to include for each source
        for source_config in self.source_list.sources: # 逐source,逐SourceMixtureConfig计算
            num_for_source = available_tokens_by_source[source_config.source_name] # 先把当前source的总数提出来
            needed_for_source = int(self.requested_tokens * source_config.target_ratio) # 计算需要在当前source取多少
            max_for_source = int( # 当前source的最大可贡献token数 = 总数 * config约定的所能提供的最大比例 * 最大重复次数
                (num_for_source * source_config.max_source_fraction)
                * source_config.max_repetition_ratio
            )

            # Ensure that the max tokens for a source meet the target ratio requirement
            if max_for_source < needed_for_source: # 检查当前source最大可贡献的token数大于在当前source取token的数量
                raise OLMoConfigurationError( # 不满足则raise
                    f"Insufficient tokens for source: {source_config.source_name} @ target global ratio: {source_config.target_ratio} :: {max_for_source} < {needed_for_source}"
                )

            max_tokens_cap_by_source[source_config.source_name] = max_for_source # 记录该source所能提供的token数的最大值

            tokens_details_by_source.append(
                SourceTokenDetails( # 在记账结构中记录当前source的config,当前source的总token数,需要在当前source取样的token数
                    config=source_config,
                    population=num_for_source,
                    num_selected=needed_for_source,
                )
            )
        # 接下来按照source-level token的配额分摊到该source下的每个path所指向的文件中
        completed: List[SourceMixtureOutcome] = [] # 记账结构,SourceMixtureOutcome保存的是SourcePathTokens信息
        # SourcePathTokens保存的是path所指向的文件最终使用的token数和最大token数
        tokens_per_path_per_source: Dict[str, List[SourcePathTokens]] = {} # 每个source下的每个path的token数
        for source in tokens_details_by_source: # 逐source处理
            source_path_tokens = self.get_paths_and_tokens_for_source(
                # 传的是config,记账结构和数据类型,返回的也是记账结构,每个结构包括path,使用的token数,最大token数,重复是通过重复记账结构表达的
                source_config=source.config, token_details=source, npdtype=npdtype
            )
            tokens_per_path_per_source[source.config.source_name] = source_path_tokens

        # We adjust the number of tokens per path so that we can complete the desired number
        # of training steps while still retaining the target ratios.
        # 接下来完成从token数到经过向sequence_length对齐的instance数的转换
        # ↓ 先计算训练总步数,这里表示要跑多少step才能至少覆盖请求token数
        training_steps = math.ceil(self.requested_tokens / self.global_batch_size)
        # 比如我们想训练总计1M tokens,global_batch_size设为了8192,那么就至少需要 1M/8192 向上取整个训练步
        assert ( # global_batch_size必须是sequence_length的整数倍,比如 sequence_length = 2048, global_batch_size = 8192
            self.global_batch_size % sequence_length == 0
        ), "global_batch_size must be multiple of sequence_length"
        num_instances_per_batch = self.global_batch_size // sequence_length # 接下来算每个batch需要多少个instance,比如 4 个
        requested_instances = training_steps * num_instances_per_batch # 再按照training steps算总的instance数量

        all_path_tokens: List[SourcePathTokens] = []
        for source_path_tokens in tokens_per_path_per_source.values(): # SourcePathTokens取出来放到list中
            all_path_tokens.extend(source_path_tokens)
        # 对每个path按照整数个instance做截断
        # Calculate base instances and remainders, respecting max_tokens constraint
        int_instances = []
        remainders = []
        for path_token in all_path_tokens: # 对每个path计算
            max_instances = path_token.max_tokens // sequence_length # 当前path最多能提供多少个instance,下取整
            desired_instances = path_token.tokens // sequence_length # 按之前的计算,当前path需要取多少个instance,下取整
            # Don't allocate more than max_instances
            base_instances = min(desired_instances, max_instances) # 取的instance的数量不应超过能提供的最大instance数量
            int_instances.append(base_instances) # 先取一次,把基准数量添加进去,用于下面检查还有哪些path能继续取instance

            # Only include remainder if we have capacity to add more
            if base_instances < max_instances: # 如果没取完,则计算remainder,这是后面Hamilton rounding要用的"小数余量"
                remainder = (path_token.tokens % sequence_length) / sequence_length
            else:
                remainder = 0.0
            remainders.append(remainder) # 这样每个path都持有了对应的instance数量的小数

        # Apply Hamilton's method for rounding, but respect capacity constraints
        # https://mathematics-democracy-institute.org/apportionment/
        # 本质上是在做一种按最大余数分配剩余名额的离散化方法
        # 因为之前分配的path-level token预算除以sequence_length大多不是整数
        # 比如path1 -> 3.8 instances path2 -> 5.2 instances path3 -> 1.6 instances
        # 其小数部分显然无法直接使用,最自然的做法是先取整数部分[path1:3, path 2:5, path 3:1]
        # 再按余数的大小把剩余的instance额度分配给余数更大的路径:0.8 > 0.6 > 0.2,就从path1开始分配instance额度

        additional_instances_needed = requested_instances - sum(int_instances) # 计算第一次取完后还需要取多少instance才满足需求
        if additional_instances_needed > 0: # 还需要继续取
            # Find indices that have capacity for more instances
            # 先找还能取的path
            eligible_indices = [
                idx
                for idx in range(len(int_instances)) # 逐path验证,如果已经取出来的部分小于path最大能提供(下取整)的就还能继续取
                if int_instances[idx] < all_path_tokens[idx].max_tokens // sequence_length
            ]

            if eligible_indices: # 非空的话则继续尝试取
                # Distribute base amount evenly among eligible paths
                # 先均匀的取,再按remainder分配leftover
                base, leftover = divmod(additional_instances_needed, len(eligible_indices)) # divmod计算商(下取整)和余数
                # 打个比方还需要取12个instance,然后eligible_indices有5个,那么based = 2, leftover = 0.4
                if base: # 大于0
                    for idx in eligible_indices:
                        max_instances = all_path_tokens[idx].max_tokens // sequence_length # 先算最大能提供的
                        # Add base amount but don't exceed max capacity
                        can_add = min(base, max_instances - int_instances[idx]) # 因为一开始取了一次,最大能提供的要去掉基准值
                        # 取基准值和能提供值中更小的那个,添加到实际取的数量中
                        int_instances[idx] += can_add

                # Recalculate how many we still need after base distribution
                # 再算一下还需要取多少
                additional_instances_needed = requested_instances - sum(int_instances)

                # Distribute remaining by largest remainders (Hamilton's method)
                if additional_instances_needed > 0: # 如果还需要继续取
                    # Only consider paths that still have capacity
                    candidates_with_remainders = [ # 先检查还有哪些path能继续取
                        (remainders[idx], idx)
                        for idx in eligible_indices # 因为上面int_instances[idx]更新过了,这里继续检查就行
                        if int_instances[idx] < all_path_tokens[idx].max_tokens // sequence_length
                    ]
                    candidates_with_remainders.sort(reverse=True) # 从大到小排序

                    for _, idx in candidates_with_remainders[:additional_instances_needed]: # 直到满足数量要求为止
                        int_instances[idx] += 1 # 每个能继续取的path都取一次

        final_tokens_per_path = [inst * sequence_length for inst in int_instances] # 对于每个path,重新转回token数量

        # Update the path_token objects with final token counts
        # 接下来生成最终结果并计算误差
        idx = 0
        for source_path_tokens in tokens_per_path_per_source.values():
            for path_token in source_path_tokens: # 更新每个path下使用的token数量
                path_token.tokens = final_tokens_per_path[idx]
                idx += 1

        final_token_distribution: Dict[str, float] = {}
        for source_name, source_path_tokens in tokens_per_path_per_source.items():
            completed.append( # 按source更新最终结果
                SourceMixtureOutcome(
                    name=source_name, # 每个source下面都对应若干source_path_tokens,source_path_tokens在上面被更新过了
                    path_tokens=source_path_tokens,
                )
            ) # ↓ 记录每个source使用的token数量
            final_token_distribution[source_name] = sum(path.tokens for path in source_path_tokens)

        total_tokens = sum(final_token_distribution.values()) # 对所有source求和就可以计算出总共使用了多少tokens
        final_token_distribution = { # 更新成source name : 比例 的样式,用来描述source的分布
            k: v / total_tokens for k, v in final_token_distribution.items()
        }

        if self.render_tables: # 按需绘制表格,绘制的是两张表,一张是source-level,另一张是global-level
            self.render_mixture_outcome_tables(tokens_details_by_source)

        for outcome in completed:
            for item in outcome.path_tokens: # 打印日志
                log.info(f"Selected {item.tokens} tokens from {outcome.name} at {item.path}")

        token_difference = total_tokens - self.requested_tokens # 计算多取了多少
        percent_difference = (token_difference / self.requested_tokens) * 100 # 计算多取的百分比
        log.info( # 打印日志
            f"Total tokens in mixture: {total_tokens} "
            f"(requested: {self.requested_tokens}, diff: {token_difference:+} tokens, "
            f"{percent_difference:+.2f}%)"
        )

        original_token_distribution = { # 最初的source分布,具体为一个 source name : 比例 的dict
            source_config.source_name: source_config.target_ratio
            for source_config in self.source_list.sources
        }
        for source_name, ratio in original_token_distribution.items():
            diff = np.abs(final_token_distribution.get(source_name, 0) - ratio) # 计算分布误差
            log.info(f"{source_name}: {diff:.4f} difference from target ratio {ratio:.4f}") # 打印日志

        return SourceMixtureDataset(sources=completed) # 之后就可以to_index和to_path喂给下游了,并且都向sequence length对齐过了

    def get_paths_and_tokens_for_source(
        self,
        source_config: SourceMixtureConfig,
        token_details: SourceTokenDetails,
        npdtype: NumpyUIntTypes,
    ) -> List[SourcePathTokens]:
        """
        Get the paths and resulting token count for a source.
        """
        take_ratio = token_details.num_selected / token_details.population # 当前source的取样数 / 当前source的总token数 = 取样比例
        path_tokens: List[SourcePathTokens] = [] # 保存结果的缓存

        resolved_paths = source_config.resolved_paths # 获取source下面对应的paths
        token_counts_by_path = { # 构建 path : 总token数 的字典
            path: self._count_tokens_for_file(path, npdtype) for path in resolved_paths
        }

        # When we need more than 1 repetition of the source data we have a take ration > 1
        if take_ratio > 1: # 取样比例大于1证明需要重复
            take_ratios = []
            remaining = take_ratio
            # 打个take_ratios=2.3,最后会变成take_ratios = [1.0, 1.0, 0.3]
            while remaining > 0: # 如果还需要继续取
                chunk = min(1.0, remaining) # 切出1份
                take_ratios.append(chunk) # 添加到list中
                remaining -= chunk # 去除被切掉chunk

            for ratio in take_ratios: # 正式开始按比例取
                for path in resolved_paths: # 对每个path都使用相同的比例
                    available_tokens = token_counts_by_path[path] # 取出当前path的总token数(也是最大token数)
                    tokens_to_keep = int(math.ceil(available_tokens * ratio)) # 向上取整
                    path_tokens.append( # 添加记账结构,记录当前path,当前path最终使用的token数,当前path的总token数
                        SourcePathTokens(
                            path=path,
                            tokens=tokens_to_keep,
                            max_tokens=available_tokens,
                        )
                    )

            return path_tokens

        for path in resolved_paths: # 小于1的情况则不需要重复,直接按比例取就可以了
            available_tokens = token_counts_by_path[path]
            tokens_to_keep = int(math.ceil(available_tokens * take_ratio)) # 同样的按比例计算取多少向上取整
            path_tokens.append(
                SourcePathTokens(
                    path=path,
                    tokens=tokens_to_keep,
                    max_tokens=available_tokens,
                )
            )

        return path_tokens

    def _count_tokens_for_paths(
        self, paths: List[PathOrStr], source: Optional[str], npdtype: NumpyUIntTypes
    ) -> int:
        """
        Count the number of tokens for a set of source files in parallel.

        Args:
            source_config: The source configuration.
            dtype: The data type of the source tokens.
        """

        with ThreadPoolExecutor(max_workers=self.processes) as executor: # 按指定线程数初始化executor
            futures = []
            for path in paths: # 一个source下有若干paths,每个path对应一个.npy格式的文件
                futures.append( # 每个executor都根据文件进行计数
                    executor.submit(self._count_tokens_for_file, path=path, npdtype=npdtype)
                )

            with Progress(disable=self.quiet) as progress: # rich的进度条
                results = []
                task = progress.add_task(
                    f"Counting available tokens for source: {source}", total=len(futures)
                )
                for future in as_completed(futures):
                    progress.update(task, advance=1)
                    results.append(future.result())

            return sum(results) # 都执行完后一求和就知道当前SourceMixtureConfig所指向的paths总共包含多少tokens了

    def _count_tokens_for_file(self, path: PathOrStr, npdtype: NumpyUIntTypes) -> int:
        return self._bytes_to_tokens(get_file_size(path), npdtype=npdtype) # 获取文件大小,按数据格式和文件大小计算token数

    def _bytes_to_tokens(self, num_bytes: int, npdtype: NumpyUIntTypes) -> int:
        """
        Convert bytes to tokens based on the dtype.
        """
        return num_bytes // npdtype(int(0)).itemsize #.npy存的就是定长整数token ID,所以直接一除就得到token数量了

    def render_mixture_outcome_tables(self, results: List[SourceTokenDetails]) -> None:
        """
        两张表,一张是source-level的表,一张是global-level的表
        Render tables enumerating the global and per-source mixture outcomes.
        """
        source_rows = [item.for_table(self.requested_tokens) for item in results] # for_table只需要传requested_tokens
        source_headers = source_rows[0].keys() # key取出来作为表头

        source_table = Table(title="Outcome by source") # rich的table
        for header in source_headers:
            source_table.add_column(header)

        for row in source_rows:
            source_table.add_row(*[row[header] for header in source_headers])

        log.info(self.table_to_text(source_table))

        total_tokens = sum([item.population for item in results])
        selected_tokens = sum([item.num_selected for item in results])
        observed_global_ratio = f"{(selected_tokens / total_tokens):.4}"

        global_table = Table(title="Global outcome")
        global_headers = [
            "total_tokens",
            "selected_tokens",
            "observed_global_ratio",
        ]

        for header in global_headers:
            global_table.add_column(header)

        global_table.add_row(f"{total_tokens:.2e}", f"{selected_tokens:.2e}", observed_global_ratio)
        log.info(self.table_to_text(global_table))

    def table_to_text(self, table: Table) -> Text:
        """Generate an ascii formatted presentation of a Rich table
        Eliminates column styling
        """
        console = Console(width=250)
        with console.capture() as capture:
            table.width = 250
            console.print(table)

        return Text.from_ansi(capture.get())

5.numpy_dataset.py:从 token 配额到训练实例

numpy_dataset.py 的核心任务,是将上一章 source_mixture.py 计算好的 token 配额转化为可直接供模型训练的样本。换句话说,它负责把分配好的 token 数据按训练要求切分成训练实例(instance),最终交给下游的模型进行训练。如果说 source_mixture.py 是"预算规划层",那么 numpy_dataset.py 就是"样本物化层"其真正地从磁盘读取 token 序列,并按照固定长度(FSL)或其他策略构造出一个个 instance。

5.1 训练实例的基本构造

在讨论 numpy_dataset.py 之前,先回顾几个与训练实例相关的基本概念,这些概念会直接影响到数据如何被切分和组织:

  1. Instance(样本/实例):训练数据的最小单位,通常长度固定为 sequence_length(例如 2048 个 token)。
  2. Micro-batch:单个设备上一次前向/反向传播实际处理的 instance 数量,称为 micro_batch_size
  3. Global batch size:一次参数更新(optimizer.step())所覆盖的全局 instance 总数,由 micro-batch、数据并行度(data_parallel_size)和梯度累积步数(gradient_accumulation_steps)共同决定:global_batch_size = micro_batch_size × data_parallel_size × gradient_accumulation_steps

从 token 视角看,global_batch_size × sequence_length 就是一次更新消费的 token 总数。这些超参数决定了训练循环对数据的需求粒度,而 numpy_dataset.py 的工作就是按照这个粒度,从底层的 token 文件中取出固定长度的 chunk,作为 instance 返回。

5.2 从 token 配额到实例的转换

source_mixture.py 里,系统已经知道了每个 source 和 path 应该贡献多少 token(即 path-level token budget)。然而,模型并不能直接消费这些配额,还需要将它们拆分成具体的 instance。

这个转换过程主要回答两个问题:

  • 每个 instance 长什么样:它由多少个 token 组成?在预训练中通常是固定的 sequence_length
  • 如何将 token 配额切分成实例:给定一个 path 的 token 配额(例如 262144 个 token),如何得到整数个固定长度的 instance(例如 262144 / 2048 = 128 个)?如果配额不是 sequence_length 的整数倍,余数部分如何处理(丢弃或保留)?

在FSL策略中,不同的 dataset 实现给出了不同的答案,但总体思路是一致的:将每个 path 视为一个逻辑上的 token 流,按固定长度切割,得到一系列 chunk,每个 chunk 就是一个 instance。对于 NumpyFSLDataset,这是通过文件真实大小和 sequence_length 直接计算的;对于 NumpyFSLDatasetMixture,则是用 source_mixture.py 提供的配额代替真实文件大小。

尽管FSL并不是预训练阶段的唯一选择,但是我们目前先着眼于FSL。

5.3 处理多个数据源

在大规模训练中,通常会混合来自不同 source 的数据,这些数据可能具有不同的大小、结构和分布。numpy_dataset.py 通过以下方式处理多源数据:

  • 多 source 合并:不同 source 的数据会被合并成一个大的逻辑数据池。对于 NumpyFSLDatasetMixture,它接收 SourceMixtureDataset.to_paths() 返回的路径列表(可能包含重复的 path),并通过 path_offset_index 指定每个 (path, idx) 条目在本轮可贡献的 token 配额。这个配额会影响该 path 在全局 FSL 索引空间中能提供多少个 fixed-length chunk。
  • 路径管理:每个 source 对应多个数据路径,这些路径中的数据会按配额进行读取和切分。对于 NumpyFSLDatasetMixture,它重写了 _get_file_size_and_length,用虚拟的文件大小(配额)替代真实文件大小,从而让父类的索引计算逻辑透明地适配混合预算。

5.4 numpy_dataset.py 的核心数据结构

numpy_dataset.py 提供了五种不同的 dataset 实现,它们分别对应五种不同的 document-to-instance 构造策略。所有实现都继承自 NumpyDatasetBaseNumpyFSLDatasetBase,并实现了 __len____getitem__ 方法。

5.4.1 NumpyFSLDataset

这是最基础、吞吐优先的实现。它将所有输入路径中的 document 视为一个连续的 token 流,不显式考虑文档边界,直接按照 sequence_length 切分为定长 chunk。每个 chunk 就是一个 instance。

  • 核心逻辑:将每个文件单独计算可提供的 chunk 数(file_size / (dtype_size * sequence_length)),再将这些 chunk 按文件顺序串联成一个全局 instance 索引空间。__getitem__ 通过全局索引找到对应的文件和局部 chunk 索引,然后从文件中读取对应[start, start+sequence_length]的 slice。
  • 适用场景:追求最高吞吐,不关心文档边界,适合绝大部分预训练场景。

5.4.2 NumpyFSLDatasetMixture

它在 FSL 的基础上接入了来自 SourceMixtureDatasetConfig 的 token 配额控制。与普通 NumpyFSLDataset 不同,这一实现不再单纯按文件真实大小决定某个 path 能提供多少个 chunk,而是先接受上游 mixture 规划好的预算,再据此限制各 path 在本轮可提供的 fixed-length instances。

  • 核心逻辑:继承 NumpyFSLDataset,但重写了 _get_file_size_and_length,用 _path_offset_index 中记录的 token 配额替代真实文件大小。这样,每个 path 能提供的 chunk 数就不再由磁盘上的实际 token 数决定,而是由 source mixture 的预算决定。它同时还会生成 document 级别的索引缓存(segment_documents_into_instances),但目前读取路径仍复用父类的 _read_chunk_from_array,因此实际 instance 仍是固定长度的 chunk。
  • 适用场景:需要精确控制每个 source 的 token 比例和重复次数,且希望复用 FSL 的高效索引机制。

5.4.3 NumpyPaddedFSLDataset

它的策略是"一个 document 对应一个 instance"。每个 document 被单独构造成一个定长序列,不足部分通过 padding 补齐,超过长度则进行截断。

  • 核心逻辑:在 prepare 中调用 segment_documents_into_instances,为每个 document 生成 (start, end) 边界索引并缓存。__getitem__ 时,先通过索引文件找到 document 的边界,再读取对应的 token slice,最后进行 padding 对齐到 sequence_length
  • 适用场景:需要保持文档完整性,但对吞吐要求不高(padding 较多)。

5.4.4 NumpyPackedFSLDataset

它采用 Optimized Best-Fit Decreasing (OBFD) 算法,将多个 document 打包进一个 fixed-length instance 中,以尽量减少 padding 和 truncation。

  • 核心逻辑:将多个 document 按长度分组,用 OBFD 算法决定哪些 document 应该放入同一个 instance,然后生成三份缓存文件:document-indices(记录每个 document 的边界)、instance-offsets(记录每个 instance 在文档列表中的切片区间)、docs-by-instance(扁平化的文档 ID 列表)。__getitem__ 时,通过这三层索引还原出当前 instance 包含的所有 document 的 token span,拼接后补齐到 sequence_length。
  • 适用场景:希望兼顾 token 利用率和文档边界,适合需要减少 padding 浪费的训练。

这里我梳理一下后面代码中的逻辑,防止具体阅读时混乱:

复制代码
NumpyPackedFSLDataset 的核心目标是:给定一组 path 指向的 token 文件,
如何将其中的文档(document)重新组织成固定长度的 instance,使得 padding 最少。
这是一个典型的"装箱"问题:箱子容量为 sequence_length,货物是各个文档的 token 长度,
我们要用最少的箱子(即最少的 padding)装下所有文档。

首先需要处理文档可能超过 sequence_length 的情况。有两种策略:

截断(truncate):只取文档的前 sequence_length 个 token,丢弃后面的部分。
切分(fragment):将长文档按 sequence_length 切成多个片段,每个片段视为一个独立文档。

经过处理后,我们得到一个文档列表,每个文档由其起始和结束位置表示:
[start_0, end_0, start_1, end_1, ...]。可以将其 reshape 为:

text
[
  [start_0, end_0],  # 文档 0
  [start_1, end_1],  # 文档 1
  [start_2, end_2],  # 文档 2
  ...
]

现在问题转化为:有一批"货物"(文档),每个货物的"体积"为 end_i - start_i,
需要放入若干个容量为 sequence_length 的"箱子"(instance)中,并尽可能减少箱子数量(即减少 padding)。
OLMo 3 使用了一种基于线段树的 Optimized Best-Fit Decreasing(OBFD)算法来实现这一打包过程。

打包算法 pack_documents_into_instances 输入文档边界数组,输出三个结果:

instances:一个列表,每个元素是一个文档 ID 列表,表示该 instance 包含哪些文档。例如:

instances = [[0, 3], [1], [2, 4, 5]]
表示 instance 0 由文档 0 和文档 3 组成,
instance 1 由文档 1 组成,
instance 2 由文档 2、4、5 组成。

document_indices:一个二维数组,记录每个文档在原始 token 流中的实际边界。
注意这里的文档 ID 与上面的 instances 中的 ID 一一对应。例如:

document_indices = [
  [260, 400],  # doc_id = 0
  [150, 260],  # doc_id = 1
  [0, 100],    # doc_id = 2
  [100, 150],  # doc_id = 3
  [400, 450],  # doc_id = 4
  [450, 500],  # doc_id = 5
]

这里 document_indices[0] 就是文档 0 的起止位置,document_indices[3] 是文档 3 的起止位置,依此类推。

total_tokens:打包前所有文档的总 token 数。

由于 NumPy 不支持二维数组直接存储 instances 这种变长结构,
需要将 instances 编码为两个一维数组,便于缓存到磁盘:

docs_by_instance:将 instances 展平为一个长数组,即按顺序列出每个 instance 包含的文档 ID。
对于上面的例子:docs_by_instance = [0, 3, 1, 2, 4, 5]

instance_offsets:记录每个 instance 在 docs_by_instance 中的区间范围。
每两个元素表示一个区间 [start, end)。
对于上面的例子:instance_offsets = [0, 2, 2, 3, 3, 6]

这意味着:
instance 0 对应 docs_by_instance[0:2] → [0, 3]
instance 1 对应 docs_by_instance[2:3] → [1]
instance 2 对应 docs_by_instance[3:6] → [2, 4, 5]

这样,在 __getitem__ 中就可以通过 instance_offsets 和 docs_by_instance 
还原出 instance 包含的文档 ID,再通过 document_indices 得到每个文档的全局边界,
最后定位到具体的文件和文件内偏移,读取 token 数据并拼接成最终的 instance。

具体恢复过程如下:
给定 instance 索引 i,从 instance_offsets 中取出区间:

start, end = instance_offsets[2*i : 2*i+2]
从 docs_by_instance 中取出该区间对应的文档 ID 列表:

doc_ids = docs_by_instance[start:end]
对每个 doc_id,通过 document_indices[doc_id] 得到文档在全局 token 流中的起止位置。

由于全局 token 流是多个文件首尾相连构成的,需要判断该文档属于哪个文件。
可以通过比较全局边界与每个文件的起始偏移来确定。
确定文件后,计算文档在文件内的偏移(全局边界减去文件起始偏移),
然后从对应文件中读取 token slice。
将所有文档的 token 拼接,如果总长度不足 sequence_length,则用 pad_token_id 填充。
这套机制使得 NumpyPackedFSLDataset 能够在 instance 内部高效复用文档,
同时将 packing 结果持久化,避免重复计算。

5.4.5 NumpyInterleavedFSLDataset

它继承自 NumpyPaddedFSLDataset,但进一步将多个 document 的 chunk 进行 token 级别的交错(interleave),从而构造出跨文档深度混合的 instance。

  • 核心逻辑:先按 NumpyPaddedFSLDataset 的方式为每个 document 生成一个 instance,然后根据配置将某些 document 标记为 exempt(不参与 interleaving)或 interleavable。对于 interleavable 的 document,会随机打乱顺序,并分组为 docs_per_instance个 document 一组。在每个组内,从每个 document 中截取 sequence_length / docs_per_instance 个 token,然后按 token 级交错(round-robin)拼接成最终的 instance,最后再添加 BOS/EOS 和 padding。

  • 适用场景:希望降低模型对文档边界偏置的敏感性,让实例看起来更像 token-level 的流式数据。

    doc1: [a1 a2 a3 a4]
    doc2: [b1 b2 b3 b4]
    doc3: [c1 c2 c3 c4]
    chunk → interleave:
    [a1 b1 c1 a2 b2 c2 a3 b3 c3 ...]
    这里的a1 a2 a3 a4都是一个个的token chunk,是块级别的拼接

5.5 小结

numpy_dataset.py 是数据层从"预算"到"实例"的关键转换点。它不再关心数据从哪里来、预算如何分配,而是专注于如何从 token 文件中高效地取出固定长度的 chunk,并以不同方式组织成 instance。通过提供五种不同的 dataset 实现,它在吞吐、文档完整性、token 利用率之间提供了灵活的权衡,为预训练循环准备了可以直接消费的样本。

5.6 详细注释

详细内容见代码注释:

python 复制代码
from __future__ import annotations

import concurrent.futures # layer 1中提到的并发执行框架,Future一个"未来才会得到结果"的对象,也就是前面提到的future对象
import hashlib
import logging
import math
import os
import random
import tempfile
from abc import ABC, abstractmethod # 抽象类和抽象方法
from copy import deepcopy
from dataclasses import dataclass
from functools import partial # 偏函数,也就是固定一个函数的若干参数以得到一个新函数,这样就可以避免反复填参数
from pathlib import Path
from typing import (
    Any,
    Callable,
    Dict,
    List,
    Literal,
    Optional,
    Sequence,
    Tuple,
    Type,
    TypeVar,
    Union,
    cast,
)

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset # 只要实现了__len__()和__getitem__()就可以被DataLoader当成数据集来迭代,此处的类大多都在实现这个

from olmo_core.exceptions import OLMoConfigurationError, OLMoEnvironmentError

from ..aliases import PathOrStr
from ..config import Config, StrEnum
from ..distributed.utils import barrier, get_fs_local_rank # 分布式的同步屏障与file system的local rank,后面看到分布式具体讲
from ..io import (
    _get_s3_client,
    deterministic_glob_directory,
    get_file_size,
    is_url,
    normalize_path,
)
from .mixes import DataMix, DataMixBase
from .source_mixture import SourceMixtureDatasetConfig # 详见上一个文件
from .tokenizer import TokenizerConfig
from .types import LongDocStrategy, NumpyDatasetDType, NumpyUIntTypes
from .utils import ( # 这里用到的工具函数我放到附录中,功能就是函数名,具体实现按需查阅
    bucket_documents,
    chunk_array,
    chunked,
    divide_into_buckets,
    find_periodic_sequences,
    get_doc_lengths_from_indices,
    get_document_lengths,
    get_rng,
    load_array_slice_into_tensor,
    memmap_to_write,
    pack_documents_into_instances,
    run_worker_func,
    segment_documents_into_instances,
    write_array_to_disk,
)

__all__ = [
    "NumpyDatasetBase",
    "NumpyFSLDatasetBase",
    "NumpyFSLDataset",
    "NumpyFSLDatasetMixture",
    "NumpyPaddedFSLDataset",
    "NumpyPackedFSLDataset",
    "NumpyInterleavedFSLDataset",
    "VSLCurriculum",
    "VSLNaturalCurriculum",
    "VSLGrowthCurriculum",
    "VSLGrowP2Curriculum",
    "VSLGrowLinearCurriculum",
    "NumpyVSLDataset",
    "NumpyDatasetConfig",
    "NumpyFSLDatasetConfig",
    "NumpyPaddedFSLDatasetConfig",
    "NumpyPackedFSLDatasetConfig",
    "NumpyInterleavedFSLDatasetConfig",
    "NumpyVSLDatasetConfig",
    "VSLCurriculumType",
    "VSLCurriculumConfig",
]


log = logging.getLogger(__name__)


T = TypeVar("T")


@dataclass
class InstanceFilterConfig(Config): # instance级别的过滤规则参数,用于检测样本中的重复以过滤坏样本,原因可能是数据清洗不彻底、packing等问题
    repetition_max_period: int = 13 # 最大周期长度 超过这个长度的模式就不再检查
    repetition_min_period: int = 1 # 最小周期长度 设为1表示形如AAAAA的重复也要检查,设为2的话就是从ABABABAB开始检查
    repetition_max_count: int = 32 # 允许重复多少次,达到这个次数就认为是坏样本


class NumpyDatasetBase(ABC):
    """
    An abstract base class for datasets backed by numpy arrays on disk of token IDs.

    In general the instances that these datasets produce are sequences of token IDs from one
    or more numpy arrays, sometimes with additional metadata attached.
    The way those instances are formed depends on the implementation details of the subclass.

    .. warning::
        When using :class:`NumpyDatasetBase` implementations in a distributed setting be sure
        that the :data:`work_dir` is shared among all local ranks and :data:`fs_local_rank` is set
        accordingly. Once those fields are set you should then call :meth:`prepare()` in the
        main process before doing anything else.

    .. tip::
        Use the dataset config helpers (e.g. :class:`NumpyFSLDatasetConfig`) to configure and
        construct datasets instead of constructing them directly.
    """

    def __init__( # 接收并保存底层的公共状态
        self,
        *paths: PathOrStr, # 通过 * 接收可变数量的参数,并强制后面的参数为关键字传参
        pad_token_id: int,
        eos_token_id: int,
        vocab_size: int,
        dtype: NumpyUIntTypes = np.uint16,
        bos_token_id: Optional[int] = None,
    ):
        if not paths:
            raise OLMoConfigurationError("At least one path is required")

        self._array_paths = tuple(paths) # 用于收集参数的paths本身就是tuple,这么写可能是明确类型或是防御性编程
        self._pad_token_id = pad_token_id
        self._eos_token_id = eos_token_id
        self._bos_token_id = bos_token_id
        self._vocab_size = vocab_size
        self._dtype = dtype
        self._fs_local_rank = get_fs_local_rank()
        self._work_dir: Optional[Path] = None # 可选的临时工作路径
        self._work_dir_set = False # 明确这个临时工作路径是否是用户配置过的
        self._array_file_sizes: Optional[Tuple[int, ...]] = None # 用于缓存文件大小的lazy cache

    @property
    @abstractmethod
    def max_sequence_length(self) -> int: # 在最终实例化之前必须实现的抽象方法,注意不是返回当前样本长度,而是控制当前dataset生成instance的最大长度
        """
        The maximum sequence length of any instances generated by this dataset.
        """
        raise NotImplementedError

    @property
    def paths(self) -> Tuple[PathOrStr, ...]: # 私有字段封装,把"内部存储"与"对外语义接口"分开,对外提供稳定接口、逻辑插入、验证、防修改等
        """
        Paths and/or URLs to the numpy arrays.
        """
        return self._array_paths # 作为索引主线

    @property
    def file_sizes(self) -> Tuple[int, ...]: # 私有字段封装,与lazy cache
        """
        The size, in bytes, of each numpy array.
        """
        if self._array_file_sizes is None: # 未缓存则缓存, lambda匿名函数,map的作用是对每个path调用一个func,且约定func需要两个参数
            self._array_file_sizes = tuple(self.map(lambda path, _: get_file_size(path))) # 所以这里除了还指定了 _
        return self._array_file_sizes

    @property
    def pad_token_id(self) -> int: # 私有字段封装
        return self._pad_token_id

    @property
    def eos_token_id(self) -> int: # 私有字段封装
        return self._eos_token_id

    @property
    def bos_token_id(self) -> Optional[int]: # 私有字段封装
        return self._bos_token_id

    @property
    def vocab_size(self) -> int: # 私有字段封装
        return self._vocab_size

    @property
    def dtype( # 私有字段封装
        self,
    ) -> NumpyUIntTypes:
        """
        The numpy datatype of the arrays.
        """
        return self._dtype

    @property
    def fingerprint_version(self) -> str: # 明确当前fingerprint协议
        """
        The version of the :data:`fingerprint`.
        """
        return "v2.0"

    @property
    def fingerprint_fields(self) -> Tuple[str, ...]: # 明确哪些字段用于构建内容签名
        """
        Extra values to include when calculating the data contents :data:`fingerprint`.
        """
        return ("vocab_size", "pad_token_id", "eos_token_id", "dtype", "bos_token_id")

    @property
    def fingerprint(self) -> str: # 使用会影响内容解析的字段来构造内容签名
        """
        Used to compare the contents of a dataset.
        """
        sha256_hash = hashlib.sha256()
        sha256_hash.update(f"class={self.__class__.__name__}".encode()) # 包含类名
        for field_name in self.fingerprint_fields: # fingerprint_fields明确的字段
            field_value = getattr(self, field_name)
            sha256_hash.update(f"{field_name}={field_value},".encode())
        for path, size in zip(self.paths, self.file_sizes): # 使用path和size来代替文档内容进行hash
            sha256_hash.update(f"path={os.path.basename(path)},size={size},".encode()) # 依赖basename而不是绝对路径
        return sha256_hash.hexdigest()

    @property
    def fs_local_rank(self) -> int: # 私有字段封装
        return self._fs_local_rank

    @fs_local_rank.setter # 指明fs_local_rank作为property被修改时应调用的函数
    def fs_local_rank(self, fs_local_rank: int):
        self._fs_local_rank = fs_local_rank

    @property
    def work_dir(self) -> Path:
        if self._work_dir is not None: # 用户设置了则使用用户设置的临时路径
            return self._work_dir
        else: # 没有则使用默认路径
            return Path(tempfile.gettempdir())

    @work_dir.setter # 因为work_dir是property,当需要对work_dir封装的_work_dir进行修改时会调用由@work_dir.setter装饰的函数
    def work_dir(self, work_dir: PathOrStr):
        if is_url(work_dir):
            raise OLMoConfigurationError(
                f"'work_dir' should be a local path, not a URL ('{work_dir}')."
            )
        self._work_dir = Path(normalize_path(work_dir))
        self._work_dir_set = True # 明确设置了临时工作路径的话就要将明确设置置为True

    @property
    def work_dir_set(self) -> bool: # 用于检查临时工作路径是不是明确设置的
        """
        Check if the working directory was explicitly set.
        """
        return self._work_dir_set

    @property
    def num_tokens(self) -> int: # 交由子类具体实现,用于统计当前dataset中有多少tokens
        """
        Get the total number of tokens in the dataset.
        """
        raise NotImplementedError

    def _get_file_size(self, path: PathOrStr): # 获取path对应文件的大小
        path_idx = self.paths.index(path) # 取对应path的索引
        return self.file_sizes[path_idx] #  按索引取文件大小

    def _warmup_clients(self): # 如果路径里有远程存储(s3/r2/weka)就提前初始化client
        # Maybe create client up front to work around a threading issue in boto.
        if any(str(p).startswith("s3://") for p in self.paths):
            _get_s3_client("s3")

        if any(str(p).startswith("r2://") for p in self.paths):
            try:
                _get_s3_client("r2")
            except OLMoEnvironmentError:
                # R2 might not be needed, so ignore this error. We will get an error
                # later if R2 is needed.
                pass

        if any(str(p).startswith("weka://") for p in self.paths):
            try:
                _get_s3_client("weka")
            except OLMoEnvironmentError:
                # Weka might not be needed, so ignore this error. We will get an error
                # later if Weka is needed.
                pass

    def map( # 对dataset的每个path执行同一个函数,并按原顺序收集结果。许多任务比如获取每个path的大小,统计每个path的token数,都需要map到path
        self,
        func: Callable[[PathOrStr, int], T], # 一个输入是PathOrStr和int的函数,也就是以path和索引为输入
        *,
        max_workers: Optional[int] = None,
        method: Literal["threads", "processes"] = "threads",
        _paths: Optional[Sequence[PathOrStr]] = None, # 可以另外指定path而不是仅仅是使用默认的当前dataset保存的path
    ) -> List[T]:
        """
        Call a function on each path in the dataset, returning a list of the results, in order.

        :param func: The function to map to the paths and their indices.
        :param max_workers: The number of workers threads/processes. Set to 0 to execute synchronously
            in the main thread/process.
        :param method: Whether to use multi-threading or multi-processing.

        :returns: The results, in the same order as :data:`paths`.
        """
        paths = _paths or self.paths # 要么指定了路径,要么就使用默认的自身的路径,优先前者

        if max_workers == 0: # 关闭并行,逐一处理即可
            return [func(path, idx) for idx, path in enumerate(paths)]

        executor_class: Union[
            Type[concurrent.futures.ThreadPoolExecutor],
            Type[concurrent.futures.ProcessPoolExecutor],
        ]
        if method == "threads":
            self._warmup_clients()
            executor_class = concurrent.futures.ThreadPoolExecutor
        elif method == "processes":
            executor_class = concurrent.futures.ProcessPoolExecutor
        else:
            raise ValueError(method)

        with executor_class(max_workers=max_workers) as executor:
            futures = [executor.submit(func, path, idx) for idx, path in enumerate(paths)]

        return [future.result() for future in futures] # 使用的是future占位,所以还是原来的顺序


    def prepare(self): # 不要求子类强制实现,但允许子类在真正开始被DataLoader使用前,做一次显式预处理
        """
        Perform any necessary preparation.

        .. warning::
            Be sure to set :data:`work_dir` properly before calling this and only call this from the
            main process (not a worker process).
        """
        pass

    @abstractmethod # 一个类中,不管是本身还是父类还是父类的父类,只要还存在未实现的abstractmethod,就不能被实例化
    def __len__(self) -> int: # pytorch dataclass的接口,在最终实例化之前必须实现的抽象方法,表明表示这个dataset里有多少个instance。
        """
        Get the number of instances in the dataset.
        """
        raise NotImplementedError

    @abstractmethod
    def __getitem__(self, index: int) -> Dict[str, Any]: # pytorch dataclass的接口,返回第index个instance
        """
        Get an instance from the dataset. At a minimum this will contain the field "input_ids", a
        integer tensor of token IDs.
        """
        raise NotImplementedError

    def _validate_instance( # 使用InstanceFilterConfig来验证instance是否符合约束
        self, input_ids: torch.Tensor, instance_filter_config: InstanceFilterConfig
    ) -> bool:
        for m in find_periodic_sequences( # find_periodic_sequences是一个generator,详见附录
            input_ids.numpy(), # 转回numpy做验证
            max_period=instance_filter_config.repetition_max_period,
            min_period=instance_filter_config.repetition_min_period,
        ):
            if m.times >= instance_filter_config.repetition_max_count:
                return False
        return True


class NumpyFSLDatasetBase(NumpyDatasetBase, Dataset[Dict[str, Any]]):
    """
    A base class for fixed sequence length (FSL) numpy array-backed datasets.
    尽管技术上不需要继承Dataset,但是继承Dataset有以下几个好处:
    1.相当于显式声明符合Dataset协议
    2.向pytorch生态兼容,方便工具判断isinstance(obj, Dataset)
    3.Dataset[Dict[str, Any]]做类型注解,表明__getitem__ 返回 Dict[str, Any]
    """

    def __init__(
        self,
        *paths: PathOrStr,
        sequence_length: int, # 相比上面新补充的信息,确定固定长度的长度
        pad_token_id: int,
        eos_token_id: int,
        vocab_size: int,
        dtype: NumpyUIntTypes = np.uint16,
        metadata: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None, # 相比上面新补充的信息,path/global metadata
        include_instance_metadata: Optional[bool] = None, # 相比上面新补充的信息 getitem时是否要将meta data包含进来
        generate_doc_lengths: bool = False, # 相比上面新补充的信息,是否生成 document length 相关信息
        bos_token_id: Optional[int] = None,
        instance_filter_config: Optional[InstanceFilterConfig] = None, # 相比上面新补充的信息,可选接入的instance_filter
        label_mask_paths: Optional[List[PathOrStr]] = None, # 相比上面新补充的信息,FSL相关辅助数据
    ): # 把"一个样本是固定长度序列"这个事实固化到对象状态里
        if include_instance_metadata is None and metadata: # 设置了metadata但是没设include_instance_metadata true则自动设置
            include_instance_metadata = True

        if isinstance(metadata, list): # 如果meta data是list(每个path都有metadata)则需要长度校验
            if len(metadata) != len(paths):
                raise OLMoConfigurationError(
                    "'metadata' should have the same length as the number of file paths"
                )
        else: # 如果是所有path共享metadata则广播到path长度次,如果为空则广播{}
            metadata = [metadata or {}] * len(paths) # 这里会会重复同一个对象引用,但是只要不改metadata就没什么问题

        if label_mask_paths is not None and len(label_mask_paths) != len(paths): # label mask也需要与path严格对齐
            raise OLMoConfigurationError(
                "There must be the same number of 'label_mask_paths' as there are 'paths'"
            )

        super().__init__( # 给父类传参数,执行父类初始化
            *paths,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            vocab_size=vocab_size,
            dtype=dtype,
            bos_token_id=bos_token_id,
        )
        self._metadata = tuple(metadata) # 既支持"每个path一份metadata",也支持"所有path共用一份metadata"。
        self._sequence_length = sequence_length
        self._include_instance_metadata = include_instance_metadata # getitem时是否携带metadata
        self._generate_doc_lengths = generate_doc_lengths # 相比上面新补充的信息,是否生成 document length 相关信息
        self.instance_filter_config = instance_filter_config # 可选接入的instance_filter
        self._label_mask_paths = label_mask_paths # 记录与 source paths 对齐的 mask 文件路径
        self._label_mask_path_to_source_path: Dict[PathOrStr, PathOrStr] = {} # label mask文件路径与它对应的source token文件路径的映射
        if self._label_mask_paths: # 如果给定了label mask的路径,则构建映射
            for label_mask_path, source_path in zip(self._label_mask_paths, self._array_paths):
                self._label_mask_path_to_source_path[label_mask_path] = source_path
                # 预处理/索引文件名是按source token file来命名的,不是按label mask file来命名的。
                # 如果后面某个流程传进来的是label mask path,想找与之对应的instance indices文件,就必须先找到它对应的source token path

    @property
    def sequence_length(self) -> int: # 对外接口,返回sequence长度
        return self._sequence_length

    @property
    def max_sequence_length(self) -> int: # 对外接口,因为是fix sequence length,所以与上面相同
        return self.sequence_length

    @property
    def max_target_sequence_length(self) -> Optional[int]: # 因为是FSL,所以没有,但留了个引子,可能会出现input和target不同的情况
        return None

    def _get_indices_path( # 根据source path集合和一些额外标识,生成一个稳定的本地索引文件路径
        self, name: str, *source_paths: PathOrStr, extra_ids: Optional[Sequence[str]] = None
    ) -> Path: # 核心在于对source paths和extra ids做哈希
        sha256_hash = hashlib.sha256()
        for source_path in source_paths: # 用 source内容身份+附加配置标识,生成一个唯一hash
            # NOTE: the pre-processed data file names are based on the corresponding source (token IDs) file name,
            # so to get the right instance indices file name for a label mask file, we need to map
            # the label mask file name to its corresponding source file name.
            # 如果传进来的实际上是一个label mask文件路径则先映射回source路径
            # 同一个token source相关的索引文件,不会因为"传的是source path还是label mask path"而得到不同缓存名
            if source_path in self._label_mask_path_to_source_path:
                source_path = self._label_mask_path_to_source_path[source_path]
            sha256_hash.update(str(source_path).encode())
            sha256_hash.update(str(self._get_file_size(source_path)).encode())
        for extra_id in extra_ids or []:
            sha256_hash.update(extra_id.encode())
        path_hash = sha256_hash.hexdigest()
        return self.work_dir / "dataset-common" / f"{name}-{self.sequence_length}-{path_hash}.npy"
        # 得到的文件形如 <work_dir>/dataset-common/<name>-<sequence_length>-<hash>.npy
        # 比如/tmp/dataset-common/instance-indices-2048-a8f3....npy
        # 之后dataset 的"中间索引/预处理结果"的缓存文件就可以写到这里,里面存的是[start_offset_1, start_offset_2, ...]
        # 或是(path_idx, offset),用于存放从raw token中切出instance的"索引表",也就是indices file,所以才需要hash


class NumpyFSLDataset(NumpyFSLDatasetBase):
    """
    A fixed sequence length (FSL) numpy array-backed dataset.

    In this implementation the token IDs from all arrays are concatenated together and then chunked
    into contiguous blocks of ``sequence_length`` tokens to create instances. Therefore documents
    may be split over multiple instances.

    .. seealso::
        :class:`NumpyPaddedFSLDataset`

    .. important::
        If the length of an array is not a multiple of ``sequence_length`` or
        ``max_target_sequence_length`` the remainder of the tokens will be ignored.

    .. important::
        No special tokens are added to the input IDs so it's assumed that if you want
        EOS tokens between documents, for example, those will already be in the array.

    :param paths: Paths or URLs to numpy token ID arrays.
    :param sequence_length: The number of tokens to chunk together into a single instance.
        Generally this should correspond to your model's maximum input length.
    :param pad_token_id: The ID of the padding token.
    :param eos_token_id: The ID of the EOS token.
    :param dtype: The numpy datatype of the arrays.
    :param metadata: Metadata to add to each item. This should be a dictionary or a list of dictionaries
        with the same number of items as there are paths.
    :param include_instance_metadata: If ``True`` (the default), each instance returned from
        :meth:`__getitem__()` will include the metadata from its source.
    :param max_target_sequence_length: Optional upper bound used when precomputing cached offsets.
        If you're planning a sequence-length warm-up, set this to the final chunk size so future
        datasets with larger ``sequence_length`` values can reuse the exact same document ordering.
        The current dataset still returns ``sequence_length``-token windows; this hint simply keeps
        token boundaries and cache files deterministic across warm-up stages. Leave ``None`` if you
        won't rebuild at a larger length.
    """
    '''
    主要作用是把多个token id array当作若干独立的token stream,每个stream按sequence_length切成固定长度块;
    dataset的全局index再映射到某个具体文件里的某个块。
    不是把所有文件真的物理拼接成一个大数组再切,而是每个文件单独切块,再把这些块按文件顺序串成一个全局dataset索引空间
    这里才能看到一个训练sample是怎么从磁盘token array中取出来的。
    '''

    def __init__(
        self,
        *paths: PathOrStr,
        sequence_length: int,
        pad_token_id: int,
        eos_token_id: int,
        vocab_size: int,
        dtype: NumpyUIntTypes = np.uint16,
        metadata: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None,
        include_instance_metadata: Optional[bool] = None,
        generate_doc_lengths: bool = False,
        bos_token_id: Optional[int] = None,
        max_target_sequence_length: Optional[int] = None, # 相比上面新增,具体作用看下面注释
        instance_filter_config: Optional[InstanceFilterConfig] = None,
        label_mask_paths: Optional[List[PathOrStr]] = None,
    ):
        super().__init__(
            *paths,
            sequence_length=sequence_length,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            vocab_size=vocab_size,
            dtype=dtype,
            metadata=metadata,
            include_instance_metadata=include_instance_metadata,
            generate_doc_lengths=generate_doc_lengths,
            bos_token_id=bos_token_id,
            instance_filter_config=instance_filter_config,
            label_mask_paths=label_mask_paths,
        )

        if max_target_sequence_length is not None and (
            max_target_sequence_length < sequence_length
            or max_target_sequence_length % sequence_length != 0
        ):
            raise OLMoConfigurationError(
                "'max_target_sequence_length' should be a multiple of 'sequence_length'"
            )
        '''
        这里的max_target_sequence_length不能小于sequence_length,且必须是sequence_length的整数倍
        该参数不是当前dataset的最大长度,而是为了未来更大的sequence length预先固定chunk边界,比如:
        当前训练阶段使用的sequence_length = 1024,但是后续会切换到sequence_length = 4096
        那么此时就可以设置max_target_sequence_length = 4096,此时dataset在计算每个文件可提供多少个1024实例时
        会先看能组成多少个完整4096大块,之后再将4096大块拆成1024小块,这样在后续使用4096训练时chunk的边界依然是不变的
        '''

        self._max_target_sequence_length = max_target_sequence_length
        self._array_offsets: Optional[Tuple[Tuple[int, int], ...]] = None # 每个文件在全局dataset索引空间中的起止instance区间
        # 比如 file0: 100 个 instances file1: 50 个 instances file2: 70 个 instances
        # 此时 offsets可能为[(0, 100), (100, 150), (150, 220)]
        self._num_instances: Optional[int] = None # 总instance数缓存

    @property
    def fingerprint_fields(self) -> Tuple[str, ...]: # 因为新增了max_target_sequence_length且与解释方式有关,故增加到finger print字段中
        return (
            "vocab_size",
            "pad_token_id",
            "eos_token_id",
            "dtype",
            "max_target_sequence_length",
            "bos_token_id",
        )

    @property
    def num_tokens(self) -> int: # 不是文件中的token数,而是能构成instance的token数,即 instance数 × 每个instance长度
        return len(self) * self.sequence_length

    @property
    def sequence_length(self) -> int:
        return self._sequence_length

    @property
    def max_sequence_length(self) -> int:
        return self.sequence_length

    @property
    def max_target_sequence_length(self) -> Optional[int]:
        return self._max_target_sequence_length

    @property
    def file_sizes(self) -> Tuple[int, ...]: # 使用的是下面的_sizes_and_offsets
        """
        The size, in bytes, of each numpy array.
        """
        return self._sizes_and_offsets[0]

    @property
    def offsets(self) -> Tuple[Tuple[int, int], ...]: # 使用的是下面的_sizes_and_offsets
        """
        Gives the global start and end instance indices for each data file in the dataset.
        """
        return self._sizes_and_offsets[1] # 这里就是刚才提到的每个文件对应的全局instance索引区间

    @property
    def metadata(self) -> Tuple[Dict[str, Any], ...]: # 暴露成per-path metadata,父类中无论如何都会变成per-path格式
        return self._metadata

    def prepare(self): # 会接续触发len(), offset, _sizes_and_offsets, 文件长度统计, label mask校验(若有),自然就完成了prepare
        len(self)

    def __len__(self) -> int:
        if self._num_instances is None: # lazy compute,没算过则使用最后一个offset的end作为总的instance数量
            self._num_instances = self.offsets[-1][1]
        return self._num_instances # 有缓存则直接返回

    def __getitem__(self, index: int) -> Dict[str, Any]: # 核心,实现global index → sample
        index = int(index)  # in case this is a numpy int type. 先对传进来的index做类型转换
        pos_index = index if index >= 0 else len(self) + index # 处理负索引

        # The index of the array within 'self.paths'.
        # 处理全局索引,找落在哪个文件,比如[(0, 100), (100, 150), (150, 220)],索引dataset[123],就应该落在file 1,局部索引是23
        array_index: Optional[int] = None
        # The index within the corresponding array.
        array_local_index: Optional[int] = None
        for i, (offset_start, offset_end) in enumerate(self.offsets):
            if offset_start <= pos_index < offset_end:
                array_index = i
                array_local_index = pos_index - offset_start
                break

        if array_index is None or array_local_index is None: # 越界检查
            raise IndexError(f"{index} is out of bounds for dataset of size {len(self)}")

        # Read the data from file. 这里才真正读磁盘,读文件内部第array_local_index个chunk
        input_ids = self._read_chunk_from_array(self.paths[array_index], array_local_index)
        out: Dict[str, Any] = {"input_ids": input_ids} # 从array_local_index * sequence_length开始读一段固定长度slice

        if self._label_mask_paths is not None: # 有对应的label mask就读对应mask chunk
            label_mask = self._read_chunk_from_array(
                self._label_mask_paths[array_index], array_local_index, dtype=np.bool_
            )
            out["label_mask"] = label_mask

        if self.instance_filter_config is not None: # 如果设置了instance filter,就附加instance_mask,这里是一个布尔变量
            out["instance_mask"] = self._validate_instance(input_ids, self.instance_filter_config)

        if self._include_instance_metadata: # 要求包含metadata则把对应文件的metadata加进去
            metadata = self._metadata[array_index]
            out["metadata"] = deepcopy(metadata) # 注意deepcopy,防止污染共享元素,见上面的[metadata] * len(paths)

        if self._generate_doc_lengths: # dataset的基本sample是token chunk,但但仍然可以在chunk内部恢复document boundary
            out["doc_lens"] = get_document_lengths(
                input_ids, self.eos_token_id, bos_token_id=self.bos_token_id
            )

        return out # 交给dataloder的dict格式的sample

    @property
    def _sizes_and_offsets(self) -> Tuple[Tuple[int, ...], Tuple[Tuple[int, int], ...]]:
        # 计算并缓存:1.每个文件的token数(注意返回给file_sizes时仍是字节数缓存)2.每个文件在全局dataset索引空间中的offsets
        if self._array_offsets is None or self._array_file_sizes is None: # lazy compute
            array_sizes: List[int] = []
            array_offsets: List[Tuple[int, int]] = []
            array_file_sizes: List[int] = []
            item_size = self.dtype(0).itemsize

            start_offset = 0
            for size, length in self.map(self._get_file_size_and_length): # 使用map对所有path做_get_file_size_and_length
                array_sizes.append(size // item_size) # 文档大小除以数据类型得到具体的token数
                end_offset = start_offset + length # 起点 + 长度 得到终点
                array_offsets.append((start_offset, end_offset)) # 添加偏移tuple
                array_file_sizes.append(size) # 添加文档大小
                start_offset += length # 起点偏移,为下个文件做准备
            # 到这里就算完了,存起来
            self._array_offsets = tuple(array_offsets)
            self._array_file_sizes = tuple(array_file_sizes)
            # 做label mask的校验
            mask_item_size = np.bool_(True).itemsize # 获取np.bool_(True)占用的内存字节数,1个字节
            if self._label_mask_paths is not None: # 给定了label mask
                for i, (size, _) in enumerate( #
                    self.map(
                        partial(self._get_file_size_and_length, dtype=np.bool_), # 偏函数固定数据大小
                        _paths=self._label_mask_paths, # 使用map函数对所有path做_get_file_size_and_length
                    )
                ):
                    size = size // mask_item_size # 文件大小除以每个布尔变量的大小,就能得到总共有多少布尔变量
                    if array_sizes[i] != size: # 校验每个文件和其label mask长度是否对齐
                        raise RuntimeError(
                            f"mismatch between size of source file ('{self._array_paths[i]}', {array_sizes[i]:,d}) and "
                            f"size of corresponding label mask file ('{self._label_mask_paths[i]}', {size:,d})"
                        )

        return self._array_file_sizes, self._array_offsets # 缓存了则直接返回

    def _read_chunk_from_array(self, path: PathOrStr, index: int, dtype=None) -> torch.Tensor:
        # 把局部的chunk索引(index)转成token slice(tensor)
        start_idx = index * self.sequence_length
        return load_array_slice_into_tensor( # load_array_slice_into_tensor逻辑非常简单,可以自己去看下
            path,
            start_idx,
            start_idx + self.sequence_length,
            dtype or self.dtype,
        ) # 最终会把对应的token slice转成tensor返回

    def _get_file_size_and_length(self, path: PathOrStr, idx: int, dtype=None) -> Tuple[int, int]:
        del idx # 明确在这段逻辑中idx不使用
        dtype = dtype or self.dtype
        item_size = dtype(0).itemsize # 按数据类型获取其单个数据大小
        file_size = get_file_size(path) # 获取文件大小,注意这里返回的是字节大小
        if ( # 如果没有指定max_target_sequence_length,或者其与sequence length长度一致
            self.max_target_sequence_length is None
            or self.max_target_sequence_length == self.sequence_length
        ):
            return file_size, file_size // (item_size * self.sequence_length) # 除一下就得到总共有多少instance了
        elif self.max_target_sequence_length > self.sequence_length: # 指定了max_target_sequence_length
            num_max_seq_len_instances = file_size // (item_size * self.max_target_sequence_length) # 先算大块的数量
            return (
                file_size,
                num_max_seq_len_instances
                * (self.max_target_sequence_length // self.sequence_length), # 再换算成小块
            )
        else:
            raise RuntimeError("invalid 'max_target_sequence_length' or 'sequence_length'")


class NumpyFSLDatasetMixture(NumpyFSLDataset):
    """
    A version of :class:`NumpyFSLDataset` built from a mixture of sources and their expected token ratios relative to each other.
    A ``path_offset_index`` is used to determine the number of instances to retain from a path when constructing the local indices.
    NumpyFSLDatasetMixture承接source_mixture.py产出的path-level token budget并将其转换为instance
    继承自NumpyFSLDataset,但一个核心不同是每个path可提供多少实例,不再由文件真实大小直接决定,而是由source_mixture.py算出来的
    也就是 NumpyFSLDataset计算每个path的instance数由file_size / (dtype_size * sequence_length)确定
    而 NumpyFSLDatasetMixture则是path_offset_index[(path, idx)] // sequence_length,source mixture允许使用多少决定
    SourceMixtureDataset.to_paths()会得到[pathA, pathB, pathA, pathC],注意path允许重复
    to_index()则会得到一个映射:
        {
            (pathA, 0): tokensA1,
            (pathB, 1): tokensB,
            (pathA, 2): tokensA2,
            (pathC, 3): tokensC,
        } 指明了每个path能提供多少tokens(也就是指定的额度)
    构成了这里开展工作的基础,这里最后会约定每个path能提供多少instance,也就是有多少个fixed-length chunk
    """

    def __init__(
        self,
        *paths: PathOrStr, # SourceMixtureDataset.to_paths()会传到这里
        path_offset_index: Dict[Tuple[str, int], int], # 新增,SourceMixtureDataset.to_index()会传到这里
        seed: int, # 这个也是新增的,用于设置采样种子
        sequence_length: int,
        pad_token_id: int,
        eos_token_id: int,
        vocab_size: int,
        dtype: NumpyUIntTypes = np.uint16,
        metadata: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None,
        include_instance_metadata: Optional[bool] = None,
        generate_doc_lengths: bool = False,
        bos_token_id: Optional[int] = None,
        max_target_sequence_length: Optional[int] = None, # 这个参数详见下面注释
        instance_filter_config: Optional[InstanceFilterConfig] = None,
    ):
        if max_target_sequence_length is not None and ( # 与上面同样的校验
            max_target_sequence_length < sequence_length
            or max_target_sequence_length % sequence_length != 0
        ):
            raise OLMoConfigurationError(
                "'max_target_sequence_length' should be a multiple of 'sequence_length'"
            )

        super().__init__( # 父类初始化
            *paths,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            vocab_size=vocab_size,
            dtype=dtype,
            sequence_length=sequence_length,
            metadata=metadata,
            include_instance_metadata=include_instance_metadata,
            generate_doc_lengths=generate_doc_lengths,
            bos_token_id=bos_token_id,
            max_target_sequence_length=max_target_sequence_length,
            instance_filter_config=instance_filter_config,
        ) # 以下是新增的内部状态
        self._num_instances: Optional[int] = None
        self._array_offsets: Optional[Tuple[Tuple[int, int], ...]] = None
        self._lengths_dtype: Optional[NumpyUIntTypes] = None # 本段代码中没用到
        self._instances_per_bucket: Optional[Tuple[Tuple[int, int], ...]] = None #本段代码中没用到
        self._path_offset_index = path_offset_index
        self._seed = seed

    @property
    def indices_dtype( # mixture instance/document indices文件中使用的整数dtype
        self,
    ) -> NumpyUIntTypes:
        return np.uint32 # 通常索引文件里存的是位置、边界、实例编号等,所以用uint32足够且节省空间

    def prepare(self):
        if self.fs_local_rank == 0: # 仅在rank0上写索引,其他local rank在这里取就好了
            log.info("Gathering indices...")
            self._write_document_indices()
        barrier() # 设置同步屏障,保证索引文件写完后别的进程再继续
        len(self) # 因为他自己没有实现__len__,所以用的是父类的

    def _get_instance_indices_path(self, source_path: PathOrStr) -> Path:
        # 生成某个source path对应的mixture instance索引缓存文件路径
        return self._get_indices_path( # 用的是父类的_get_indices_path
            "mixture-instance-indices", source_path, extra_ids=(self.indices_dtype.__name__,)
        )

    def _write_document_indices(self): # 区别于父类的核心逻辑
        # 先找哪些path需要生成indices
        paths_needed: List[Tuple[PathOrStr, int]] = []
        for idx, path in enumerate(self.paths):
            indices_path = self._get_instance_indices_path(path)
            if indices_path.is_file(): # 如果缓存已经存在,就复用
                log.info(f"Reusing document indices for '{path}' at:\n'{indices_path}'")
            elif path not in paths_needed: # 否则记下来需要处理
                paths_needed.append((path, idx))

        if paths_needed: # 为每个需要处理的path计算max instance
            with concurrent.futures.ProcessPoolExecutor() as executor: #多进程处理
                futures = [] # future对象的占位列表
                for path, idx in paths_needed:
                    indices_path = self._get_instance_indices_path(path) # 尽管path可能重复,但是文件只需要生成一次
                    log.info(f"Gathering instance indices for '{path}'...")
                    # NOTE: We limit the number of instances by total target token count // sequence length
                    max_instances = ( # 按(path idx)取出当前path能提供的token数(预算),再除以sequence_length,得到最大instance数
                        self._path_offset_index[(str(path), idx)] // self.sequence_length # 这里path可能是重复的
                    )

                    # Sampling from small npy files can result in 0 instance indices.
                    # We skip processing these to avoid writing empty mmapped files.
                    '''
                    segment_documents_into_instances() 的行为像"每document对应一个clipped instance",这会导致长文档
                    token利用率偏低;因此这段实现要么是有意采用document-level sampling,要么就是doc分布上以短doc为主
                    也有可能是path指向的.npy文件本身就不大,所以不存在超长token array,只是一个shard,详细逻辑行为见附录。
                    所以数据层的逻辑设计从数据的预处理阶段就开始了,是否"设计有问题",也取决于它想优化什么。
                    '''
                    if max_instances > 0: # 只有当前path能提供时才执行
                        future = executor.submit(
                            run_worker_func,
                            segment_documents_into_instances, # 对给定的path,实现一个文档变成一个instance
                            path, # 这是读的path
                            indices_path, # 这是写入的path
                            max_sequence_length=self.sequence_length,
                            eos_token_id=self.eos_token_id,
                            dtype=self.dtype,
                            indices_dtype=self.indices_dtype,
                            sample=(max_instances, self._seed),
                        )
                        futures.append(future)

                concurrent.futures.wait(futures, return_when="FIRST_EXCEPTION")

                # Log results.
                for path, future in zip([item[0] for item in paths_needed], futures):
                    _, total_instances = future.result()  # 把每个path的instances数打印到日志
                    log.info(
                        f"Created {total_instances:,d} instances of sequence length up to "
                        f"{self.sequence_length} from '{path}'"
                    )

    '''注意下面的read部分被注释掉了,所以这个类尽管在上面做了_write_document_indices,'''
    '''但对instance的读取方式依然是依赖父类的getitem,而不是去读文件中的索引'''

    # def _read_chunk_from_array(self, path: PathOrStr, index: int) -> torch.Tensor:
    #     indices_path = self._get_instance_indices_path(path)
    #     indices = load_array_slice_into_tensor(
    #         indices_path, index * 2, index * 2 + 2, self.indices_dtype
    #     )
    #     start_idx, end_idx = indices
    #     data = load_array_slice_into_tensor(path, int(start_idx), int(end_idx), self.dtype)
    #     return data

    def _get_file_size_and_length(self, path: PathOrStr, idx: int, dtype=None) -> Tuple[int, int]:
        # source_mixture.py控制的采样比例和数量是在这里发挥作用,将约定的预算,包装成文件大小
        dtype = dtype or self.dtype
        item_size = dtype(0).itemsize
        file_size = self._get_size_from_offset_index((path, idx)) # 使用的是path与配额,而不仅仅是path读真实文件大小
        if (
            self.max_target_sequence_length is None
            or self.max_target_sequence_length == self.sequence_length
        ):
            return file_size, file_size // (item_size * self.sequence_length) # 返回虚拟的文件大小和instance数量
        elif self.max_target_sequence_length > self.sequence_length:
            num_max_seq_len_instances = file_size // (item_size * self.max_target_sequence_length)
            return ( # 同样的,先算大块再拆小块
                file_size,
                num_max_seq_len_instances
                * (self.max_target_sequence_length // self.sequence_length),
            )
        else:
            raise RuntimeError("invalid 'max_target_sequence_length' or 'sequence_length'")

    def _get_size_from_offset_index(self, path_index: Tuple[PathOrStr, int]) -> int:
        try:
            path, idx = path_index
            # Get size in bytes from tokens in the supplied index * itemsize
            return self._path_offset_index[(str(path), idx)] * self.dtype(0).itemsize # token预算 × 数据大小 = 虚拟文件大小
        except KeyError:
            raise OLMoEnvironmentError(f"Item not found in path index @ {path_index}")


class NumpyPaddedFSLDataset(NumpyFSLDataset):
    """
    An FSL dataset that creates a single instance from each document.
    The resulting instances will all have exactly ``sequence_length`` tokens, using padding if needed.
    与上面类似,生成固定长度的instance,但是是每个document切分一个instance,长度不够的做padding
    """

    def __init__(
        self,
        *paths: PathOrStr,
        sequence_length: int,
        pad_token_id: int,
        eos_token_id: int,
        vocab_size: int,
        dtype: NumpyUIntTypes = np.uint16,
        bos_token_id: Optional[int] = None,
        metadata: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None,
        include_instance_metadata: Optional[bool] = None,
        instance_filter_config: Optional[InstanceFilterConfig] = None,
        label_mask_paths: Optional[List[PathOrStr]] = None,
    ):
        super().__init__(
            *paths,
            sequence_length=sequence_length,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            vocab_size=vocab_size,
            dtype=dtype,
            metadata=metadata,
            include_instance_metadata=include_instance_metadata,
            bos_token_id=bos_token_id,
            instance_filter_config=instance_filter_config,
            label_mask_paths=label_mask_paths,
        )
        self._array_instance_offsets: Optional[Tuple[Tuple[int, int], ...]] = None

    @property
    def fingerprint_fields(self) -> Tuple[str, ...]:
        return (
            "vocab_size",
            "pad_token_id",
            "eos_token_id",
            "dtype",
            "max_target_sequence_length",
            "bos_token_id",
            "sequence_length", # 注意这里的fingerprint加了sequence_length作为关键配置项
        )

    @property
    def offsets(self) -> Tuple[Tuple[int, int], ...]: # 这里是按索引文件中的instance数量推算
        if self._array_instance_offsets is None:
            item_size = self.indices_dtype(0).itemsize # 先取每个数据的大小
            num_instances_per_path = self.map( # 对每个path计算
                lambda path, _: get_file_size(self._get_instance_indices_path(path)) # 这里get_file_size读的是真实文件大小
                // (item_size * 2) # 每个item是一对数字[start_idx, end_idx]
            ) # 就得到了这个path下总共有多少instance
            array_instance_offsets = []
            start_offset = 0
            for num_instances in num_instances_per_path: # 构造全局区间,用于getitem时的索引
                array_instance_offsets.append((start_offset, start_offset + num_instances))
                start_offset += num_instances
            self._array_instance_offsets = tuple(array_instance_offsets)
        return self._array_instance_offsets

    @property
    def indices_dtype( # 统一用uint32存
        self,
    ) -> NumpyUIntTypes:
        return np.uint32

    def prepare(self):
        if self.fs_local_rank == 0: # 仅local rank 0写文档,其他rank读就好了
            log.info("Gathering dataset document indices...")
            self._write_instance_indices() # 在这里把索引写到本地,后续通过len处理
        barrier() # 同步屏障
        len(self) # 用的是父类的len,所以也间接的prepare了父类,并且父类的len调用了offset,用的是该类重写的offset

    def __getitem__(self, index: int) -> Dict[str, Any]:
        item = super().__getitem__(index) # 先用父类来取,之后再做padding,因为父类中有_read_chunk_from_array,但用的也是该类自己的
        pad_shape = (0, self.sequence_length - len(item["input_ids"])) # 计算pad量
        if "label_mask" in item: # 有label mask 则将label mask同步pad
            item["label_mask"] = F.pad(item["label_mask"], pad_shape, value=False)
        else: # 否则就使用input_ids对应位置置1,再pad后作为mask
            item["label_mask"] = F.pad(
                torch.ones_like(item["input_ids"], dtype=torch.bool), pad_shape, value=False
            )
        item["input_ids"] = F.pad(item["input_ids"], pad_shape, value=self.pad_token_id) # pad输入
        return item

    def _read_chunk_from_array(self, path: PathOrStr, index: int, dtype=None) -> torch.Tensor: # 被父类getitem间接调用
        # 父类的read是[k*S : (k+1)*S],也就是去取index确定的slice,这里改成了去取tokens[start_idx:end_idx]
        indices_path = self._get_instance_indices_path(path)
        indices = load_array_slice_into_tensor( # 先去索引文件取边界
            indices_path, index * 2, index * 2 + 2, self.indices_dtype
        )
        start_idx, end_idx = indices # 之后就可以去path取边界指定的数据了
        data = load_array_slice_into_tensor(path, int(start_idx), int(end_idx), dtype or self.dtype)
        return data

    def _get_instance_indices_path(self, source_path: PathOrStr) -> Path:
        return self._get_indices_path("instance-indices", source_path)

    def _write_instance_indices(self): # prepare的核心工作
        paths_needed: List[PathOrStr] = []
        for path in self.paths: # 先检查path中那些处理了哪些没处理
            indices_path = self._get_instance_indices_path(path)
            if indices_path.is_file():
                log.info(f"Reusing instance indices for '{path}' at:\n'{indices_path}'")
            elif path not in paths_needed:
                paths_needed.append(path)

        if paths_needed: # 如果还有没处理的部分
            with concurrent.futures.ProcessPoolExecutor() as executor:
                futures = []
                for path in paths_needed:
                    indices_path = self._get_instance_indices_path(path)
                    log.info(f"Gathering instance indices for '{path}'...")
                    future = executor.submit( # 这里使用的也是segment_documents_into_instances按doc切分,相比上面的类没做采样
                        run_worker_func,
                        segment_documents_into_instances,
                        path,
                        indices_path,
                        max_sequence_length=self.sequence_length,
                        eos_token_id=self.eos_token_id,
                        dtype=self.dtype,
                        indices_dtype=self.indices_dtype,
                    )
                    futures.append(future)

                concurrent.futures.wait(futures, return_when="FIRST_EXCEPTION")

                # Log results.
                for path, future in zip(paths_needed, futures):
                    _, total_instances = future.result()
                    log.info(
                        f"Created {total_instances:,d} instances of sequence length up to "
                        f"{self.sequence_length} from '{path}'"
                    )


class NumpyPackedFSLDataset(NumpyFSLDatasetBase): # 注意到直接继承自NumpyFSLDatasetBase而不是NumpyFSLDataset
    """
    An FSL dataset that packs documents into instances using the Optimized Best-Fit Decreasing (OBFD)
    algorithm described in `Fewer Truncations Improve Language Modeling <https://arxiv.org/pdf/2404.10830>`_.
    The resulting instances will all have exactly ``sequence_length`` tokens, using padding if needed.
    把 document 打包进固定长度 instance,并尽量降低 padding。默认是按单个 source file 独立 packing,
    但也支持通过 source_group_size > 1 把多个连续 source file 作为一个更大的 packing 单位。
    重新定义了大多数行为
    .. note::
        By default OBFD is applied to each source file separately since source files from the Dolma toolkit
        are usually large enough for OBFD to achieve very good compactness (minimal padding tokens)
        and so that we can parallelize the packing. However, you can pack instances from multiple
        consecutive source files together by setting ``source_group_size`` to a value greater than 1.

    .. tip::
        Although this shares much of its option plumbing with :class:`NumpyFSLDataset`, it bypasses
        that subclass and derives from :class:`NumpyFSLDatasetBase` so it can provide its own packing
        caches, offsets, and item materialisation logic. Subclassing :class:`NumpyFSLDataset` would
        require overriding nearly every behavior defined there.
    """

    def __init__(
        self,
        *paths: PathOrStr,
        sequence_length: int,
        pad_token_id: int,
        eos_token_id: int,
        vocab_size: int,
        dtype: NumpyUIntTypes = np.uint16,
        metadata: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None,
        include_instance_metadata: Optional[bool] = None,
        generate_doc_lengths: bool = False,
        bos_token_id: Optional[int] = None,
        instance_filter_config: Optional[InstanceFilterConfig] = None,
        label_mask_paths: Optional[List[PathOrStr]] = None,
        long_doc_strategy: LongDocStrategy = LongDocStrategy.truncate, # 超长document本身怎么处理,默认是截断
        source_group_size: int = 1, # 默认是单个source file自己做packing,大于1时指的是若干source file一起做packing
    ):
        super().__init__( # 父类初始化
            *paths,
            sequence_length=sequence_length,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            vocab_size=vocab_size,
            dtype=dtype,
            metadata=metadata,
            include_instance_metadata=include_instance_metadata,
            generate_doc_lengths=generate_doc_lengths,
            bos_token_id=bos_token_id,
            instance_filter_config=instance_filter_config,
            label_mask_paths=label_mask_paths,
        )

        assert source_group_size >= 1

        self._long_doc_strategy = long_doc_strategy
        self._source_group_size = source_group_size

        self._source_path_groups = list(chunked(self.paths, self.source_group_size)) # 实现按group size分组,很简单,自己看下
        self._label_mask_path_groups: Optional[List[List[PathOrStr]]] = None
        self._metadata_groups = list(chunked(self._metadata, self.source_group_size))

        if self._label_mask_paths: # 如果对label也做了mask,那么自然也要对label mask分组
            self._label_mask_path_groups = list(
                chunked(self._label_mask_paths, self.source_group_size)
            )

        self._source_sizes: Optional[List[int]] = None
        self._source_size_groups: Optional[List[List[int]]] = None
        self._source_instance_offsets: Optional[Tuple[Tuple[int, int], ...]] = None
        self._num_instances: Optional[int] = None

    @property
    def fingerprint_fields(self) -> Tuple[str, ...]:
        fields: Tuple[str, ...] = (
            "vocab_size",
            "pad_token_id",
            "eos_token_id",
            "dtype",
            "long_doc_strategy",
            "bos_token_id",
            "sequence_length",
        )
        # For backwards compat, only add this when it's not the default.
        if self._source_group_size > 1: # 指定了分组行为的话也要把这个字段纳入fingerprint
            fields = fields + ("source_group_size",)
        return fields

    @property
    def long_doc_strategy(self) -> LongDocStrategy:
        return self._long_doc_strategy

    @property
    def source_group_size(self) -> int:
        return self._source_group_size

    @property
    def indices_dtype(
        self,
    ) -> NumpyUIntTypes:
        return np.uint64 # 这里用的是uint64

    @property
    def source_instance_offsets(self) -> Tuple[Tuple[int, int], ...]: # 计算pack版的offset
        if self._source_instance_offsets is None:
            item_size = self.indices_dtype(0).itemsize
            num_instances_per_group = self.map(
                lambda path, _: get_file_size(path) // (item_size * 2), # lambda匿名函数,因为map要求func有两个参数,所以再包一层
                _paths=[ # 使用这里指定的路径
                    self._get_instance_offsets_path(*paths) # 传入的路径是分组后的缓存路径,访问的是每个packed instance
                    for paths in chunked(self.paths, self.source_group_size) # 在 docs_by_instance 中对应的 doc-ID 区间
                ],# 形如[start_0, end_0, start_1, end_1, ...],因为一对数字对应一个instance,所以就可以推断出该组下总共有多少instance
            ) # 另一个需要注意的是这里计数的是已经pack好的instance
            array_instance_offsets = []
            start_offset = 0
            for num_instances in num_instances_per_group: # 之后就可以正常计算offset了
                array_instance_offsets.append((start_offset, start_offset + num_instances))
                start_offset += num_instances
            self._source_instance_offsets = tuple(array_instance_offsets)
        return self._source_instance_offsets

    @property
    def source_sizes(self) -> List[int]:
        if self._source_sizes is None:
            item_size = self.dtype(0).itemsize
            self._source_sizes = self.map(lambda path, _: get_file_size(path) // item_size) # 逐文件计算token数量
        return self._source_sizes

    @property
    def source_size_groups(self) -> List[List[int]]:
        if self._source_size_groups is None:
            self._source_size_groups = list(chunked(self.source_sizes, self.source_group_size)) # 按照group分组
        return self._source_size_groups

    def prepare(self):
        if self.fs_local_rank == 0:
            log.info("Packing document into instances...")
            self._pack_all_documents_into_instances() # 先做packing,
        barrier()
        len(self) # 最后再计算len,同样会触发source_instance_offsets计算offset

    def __len__(self) -> int:
        if self._num_instances is None: # 逻辑是一样的,只需要取最后一个offset的end就知道有多少instance了
            self._num_instances = self.source_instance_offsets[-1][1]
        return self._num_instances

    def __getitem__(self, index: int) -> Dict[str, Any]: # 根据全局索引index取instance
        index = int(index)  # in case this is a numpy int type.
        index = index if index >= 0 else len(self) + index # 处理负索引

        # The index of the source group.
        source_group_index: Optional[int] = None
        # The instance index within the source group.
        instance_index: Optional[int] = None
        for i, (instance_offset_start, instance_offset_end) in enumerate(
            self.source_instance_offsets
        ): # 全局索引 → 第几个source group + 该group内第几个packed instance
            if instance_offset_start <= index < instance_offset_end:
                source_group_index = i
                instance_index = index - instance_offset_start
                break

        if source_group_index is None or instance_index is None:
            raise IndexError(f"{index} is out of bounds for dataset of size {len(self)}")

        # All npy source file paths within the group. 取出该group下的所有path
        source_paths = self._source_path_groups[source_group_index]
        # The number of tokens in each npy source file within the group. 取出该group下的所有path对应的token数量
        source_sizes = self.source_size_groups[source_group_index]
        # All label mask paths for the group.
        label_mask_paths = ( # 按组取label mask
            None
            if self._label_mask_path_groups is None
            else self._label_mask_path_groups[source_group_index]
        )
        # 按path取对应的缓存文件
        document_indices_path = self._get_document_indices_path(*source_paths) # 记录每个document的(start_idx, end_idx)
        instance_offsets_path = self._get_instance_offsets_path(*source_paths) # 每个instance在doc-ID长数组中的切片区间
        docs_by_instance_path = self._get_docs_by_instance_path(*source_paths) # instance由哪些doc组成

        # Load start and end document indices corresponding to instance.
        instance_indices = load_array_slice_into_tensor( # 先按照offset取对应的start和end
            instance_offsets_path,
            instance_index * 2,
            instance_index * 2 + 2,
            self.indices_dtype,
        ).tolist() # 也就是当前packed instance在docs_by_instance数组中的[instance_start, instance_end)
        instance_start, instance_end = instance_indices

        # Load document IDs corresponding to instance.
        document_ids = load_array_slice_into_tensor( # 按start和end,取document_ids,也就是packed instance由哪些document组成
            docs_by_instance_path,
            instance_start,
            instance_end,
            self.indices_dtype,
        ).tolist()

        # Load token IDs and maybe label masks for each document. 然后就可以根据document_ids查真实的token边界了
        document_token_ids: List[torch.Tensor] = []
        document_label_masks: Optional[List[torch.Tensor]] = (
            None if label_mask_paths is None else []
        )
        for document_id in document_ids: # 逐个doc检查group内的全局边界
            document_indices = load_array_slice_into_tensor( # 先取全局边界
                document_indices_path, document_id * 2, document_id * 2 + 2, self.indices_dtype
            ).tolist()
            document_start, document_end = document_indices

            # Pick out the right source file from the source group by comparing the starting
            # index (in tokens) of the document to the starting index of each source within the group.
            source_path: Optional[PathOrStr] = None
            label_mask_path: Optional[PathOrStr] = None
            source_start = 0 # 把多个文件逻辑拼接,再反向映射回物理文件
            for i, (source_path, source_size) in enumerate(zip(source_paths, source_sizes)):
                if source_start <= document_start < (source_start + source_size): # 落到某个具体文件内时
                    document_start -= source_start # document_start = document_start - source_start就得到了文件内起点
                    document_end -= source_start # 同样的得到文件内的终点
                    if label_mask_paths is not None: # 指定了label_mask,则按path取
                        label_mask_path = label_mask_paths[i]
                    break
                else:
                    source_start += source_size # 没有落在当前文件内则继续推进检查下一个文件
            else:
                raise RuntimeError("we shouldn't be here!")

            assert source_path is not None
            document_token_ids.append( # 这里就可以到具体的文件中,按起止去读了
                load_array_slice_into_tensor(source_path, document_start, document_end, self.dtype)
            )
            if label_mask_path is not None:
                assert document_label_masks is not None
                document_label_masks.append(
                    load_array_slice_into_tensor( # 取对应文件中°label mask
                        label_mask_path, document_start, document_end, np.bool_
                    )
                )

        # Combine token IDs and maybe label masks for each document.
        input_ids = torch.cat(document_token_ids) # 全都取回来就可以拼接了,拼接好的就是正式的输入ids
        label_mask = None if document_label_masks is None else torch.cat(document_label_masks)

        # Pad to target sequence length.
        pad_shape = (0, self.sequence_length - input_ids.numel()) # 长度不够就做pad,逻辑与上一个类类似
        if label_mask is not None:
            label_mask = F.pad(label_mask, pad_shape, value=False)
        else:
            label_mask = F.pad(torch.ones_like(input_ids, dtype=torch.bool), pad_shape, value=False)
        input_ids = F.pad(input_ids, pad_shape, value=self.pad_token_id)

        # Prepare final output.
        out: Dict[str, Any] = {"input_ids": input_ids, "label_mask": label_mask}
        if self.instance_filter_config is not None:
            out["instance_mask"] = self._validate_instance(input_ids, self.instance_filter_config)
        if self._include_instance_metadata:
            metadata = self._metadata_groups[source_group_index]
            out["metadata"] = deepcopy(metadata)
        if self._generate_doc_lengths:
            out["doc_lens"] = get_document_lengths(
                input_ids, self.eos_token_id, bos_token_id=self.bos_token_id
            )
        return out

    def _get_document_indices_path(self, *source_paths: PathOrStr) -> Path:
        return self._get_indices_path(
            "document-indices",
            *source_paths,
            extra_ids=(self._long_doc_strategy, self.indices_dtype.__name__),
        )

    def _get_instance_offsets_path(self, *source_paths: PathOrStr) -> Path:
        return self._get_indices_path(
            "instance-offsets",
            *source_paths,
            extra_ids=(self._long_doc_strategy, self.indices_dtype.__name__),
        )

    def _get_docs_by_instance_path(self, *source_paths: PathOrStr) -> Path:
        return self._get_indices_path(
            "documents-by-instance",
            *source_paths,
            extra_ids=(self._long_doc_strategy, self.indices_dtype.__name__),
        )

    def _pack_documents_from_source_into_instances(
        self, *source_paths: PathOrStr
    ) -> Tuple[int, int]:
        document_indices_path = self._get_document_indices_path(*source_paths)
        instance_offsets_path = self._get_instance_offsets_path(*source_paths)
        docs_by_instance_path = self._get_docs_by_instance_path(*source_paths)

        instances, document_indices, total_tokens = pack_documents_into_instances( # instances标明这个instance由哪些document组成
            *source_paths,
            max_sequence_length=self.sequence_length,
            eos_token_id=self.eos_token_id,
            bos_token_id=self.bos_token_id,
            dtype=self.dtype,
            indices_dtype=self.indices_dtype,
            long_doc_strategy=self._long_doc_strategy,
        ) # document_indices记录了文档的起止边界。total_tokens标记了pack前的总token数
        document_indices = document_indices.reshape(-1)
        '''
        之后处理instances,instances要转化为两类缓存,首先是instance_offsets,包含了在某个doc中的起点和终点
        比如instances = [[0, 3], [1], [2, 4, 5]],表示:
        
        instance0 由 document 0 和 document 3 组成
        instance1 由 document 1 组成
        instance2 由 document 2、4、5 组成
        
        经过计算instances包含了3个instance。之后为了将这个instance写入缓存,还需要进一步处理
        docs_by_instance = [0, 3, 1, 2, 4, 5] 也就是把instances变成一个长数组
        instance_offsets = [0, 2, 2, 3, 3, 6] 每两个数字表示一个instance所需要doc情况,其用区间表示
        表示:
        instance0 用 docs [0,3] 也就是instance_offsets中的[0, 2]
        instance1 用 docs [1] 也就是instance_offsets中的[2, 3]
        instance2 用 docs [2,4,5] 也就是instance_offsets中的[3, 6]
        '''
        instance_start_offset = 0
        instance_offsets_list: List[int] = []
        documents_by_instance_list: List[int] = []
        for instance in instances:
            instance_offsets_list.append(instance_start_offset)
            instance_offsets_list.append(instance_start_offset + len(instance))
            instance_start_offset += len(instance)
            documents_by_instance_list.extend(instance) # 然后是documents_by_instance

        # shape: (num_instances * 2,)
        instance_offsets = np.array(instance_offsets_list, dtype=self.indices_dtype)
        # shape: (num_documents,)
        docs_by_instance = np.array(documents_by_instance_list, dtype=self.indices_dtype)
        # 然后就可以写入缓存了
        write_array_to_disk(document_indices, document_indices_path)
        write_array_to_disk(instance_offsets, instance_offsets_path)
        write_array_to_disk(docs_by_instance, docs_by_instance_path)

        return len(instances), total_tokens

    def _pack_all_documents_into_instances(self):
        # Collect all sources that need to be packed (no cache hit).
        sources_needed: List[List[PathOrStr]] = []
        for source_paths in chunked(self.paths, self.source_group_size):
            document_indices_path = self._get_document_indices_path(*source_paths)
            instance_offsets_path = self._get_instance_offsets_path(*source_paths)
            docs_by_instance_path = self._get_docs_by_instance_path(*source_paths)
            if ( # 检查是否已经缓存了
                document_indices_path.is_file()
                and instance_offsets_path.is_file()
                and docs_by_instance_path.is_file()
            ):
                log.info(f"Reusing cached packing results for {source_paths}")
            elif source_paths not in sources_needed: # 未缓存的添加到待处理列表
                sources_needed.append(source_paths)

        if sources_needed:
            with concurrent.futures.ProcessPoolExecutor() as executor:
                futures = []
                for source_paths in sources_needed:
                    log.info(f"Packing documents from {source_paths} into instances...")
                    future = executor.submit(
                        run_worker_func, # 注册了weka,允许cached_path识别
                        self._pack_documents_from_source_into_instances,  # 按分组处理
                        *source_paths,
                    )
                    futures.append(future)

                concurrent.futures.wait(futures, return_when="FIRST_EXCEPTION")

                # Log results.
                for source_paths, future in zip(sources_needed, futures):
                    total_instances, total_tokens = future.result()
                    total_padding = self.sequence_length * total_instances - total_tokens
                    avg_padding = total_padding / total_instances
                    log.info(
                        f"Packed {total_tokens:,} tokens from {source_paths} into {total_instances:,d} instances "
                        f"of sequence length {self.sequence_length:,d} using an average of "
                        f"{avg_padding:.1f} padding tokens per instance."
                    )


class NumpyInterleavedFSLDataset(NumpyPaddedFSLDataset): # 注意这里继承的是NumpyPaddedFSLDataset
    """
    A version of :class:`NumpyPaddedFSLDataset` that creates a single instance by chunking documents and
    interleaving these chunks. The resulting instances may be padded out to ``sequence_length``.
    """

    def __init__(
        self,
        *paths: PathOrStr,
        sequence_length: int,
        pad_token_id: int,
        eos_token_id: int,
        vocab_size: int,
        seed: int,
        docs_per_instance: int, # 每个instance用多少个document
        chunks_per_doc: int, # 每个document切多少块
        dtype: NumpyUIntTypes = np.uint16,
        metadata: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None,
        include_instance_metadata: Optional[bool] = None,
        instance_filter_config: Optional[InstanceFilterConfig] = None,
        label_mask_paths: Optional[List[PathOrStr]] = None,
        bos_token_id: Optional[int] = None,
        interleaving_exempt_paths: Optional[List[PathOrStr]] = None,
    ):
        if sequence_length % docs_per_instance != 0: # 参数合法性检查,每个doc占的token数一致
            raise OLMoConfigurationError(
                "'sequence_length' must be a multiple of 'docs_per_instance'"
            )
        if sequence_length % chunks_per_doc != 0: # 每个chunk大小一致
            raise OLMoConfigurationError("'sequence_length' must be a multiple of 'chunks_per_doc'")

        super().__init__(
            *paths,
            sequence_length=sequence_length,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            vocab_size=vocab_size,
            dtype=dtype,
            bos_token_id=bos_token_id,
            metadata=metadata,
            include_instance_metadata=include_instance_metadata,
            instance_filter_config=instance_filter_config,
            label_mask_paths=label_mask_paths,
        )
        self._docs_per_instance = docs_per_instance
        self._chunks_per_doc = chunks_per_doc
        self._seed = seed
        self._interleaving_exempt_paths = interleaving_exempt_paths
        self._num_interleaving_exempt_instances: Optional[int] = None
        self._num_interleavable_instances: Optional[int] = None

    @property
    def fingerprint_fields(self) -> Tuple[str, ...]:
        return (
            "vocab_size",
            "pad_token_id",
            "eos_token_id",
            "dtype",
            "_docs_per_instance",
            "_seed",
            "_interleaving_exempt_paths",
            "max_target_sequence_length",
            "bos_token_id",
            "sequence_length",
        )

    def __len__(self) -> int:
        if self._num_instances is None:
            item_size = self.indices_dtype(0).itemsize

            interleavable_indices_path = self._get_interleaveable_indices_path() # 计算索引的数量
            num_interleavable_instances = get_file_size(interleavable_indices_path) // item_size

            interleaving_exempt_indices_path = self._get_interleaving_exempt_indices_path()
            num_interleaving_exempt_instances = ( # 计算不参与interleave的instance的数量
                (get_file_size(interleaving_exempt_indices_path) // item_size)
                if interleaving_exempt_indices_path.is_file() # 如果指定了不参与interleave的doc
                else 0
            )

            self._num_interleavable_instances = num_interleavable_instances
            self._num_interleaving_exempt_instances = num_interleaving_exempt_instances
            self._num_instances = ( # 一组doc group构成了一个instance,所以doc的索引要除以组的大小才能得到instance的数量
                num_interleavable_instances // self._docs_per_instance
                + num_interleaving_exempt_instances
            )
        return self._num_instances

    def prepare(self):
        if self.fs_local_rank == 0:
            log.info("Gathering dataset document and interleaving indices...")
            self._write_instance_indices() # 用的是父类的方法,每个document切一个instance
            self._write_docs_interleaving_indices() # 这里才是自己的方法
        barrier()
        len(self)

    def _write_docs_interleaving_indices(self):
        interleavable_indices_path = self._get_interleaveable_indices_path()
        interleaving_exempt_indices_path = self._get_interleaving_exempt_indices_path()
        if interleavable_indices_path.is_file() and ( # 检查是否已经缓存过
            self._interleaving_exempt_paths is None or interleaving_exempt_indices_path.is_file()
        ): # 如果interleave的索引和exempt的索引都存在,或是interleave的索引存在且不需要做exempt
            log.info(
                f"Reusing all document interleaving indices at:\n'{interleavable_indices_path}'"
            )
        else:
            log.info(
                f"Generating all document interleaving indices to:\n'{interleavable_indices_path}..."
            )

            if self._interleaving_exempt_paths: # 指定了exempt的路径
                interleaving_exempt_doc_indices = [ # 整理全局区间
                    instance_num
                    for i_offset, (start, end) in enumerate(self.offsets)
                    for instance_num in range(start, end)
                    if self.paths[i_offset] in self._interleaving_exempt_paths
                ]

                with memmap_to_write( # 写exempt索引
                    interleaving_exempt_indices_path,
                    dtype=self.indices_dtype,
                    shape=(len(interleaving_exempt_doc_indices),),
                ) as interleaving_exempt_indices:
                    interleaving_exempt_indices[:] = interleaving_exempt_doc_indices

                interleavable_doc_indices = sorted( # 两个集合求差集,得到的就是允许参与interleave的doc的全局索引
                    set(range(self.offsets[-1][1])) - set(interleaving_exempt_doc_indices)
                )
            else: # 否则就全部拿来做interleave
                interleavable_doc_indices = list(range(self.offsets[-1][1]))

            with memmap_to_write( # 写interleave索引
                interleavable_indices_path,
                dtype=self.indices_dtype,
                shape=(len(interleavable_doc_indices),),
            ) as interleavable_indices:
                interleavable_indices[:] = get_rng(self._seed).permutation(
                    interleavable_doc_indices
                )

    def _remove_special_tokens_and_interleave(
        self,
        tensors: List[torch.Tensor],
        tensors_non_special_indices: List[Tuple[torch.Tensor, ...]],
    ) -> torch.Tensor:
        cleaned_tensors: List[torch.Tensor] = [ # 先把正文取出来
            tensor[non_special_indices]
            for tensor, non_special_indices in zip(tensors, tensors_non_special_indices)
        ]

        chunked_tensors = [ # 按指定chunk数量,把doc的内容切分
            cleaned_tensor.tensor_split(self._chunks_per_doc) for cleaned_tensor in cleaned_tensors
        ]
        return torch.cat( # docA [A1, A2, A3] docB [B1, B2, B3] docC [C1, C2, C3] → [A1, B1, C1, A2, B2, C2, A3, B3, C3]
            [
                chunked_tensor[i]
                for i in range(self._chunks_per_doc)
                for chunked_tensor in chunked_tensors
            ]
        )

    def __getitem__(self, index: int) -> Dict[str, Any]: # instance布局类似:[exempt instances][interleaved instances]
        index = int(index)  # in case this is a numpy int type.
        pos_index = index if index >= 0 else len(self) + index

        assert self._num_interleaving_exempt_instances is not None # 要求这个变量完成初始化
        if self._interleaving_exempt_paths and pos_index < self._num_interleaving_exempt_instances: # 判断是不是exempt instance
            interleaving_exempt_indices_path = self._get_interleaving_exempt_indices_path()
            doc_index = load_array_slice_into_tensor( # 如果是的话就取index,然后调用父类的方法去取就好了
                interleaving_exempt_indices_path,
                pos_index,
                pos_index + 1,
                self.indices_dtype,
            ).tolist()[0]

            return super().__getitem__(doc_index)
        # instance布局类似:[exempt instances][interleaved instances]
        pos_index -= self._num_interleaving_exempt_instances # 所以要先减掉exempt,构造interleaved的内部局部索引
        assert self._num_interleavable_instances is not None
        assert pos_index < self._num_interleavable_instances

        interleaving_indices_path = self._get_interleaveable_indices_path()
        interleaving_indices = load_array_slice_into_tensor( # 把这一组doc的索引取出来
            interleaving_indices_path,
            pos_index * self._docs_per_instance,
            pos_index * self._docs_per_instance + self._docs_per_instance,
            self.indices_dtype,
        ).tolist()

        docs: List[Dict[str, Any]] = []
        for doc_index in interleaving_indices: # 逐doc处理
            doc = super().__getitem__(doc_index) # 对于每个doc都使用的父类的方法读出slice

            # 因为是_docs_per_instance个文档拼在一起,所以读出来的原始文档要缩短这么多才能保证拼起来长度不超
            # Shrink the documents down, so that interleaving them does not exceed the sequence length.
            doc["input_ids"] = doc["input_ids"][: self.sequence_length // self._docs_per_instance]
            doc["label_mask"] = doc["label_mask"][: self.sequence_length // self._docs_per_instance]

            docs.append(doc)

        for doc in docs: # 检查字段,所有doc的字段必须统一
            if doc.keys() != docs[0].keys():
                raise RuntimeError(
                    f"Trying to interleave documents when dataset docs have different keys: {docs[0].keys()}, {doc.keys()}."
                )

        item: Dict[str, Any] = {}

        docs_non_special_token_indices = []
        for doc in docs: # 给特殊token生成mask,包括pad、eos、bos
            special_tokens_mask = torch.logical_or(
                doc["input_ids"] == self.pad_token_id, # doc["input_ids"]中的元素会与pad id逐一比较,值相等的会置True,其余置False
                doc["input_ids"] == self.eos_token_id,
            )
            if self.bos_token_id is not None:
                special_tokens_mask = torch.logical_or(
                    special_tokens_mask,
                    doc["input_ids"] == self.bos_token_id,
                )
            # 先把特殊字符去掉,再做interleave,最后再加上特殊字符,避免出现[BOS, BOS, token, token, EOS, EOS, PAD, PAD, ...]这种情况
            non_special_token_indices = torch.nonzero( # torch.nonezero获取非零元素(True)
                torch.logical_not(special_tokens_mask), # 做了非的操作,结合nonzero相当于获取False,也就是非特殊token位置
                as_tuple=True,
            )
            docs_non_special_token_indices.append(non_special_token_indices)

        item["input_ids"] = self._remove_special_tokens_and_interleave( # 按照整理好的索引取出正文内容
            [doc["input_ids"] for doc in docs], docs_non_special_token_indices
        )
        item["label_mask"] = self._remove_special_tokens_and_interleave( # 按照整理好的索引取出正文内容
            [doc["label_mask"] for doc in docs], docs_non_special_token_indices
        )

        # Add bos and tokens if there is space after interleaving. 这里再加上bos和eos
        if self.bos_token_id is not None and len(item["input_ids"]) < self.sequence_length:
            item["input_ids"] = F.pad(item["input_ids"], (1, 0), value=self.bos_token_id)
            item["label_mask"] = F.pad(item["label_mask"], (1, 0), value=True)
        if len(item["input_ids"]) < self.sequence_length:
            item["input_ids"] = F.pad(item["input_ids"], (0, 1), value=self.eos_token_id)
            item["label_mask"] = F.pad(item["label_mask"], (0, 1), value=True)

        pad_shape = (0, self.sequence_length - len(item["input_ids"])) # 之前移除的pad也要加回来
        item["input_ids"] = F.pad(item["input_ids"], pad_shape, value=self.pad_token_id)
        item["label_mask"] = F.pad(item["label_mask"], pad_shape, value=False)

        if "instance_mask" in docs[0]: # instance mask用于控制一个instance要不要参与训练
            item["instance_mask"] = all([doc["instance_mask"] for doc in docs]) # 使用的所有的doc都需要参与训练

        if "metadata" in docs[0]: # 有元信息则验证元信息,需要所有doc的元信息一致
            metadata = docs[0]["metadata"]
            for doc in docs:
                doc_metadata = docs[0]["metadata"] # 这里理应是doc[0]["metadata"],否则这行没有任何意义
                if metadata != doc_metadata:
                    raise RuntimeError(
                        f"Trying to interleave documents when dataset docs have different metadata: {metadata}, {doc_metadata}."
                    )
            item["metadata"] = metadata

        if "doc_lens" in docs[0]: #
            raise RuntimeError("Document lengths unexpectedly found.")

        return item

    def _get_instance_indices_path(self, source_path: PathOrStr) -> Path:
        return self._get_indices_path(
            "instance-indices", source_path, extra_ids=(str(self._docs_per_instance),)
        )

    def _get_interleaveable_indices_path(self) -> Path:
        return self.work_dir / f"dataset-{self.fingerprint}" / "interleavable-docs-indices.npy"

    def _get_interleaving_exempt_indices_path(self) -> Path:
        return (
            self.work_dir / f"dataset-{self.fingerprint}" / "interleaving-exempt-docs-indices.npy"
        )


@dataclass
class VSLCurriculum:
    """
    Base class for variable sequence length curriculums. These determine the sampling
    probability of batches from each bucket throughout training with a :class:`NumpyVSLDataset`.
    """

    @abstractmethod
    def batches_per_bucket(
        self, dataset: NumpyVSLDataset, global_batch_size: int
    ) -> List[Tuple[int, int]]:
        raise NotImplementedError

    @abstractmethod
    def get_batch_indices(
        self, batches_per_bucket: Sequence[Tuple[int, int]], seed: int
    ) -> np.ndarray:
        raise NotImplementedError

    def get_total_batches(self, batches_per_bucket: Sequence[Tuple[int, int]]) -> int:
        return sum([batches for _, batches in batches_per_bucket])

    def log_buckets(
        self,
        dataset: NumpyVSLDataset,
        global_batch_size: int,
        batches_per_bucket: Sequence[Tuple[int, int]],
    ):
        natural_batches_per_bucket = VSLNaturalCurriculum().batches_per_bucket(
            dataset, global_batch_size
        )
        for i, (seq_len, num_batches) in enumerate(batches_per_bucket):
            num_natural_batches = natural_batches_per_bucket[i][1]
            if num_batches != num_natural_batches:
                log.info(
                    f"- bucket {i}:   sequence length {seq_len:>6d} => {num_batches:>6d} batches "
                    f"used ({num_natural_batches:d} total)"
                )
            else:
                log.info(
                    f"- bucket {i}:   sequence length {seq_len:>6d} => {num_batches:>6d} batches"
                )

    @property
    @abstractmethod
    def short_str(self) -> str:
        """
        Return a unique human-readable identifier for the instance.
        """
        raise NotImplementedError


@dataclass
class VSLNaturalCurriculum(VSLCurriculum):
    """
    Implements the natural curriculum from
    `Dataset Decomposition: Faster LLM Training with Variable Sequence Length Curriculum
    <https://arxiv.org/pdf/2405.13226>`_.
    """

    def batches_per_bucket(
        self, dataset: NumpyVSLDataset, global_batch_size: int
    ) -> List[Tuple[int, int]]:
        batches_per_bucket = []
        for seq_len, num_instances in dataset.instances_per_bucket:
            instances_per_batch = global_batch_size // seq_len
            batches = num_instances // instances_per_batch
            batches_per_bucket.append((seq_len, batches))
        return batches_per_bucket

    def get_batch_indices(
        self, batches_per_bucket: Sequence[Tuple[int, int]], seed: int
    ) -> np.ndarray:
        total_batches = self.get_total_batches(batches_per_bucket)
        batch_indices = np.arange(total_batches, dtype=np.uint32)
        rng = get_rng(seed)
        # Put a batch with the largest sequence length first to catch OOMs early.
        idx = rng.integers(total_batches - batches_per_bucket[-1][1], total_batches)
        batch = batch_indices[idx]
        batch_indices[idx] = batch_indices[0]
        batch_indices[0] = batch
        rng.shuffle(batch_indices[1:])
        return batch_indices

    @property
    def short_str(self) -> str:
        return "vsl-natural"


@dataclass
class VSLGrowthCurriculum(VSLCurriculum):
    """
    A base class for growth curriculums, like :class:`VSLGrowP2Curriculum` and :class:`VSLGrowLinearCurriculum`.
    """

    num_cycles: int = 8
    """
    The number of cycles in the curriculum.
    """
    balanced: bool = False
    """
    Whether or not to balance the number of batches in each bucket.

    .. note::
        Balancing the number of batches requires dropping more data.
    """

    def batches_per_bucket(
        self, dataset: NumpyVSLDataset, global_batch_size: int
    ) -> List[Tuple[int, int]]:
        actual_batches_per_bucket = VSLNaturalCurriculum().batches_per_bucket(
            dataset, global_batch_size
        )
        if self.balanced:
            batches_per_bucket = min([batches for _, batches in actual_batches_per_bucket])
            batches_per_bucket = self.num_cycles * (batches_per_bucket // self.num_cycles)
            return [(seq_len, batches_per_bucket) for seq_len, _ in actual_batches_per_bucket]
        else:
            return [
                (seq_len, self.num_cycles * (batches_per_bucket // self.num_cycles))
                for seq_len, batches_per_bucket in actual_batches_per_bucket
            ]

    def get_cycle_distribution(
        self, indices: np.ndarray, batches_per_bucket: Sequence[Tuple[int, int]], cycle: int = 0
    ) -> List[List[int]]:
        cycle_length = indices.shape[0] // self.num_cycles
        cycle_indices = indices[cycle * cycle_length : (cycle * cycle_length) + cycle_length]
        distribution: List[List[int]] = []
        for subcycle in np.array_split(cycle_indices, len(batches_per_bucket)):
            distribution.append([])
            bucket_offset_start = 0
            bucket_offset_end = 0
            for _, num_batches in batches_per_bucket:
                bucket_offset_end += num_batches
                count = ((subcycle >= bucket_offset_start) & (subcycle < bucket_offset_end)).sum()
                distribution[-1].append(count)
                bucket_offset_start += num_batches
        return distribution

    def get_batch_indices(
        self, batches_per_bucket: Sequence[Tuple[int, int]], seed: int
    ) -> np.ndarray:
        # Shortest sequence length first.
        assert batches_per_bucket[0][0] < batches_per_bucket[-1][0]

        rng = get_rng(seed)
        num_buckets = len(batches_per_bucket)

        log.info(f"Constructing {self.__class__.__name__} curriculum with {num_buckets} buckets")

        cycles: List[np.ndarray] = []
        for cycle in range(self.num_cycles):
            # Now we need to chunk the batch indices *within* each bucket in this cycle into the batch
            # indices for each sub-cycle.
            # At the same time we'll translate those *within* bucket indices into global batch indices
            # by adding the right offset for each bucket.
            all_bucket_subcycle_batches: List[List[np.ndarray]] = []
            for bucket in range(num_buckets):
                # This is how many batches we'll pull from this bucket for each cycle.
                batch_counts_per_cycle_this_bucket = divide_into_buckets(
                    batches_per_bucket[bucket][1], self.num_cycles
                )
                # These are the batch indices *within* this bucket that we'll use for this cycle.
                batches_this_cycle_this_bucket = chunk_array(
                    np.arange(0, batches_per_bucket[bucket][1], dtype=np.uint32),
                    batch_counts_per_cycle_this_bucket,
                )[cycle]
                bucket_offset = sum([b for _, b in batches_per_bucket[:bucket]])
                bucket_subcycle_batch_counts = self._get_num_bucket_batches_for_cycle(
                    bucket, num_buckets, batch_counts_per_cycle_this_bucket[cycle]
                )
                bucket_subcycle_batches = chunk_array(
                    bucket_offset + batches_this_cycle_this_bucket, bucket_subcycle_batch_counts
                )
                all_bucket_subcycle_batches.append(bucket_subcycle_batches)

            # Now we'll build each full syb-cycle by concatenating all of the bucket sub-cycle batches
            # together and shuffling.
            all_subsycles: List[np.ndarray] = []
            for subcycle in range(num_buckets):
                subsycle_batches: List[np.ndarray] = []
                for bucket in range(num_buckets):
                    subsycle_batches.append(all_bucket_subcycle_batches[bucket][subcycle])
                res = np.concatenate(subsycle_batches)
                rng.shuffle(res)
                all_subsycles.append(res)
            del all_bucket_subcycle_batches

            # Finally we can concatenate all of the subsycles together to form the complete cycle.
            cycles.append(np.concatenate(all_subsycles))
            del all_subsycles

        indices = np.concatenate(cycles)
        del cycles

        # Make sure the very first batch has the longest sequence length (is from the last bucket).
        # That way OOMs should happen right away.
        final_bucket_start = sum([b for _, b in batches_per_bucket[:-1]])
        first_long_seq_len_batch = np.argmax(indices >= final_bucket_start)
        batch = indices[first_long_seq_len_batch]
        indices[first_long_seq_len_batch] = indices[0]
        indices[0] = batch

        assert indices.shape[0] == self.get_total_batches(batches_per_bucket)
        return indices

    @classmethod
    @abstractmethod
    def _get_bucket_odds_for_cycle(cls, bucket_idx: int, num_buckets: int) -> List[int]:
        raise NotImplementedError

    @classmethod
    def _get_num_bucket_batches_for_cycle(
        cls, bucket_idx: int, num_buckets: int, num_batches: int
    ) -> List[int]:
        odds = cls._get_bucket_odds_for_cycle(bucket_idx, num_buckets)
        divisor = sum(odds)
        props = [o / divisor for o in odds]
        out = []
        total = 0
        for p in props:
            n = round(p * num_batches)
            total += n
            out.append(n)
        if total < num_batches:
            out[-1] += num_batches - total
        return out


@dataclass
class VSLGrowP2Curriculum(VSLGrowthCurriculum):
    """
    Implements the "Grow-P2" curriculum from
    `Dataset Decomposition: Faster LLM Training with Variable Sequence Length Curriculum
    <https://arxiv.org/pdf/2405.13226>`_.
    """

    @classmethod
    def _get_bucket_odds_for_cycle(cls, bucket_idx: int, num_buckets: int) -> List[int]:
        all_odds = []
        start_odds = num_buckets - bucket_idx
        for cycle in range(num_buckets):
            exp = (
                start_odds + cycle
                if start_odds + cycle <= num_buckets
                else start_odds - ((start_odds + cycle) % num_buckets)
            )
            all_odds.append(2 ** (exp - 1))
        return all_odds

    @property
    def short_str(self) -> str:
        return f"vsl-grow-p2-{self.num_cycles}-cycle{'-balanced' if self.balanced else ''}"


@dataclass
class VSLGrowLinearCurriculum(VSLGrowthCurriculum):
    """
    Implements the "Grow-Linear" curriculum from
    `Dataset Decomposition: Faster LLM Training with Variable Sequence Length Curriculum
    <https://arxiv.org/pdf/2405.13226>`_.
    """

    @classmethod
    def _get_bucket_odds_for_cycle(cls, bucket_idx: int, num_buckets: int) -> List[int]:
        all_odds = []
        start_odds = num_buckets - bucket_idx
        for cycle in range(num_buckets):
            odds = (
                start_odds + cycle
                if start_odds + cycle <= num_buckets
                else start_odds - ((start_odds + cycle) % num_buckets)
            )
            all_odds.append(odds)
        return all_odds

    @property
    def short_str(self) -> str:
        return f"vsl-grow-linear-{self.num_cycles}-cycle{'-balanced' if self.balanced else ''}"


class NumpyVSLDataset(NumpyDatasetBase, Dataset[Dict[str, Any]]):
    """
    A variable sequence length (VSL) numpy array-backed dataset. This is used to inject a sequence
    length-based curriculum during training as introduced in
    `Dataset Decomposition: Faster LLM Training with Variable Sequence Length Curriculum
    <https://arxiv.org/pdf/2405.13226>`_.

    This dataset creates instances of token IDs with lengths that are powers of 2
    between ``min_sequence_length`` (which must be a power of 2) and ``max_sequence_length``
    (also a power a 2). Some tokens will be discarded unless ``min_sequence_length`` is 1.

    .. important::
        No special tokens are added to the input IDs so it's assumed that if you want
        EOS tokens between documents, for example, those will already be in the array.

    :param paths: Paths or URLs to numpy token ID arrays.
    :param pad_token_id: The ID of the padding token.
    :param eos_token_id: The ID of the EOS token.
    :param max_sequence_length: The maximum allowed sequence length. A power of 2, e.g. '4096'.
    :param min_sequence_length: The minimum allowed sequence length. A power of 2, e.g. '256'.
    :param curriculum: The variable sequence length curriculum. Determines the sampling
        probability of batches from each bucket throughout training.
    :param dtype: The numpy datatype of the arrays.
    :param metadata: Metadata to add to each item. This should be a dictionary or a list of dictionaries
        with the same number of items as there are paths.
    :param include_instance_metadata: If ``True`` (the default), each instance returned from
        :meth:`__getitem__()` will include the metadata from its source.
    """

    def __init__(
        self,
        *paths: PathOrStr,
        pad_token_id: int,
        eos_token_id: int,
        vocab_size: int,
        max_sequence_length: int,
        min_sequence_length: int = 256,
        curriculum: Optional[VSLCurriculum] = None,
        dtype: NumpyUIntTypes = np.uint16,
        metadata: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None,
        include_instance_metadata: Optional[bool] = None,
        instance_filter_config: Optional[InstanceFilterConfig] = None,
    ):
        if math.log(max_sequence_length, 2) % 1 != 0:
            raise OLMoConfigurationError("'max_sequence_length' must be a power of 2")

        if math.log(min_sequence_length, 2) % 1 != 0:
            raise OLMoConfigurationError("'min_sequence_length' must be a power of 2")

        if max_sequence_length <= min_sequence_length:
            raise OLMoConfigurationError(
                "'max_sequence_length' should be bigger than 'min_sequence_length'"
            )

        if include_instance_metadata is None and metadata:
            include_instance_metadata = True

        if isinstance(metadata, list):
            if len(metadata) != len(paths):
                raise OLMoConfigurationError(
                    "'metadata' should have the same length as the number of file paths"
                )
        else:
            metadata = [metadata or {}] * len(paths)

        super().__init__(
            *paths,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            vocab_size=vocab_size,
            dtype=dtype,
        )
        self._metadata = metadata
        self._include_instance_metadata = include_instance_metadata
        self._max_sequence_length = max_sequence_length
        self._min_sequence_length = min_sequence_length
        self._curriculum = curriculum or VSLNaturalCurriculum()
        self._num_instances: Optional[int] = None
        self._array_offsets: Optional[Tuple[Tuple[int, int], ...]] = None
        self._lengths_dtype: Optional[NumpyUIntTypes] = None
        self._instances_per_bucket: Optional[Tuple[Tuple[int, int], ...]] = None
        self.instance_filter_config = instance_filter_config

    @property
    def fingerprint_fields(self) -> Tuple[str, ...]:
        """
        Extra values to include when calculating the data contents :data:`fingerprint`.
        """
        return (
            "vocab_size",
            "pad_token_id",
            "eos_token_id",
            "dtype",
            "min_sequence_length",
            "max_sequence_length",
            "curriculum",
        )

    @property
    def max_sequence_length(self) -> int:
        return self._max_sequence_length

    @property
    def min_sequence_length(self) -> int:
        return self._min_sequence_length

    @property
    def curriculum(self) -> VSLCurriculum:
        return self._curriculum

    @property
    def all_sequence_lengths(self) -> List[int]:
        min_exp = int(math.log(self.min_sequence_length, 2))
        max_exp = int(math.log(self.max_sequence_length, 2))
        return [2**exp for exp in range(min_exp, max_exp + 1)]

    @property
    def offsets(self) -> Tuple[Tuple[int, int], ...]:
        """
        Gives the global start and end instance indices for each data file in the dataset.
        """
        if self._array_offsets is None:
            array_offsets = []
            item_size = self.indices_dtype(0).itemsize
            start_offset = 0
            for path in self.paths:
                doc_indices_path = self._get_document_indices_path(path)
                instances_in_file = (get_file_size(doc_indices_path) // item_size) // 2
                end_offset = start_offset + instances_in_file
                array_offsets.append((start_offset, end_offset))
                start_offset += instances_in_file
            self._array_offsets = tuple(array_offsets)
        return self._array_offsets

    def prepare(self):
        if self.fs_local_rank == 0:
            log.info("Gathering dataset document indices and buckets...")
            self._write_document_indices()
            self._write_instance_lengths()
            self._write_instance_buckets(self.get_instance_lengths())
        barrier()
        len(self)

    def __len__(self):
        if self._num_instances is None:
            self._num_instances = self.offsets[-1][1]
        return self._num_instances

    def __getitem__(self, index: int) -> Dict[str, Any]:
        index = int(index)  # in case this is a numpy int type.
        pos_index = index if index >= 0 else len(self) + index

        # The index of the array within 'self.paths'.
        array_index: Optional[int] = None
        # The index within the corresponding array.
        array_local_index: Optional[int] = None
        for i, (offset_start, offset_end) in enumerate(self.offsets):
            if offset_start <= pos_index < offset_end:
                array_index = i
                array_local_index = pos_index - offset_start
                break

        if array_index is None or array_local_index is None:
            raise IndexError(f"{index} is out of bounds for dataset of size {len(self)}")

        # Read the data from file.
        input_ids = self._read_chunk_from_array(self.paths[array_index], array_local_index)
        out: Dict[str, Any] = {"input_ids": input_ids}

        if self.instance_filter_config is not None:
            out["instance_mask"] = self._validate_instance(input_ids, self.instance_filter_config)

        if self._include_instance_metadata:
            metadata = self._metadata[array_index]
            out["metadata"] = deepcopy(metadata)

        return out

    def _read_chunk_from_array(self, path: PathOrStr, index: int) -> torch.Tensor:
        indices_path = self._get_document_indices_path(path)
        indices = load_array_slice_into_tensor(
            indices_path, index * 2, index * 2 + 2, self.indices_dtype
        )
        start_idx, end_idx = indices
        data = load_array_slice_into_tensor(path, int(start_idx), int(end_idx), self.dtype)
        return data

    def _get_document_indices_path(self, path: PathOrStr) -> Path:
        sha256_hash = hashlib.sha256()
        sha256_hash.update(str(path).encode())
        sha256_hash.update(str(self._get_file_size(path)).encode())
        for seq_len in self.all_sequence_lengths:
            sha256_hash.update(str(seq_len).encode())
        path_hash = sha256_hash.hexdigest()
        return self.work_dir / "dataset-common" / f"bucketed-doc-indices-{path_hash}.npy"

    def _get_instance_lengths_path(self) -> Path:
        return self.work_dir / f"dataset-{self.fingerprint}" / "instance-lengths.npy"

    def _get_instance_bucket_path(self, seq_len: int) -> Path:
        return self.work_dir / f"dataset-{self.fingerprint}" / f"bucket{seq_len}-indices.npy"

    def _write_document_indices(self):
        paths_needed: List[PathOrStr] = []
        for path in self.paths:
            indices_path = self._get_document_indices_path(path)
            if indices_path.is_file():
                log.info(f"Reusing document indices for '{path}' at:\n'{indices_path}'")
            elif path not in paths_needed:
                paths_needed.append(path)

        if paths_needed:
            with concurrent.futures.ProcessPoolExecutor() as executor:
                futures = []
                for path in paths_needed:
                    indices_path = self._get_document_indices_path(path)
                    log.info(f"Gathering document indices for '{path}'...")
                    future = executor.submit(
                        run_worker_func,
                        bucket_documents,
                        path,
                        indices_path,
                        buckets=self.all_sequence_lengths,
                        eos_token_id=self.eos_token_id,
                        dtype=self.dtype,
                        indices_dtype=self.indices_dtype,
                    )
                    futures.append(future)

                concurrent.futures.wait(futures, return_when="FIRST_EXCEPTION")

                # Log results.
                for path, future in zip(paths_needed, futures):
                    total_og_docs, total_bucketed_docs = future.result()
                    log.info(
                        f"Created {total_bucketed_docs:,d} bucketed documents by sequence length from "
                        f"{total_og_docs:,d} original documents in '{path}'"
                    )

    def _write_instance_lengths(self):
        instance_lengths_path = self._get_instance_lengths_path()
        if instance_lengths_path.is_file():
            log.info(f"Reusing all instance lengths at:\n'{instance_lengths_path}'")
        else:
            log.info(f"Gathering all instance lengths to:\n'{instance_lengths_path}...")
            with memmap_to_write(
                instance_lengths_path, dtype=self.lengths_dtype, shape=(len(self),)
            ) as instance_lengths:
                for path, (offset_start, offset_end) in zip(self.paths, self.offsets):
                    indices_path = self._get_document_indices_path(path)
                    indices_mmap = np.memmap(indices_path, dtype=self.indices_dtype, mode="r")
                    instance_lengths[offset_start:offset_end] = get_doc_lengths_from_indices(
                        indices_mmap
                    )
                    del indices_mmap

    def _write_instance_buckets(self, instance_lengths: np.ndarray):
        for seq_len in self.all_sequence_lengths:
            bucket_path = self._get_instance_bucket_path(seq_len)
            if bucket_path.is_file():
                log.info(
                    f"Reusing instance indices for seq len {seq_len} bucket at:\n'{bucket_path}'"
                )
            else:
                log.info(f"Gathering instance indices for seq len {seq_len} bucket...")
                bucket_path.parent.mkdir(exist_ok=True, parents=True)
                instance_indices = (instance_lengths == seq_len).nonzero()[0]
                with memmap_to_write(
                    bucket_path,
                    dtype=self.indices_dtype,
                    shape=instance_indices.shape,
                ) as bucket:
                    bucket[:] = instance_indices
                log.info(
                    f"Instance indices for seq len {seq_len} bucket written to:\n'{bucket_path}'"
                )

    def get_instance_lengths(self) -> np.ndarray:
        """
        Get a numpy memory-mapped array with the length of every instance in the dataset.
        """
        return np.memmap(self._get_instance_lengths_path(), dtype=self.lengths_dtype, mode="r")

    def get_instance_bucket(self, seq_len: int) -> np.ndarray:
        """
        Get the instance indices in a bucket.
        """
        return np.memmap(
            self._get_instance_bucket_path(seq_len), dtype=self.indices_dtype, mode="r"
        )

    def get_instance_buckets(self) -> List[Tuple[int, np.ndarray]]:
        """
        Get the buckets of instance indices that all have the same length.
        The buckets will be sorted from smallest sequence length to longest.
        """
        buckets = []
        for seq_len in self.all_sequence_lengths:
            buckets.append((seq_len, self.get_instance_bucket(seq_len)))
        return buckets

    @property
    def instances_per_bucket(self) -> Tuple[Tuple[int, int], ...]:
        """
        The number of instances in each bucket.
        """
        if self._instances_per_bucket is None:
            instances_per_bucket = []
            item_size = self.indices_dtype(0).itemsize
            for seq_len in self.all_sequence_lengths:
                instances_per_bucket.append(
                    (seq_len, get_file_size(self._get_instance_bucket_path(seq_len)) // item_size)
                )
            self._instances_per_bucket = tuple(instances_per_bucket)
        return self._instances_per_bucket

    @property
    def indices_dtype(self) -> NumpyUIntTypes:
        return np.uint32

    @property
    def lengths_dtype(self) -> NumpyUIntTypes:
        if self._lengths_dtype is None:
            for dtype in (
                np.uint8,
                np.uint16,
                np.uint32,
                np.uint64,
            ):
                if (self.max_sequence_length - 1) <= np.iinfo(dtype).max:
                    self._lengths_dtype = dtype
                    break
            assert self._lengths_dtype is not None
        return self._lengths_dtype


class VSLCurriculumType(StrEnum):
    """
    An enumeration of the different VSL curriculum implementations.
    """

    natural = "natural"
    """
    The natural curriculum ➡️ :class:`VSLNaturalCurriculum`.
    """

    grow_p2 = "grow_p2"
    """
    The "Grow-P2" curriculum ➡️ :class:`VSLGrowP2Curriculum`.
    """

    grow_linear = "grow_linear"
    """
    The "Grow-Linear" curriculum ➡️ :class:`VSLGrowLinearCurriculum`.
    """


@dataclass
class VSLCurriculumConfig(Config):
    name: VSLCurriculumType = VSLCurriculumType.natural
    num_cycles: Optional[int] = None
    balanced: Optional[bool] = None

    def validate(self):
        if self.name == VSLCurriculumType.natural:
            self.num_cycles = None
            self.balanced = None

    def build(self) -> VSLCurriculum:
        """
        Build the VSL curriculum.
        """
        if self.name == VSLCurriculumType.natural:
            if self.num_cycles is not None:
                raise OLMoConfigurationError(
                    f"'num_cycles' is not a valid field for the {self.name} curriculum"
                )
            if self.balanced is not None:
                raise OLMoConfigurationError(
                    f"'balanced' is not a valid field for the {self.name} curriculum"
                )
            return VSLNaturalCurriculum()

        if self.name in (VSLCurriculumType.grow_p2, VSLCurriculumType.grow_linear):
            if self.num_cycles is None:
                raise OLMoConfigurationError(
                    f"'num_cycles' is required for the {self.name} curriculum"
                )
            if self.balanced is None:
                raise OLMoConfigurationError(
                    f"'balanced' is required for the {self.name} curriculum"
                )

            if self.name == VSLCurriculumType.grow_p2:
                return VSLGrowP2Curriculum(num_cycles=self.num_cycles, balanced=self.balanced)
            else:  # grow_linear
                return VSLGrowLinearCurriculum(num_cycles=self.num_cycles, balanced=self.balanced)

        raise NotImplementedError(self.name)


NumpyDatasetConfigT = TypeVar("NumpyDatasetConfigT", bound="NumpyDatasetConfig")


@dataclass(kw_only=True)
class NumpyDatasetConfig(Config, ABC): # 抽象基类
    """
    Abstract base configuration class for numpy-based datasets.

    This abstract base class provides common configuration options and utilities
    for creating :class:`NumpyDatasetBase` datasets.
    """

    tokenizer: TokenizerConfig
    """
    The tokenizer config.
    """
    paths: Optional[List[str]] = None # path 和下面的 mix(数据集的name)只能选一个
    """
    The paths/URLs to the numpy token ID arrays.
    """
    mix: Optional[Union[str, DataMixBase]] = None # 和上面的 path 只能选一个
    """
    The name of a data mix (e.g. ``"dolma17"``).
    """
    mix_base_dir: Optional[str] = None
    """
    The base directory of the data mix.
    """
    expand_glob: bool = False
    """
    If True, treat the :data:`paths` as globs.
    """
    dtype: Optional[NumpyDatasetDType] = None # 可以显式指定也可以自动推断
    """
    The numpy datatype of the token ID arrays.
    """
    metadata: Optional[List[Dict[str, Any]]] = None
    """
    Metadata for the numpy arrays.
    """
    include_instance_metadata: bool = True
    """
    Whether or not to include the :data:`metadata` in the instances returned from
    :meth:`NumpyDatasetBase.__getitem__()`.
    """
    instance_filter_config: Optional[InstanceFilterConfig] = None
    """
    The instance filter config (aka the "ngram filter") that will be applied to the dataset. This
    can be used to filter out instances with too many repeated token ngrams.
    """
    source_permutation_seed: Optional[int] = None
    """
    Used to shuffle the source files before handing off to the dataset class.
    """
    work_dir: Optional[str] = None
    """
    The dataset working directory. This is used to cache working files like shuffled indices,
    instance buckets, etc.

    .. tip::
        You can save a lot of time and disk space by setting this to a common directory across
        all of you runs.
    """
    ignore_fingerprint_mismatch: bool = False
    """
    If True, ignore dataset fingerprint mismatches when loading from a checkpoint.
    This is used when intentionally switching to a different dataset mix.
    """

    @abstractmethod
    def build(self) -> NumpyDatasetBase: # 需要子类具体实现的抽象方法
        """
        Build and return a NumpyDatasetBase instance from this configuration.

        :returns: The constructed dataset instance.
        """
        raise NotImplementedError

    def get_dtype(self) -> NumpyUIntTypes: # 确定token array应该用哪种无符号整数(dtype)来读
        if self.dtype is not None: # 指定了则使用
            return NumpyDatasetDType(self.dtype).as_np_dtype()

        for dtype in (
            NumpyDatasetDType.uint8,
            NumpyDatasetDType.uint16,
            NumpyDatasetDType.uint32,
            NumpyDatasetDType.uint64,
        ): # 按vocab size自动选一个能装下token id的最小dtype
            if (self.tokenizer.vocab_size - 1) <= np.iinfo(dtype.as_np_dtype()).max:
                log.info(f"Assuming dtype '{dtype}' based on vocab size")
                return dtype.as_np_dtype()

        raise ValueError("vocab size too big!")

    def _expand_globs(self, patterns: Sequence[str]) -> List[str]: # 把pattern展开成文件列表
        expanded: List[str] = []
        for pattern in patterns:
            log.info(f"Expanding '{pattern}'...")
            matches = deterministic_glob_directory(pattern) # 用的是io中的匹配逻辑
            if not matches:
                error_msg = f"Pattern '{pattern}' did not match any files"
                # Add helpful hint for mix-0625 which has unavailable files
                if "0625" in pattern:
                    error_msg += (
                        "\n\nNOTE: Some files in OLMo-mix-0625 are not available. "
                        "If you are resuming training from a checkpoint that used mix-0625, you will need to "
                        "switch to a newer mix such as OLMo-mix-0925. To continue training with a different "
                        "dataset mix, set 'ignore_fingerprint_mismatch=True' in your NumpyDataLoaderConfig "
                        "to bypass the fingerprint mismatch error. This will probably result in a different data order!"
                    )
                raise FileNotFoundError(error_msg)
            for match in matches:
                log.info(f" - '{match}'")
            expanded.extend(matches)
        return expanded

    def _resolve_paths_metadata( # 解析成 paths + metadata + label_mask_paths
        self,
        *,
        allow_mix: bool,
        label_mask_paths: Optional[Sequence[PathOrStr]] = None,
    ) -> Tuple[List[str], Optional[List[Dict[str, Any]]], Optional[List[PathOrStr]]]:
        if self.paths is not None and self.mix is not None: # 给定path或是指定mix二者只能选其一
            raise OLMoConfigurationError("Only one of 'paths' or 'mix' can be set")

        metadata: Optional[List[Dict[str, Any]]] = self.metadata # 可选的metadata
        resolved_label_masks: Optional[List[PathOrStr]] = None # 解析结果的缓存

        if self.paths is not None: # path的解析分支
            raw_paths = [str(path) for path in self.paths]
            if self.expand_glob: # 指定了全局匹配
                paths = self._expand_globs(raw_paths) # 获取匹配到的路径
                if label_mask_paths is not None: # 如果指定了label mask的路径,则同样全局匹配
                    mask_patterns = [str(path) for path in label_mask_paths]
                    expanded_masks = self._expand_globs(mask_patterns)
                    resolved_label_masks = [cast(PathOrStr, mask) for mask in expanded_masks]
            else: # 否则就是可以直接访问的路径
                paths = raw_paths
                if label_mask_paths is not None:
                    resolved_label_masks = [cast(PathOrStr, path) for path in label_mask_paths]
        else: # mix的解析分支
            if self.mix is None: # 在这里就必须指定mix
                raise OLMoConfigurationError("Either 'paths' or 'mix' must be set")
            if not allow_mix: # 数据混合情况交给子类重写,这里是不允许数据混合
                raise OLMoConfigurationError("'mix' is not supported for this dataset type")
            if self.mix_base_dir is None: # 如果要求进行data mix那么就要给定路径
                raise OLMoConfigurationError(
                    "'mix_base_dir' is required to build a dataset from a mix"
                )
            if self.tokenizer.identifier is None:
                raise OLMoConfigurationError(
                    "Missing tokenizer identifier required to construct data mix"
                )
            mix = self.mix
            if not isinstance(mix, DataMixBase): # 类型判断和转换,用的是mixes中的类
                mix = DataMix(mix)
            paths, labels = mix.build(self.mix_base_dir, self.tokenizer.identifier) # 解析路径和label(比如longmino)
            paths = [str(path) for path in paths] # path先做类型转换到str
            if metadata is None: # 给metadata注入label,这里的label用于标注是哪个数据集,不是训练时的label
                metadata = [{"label": label} for label in labels]
            if label_mask_paths is not None: # 对label mask的路径做类型转换
                resolved_label_masks = [cast(PathOrStr, path) for path in label_mask_paths]

        if self.source_permutation_seed is not None: # 如果指定了用于shuffle path的种子,则对path进行shuffle
            order = list(range(len(paths))) # 生成顺序
            rng = random.Random(self.source_permutation_seed)
            rng.shuffle(order) # shuffle
            paths = [paths[i] for i in order] # 按shuffle顺序排序
            if metadata is not None: # 对metadata和label mask也按shuffle顺序排序
                metadata = [metadata[i] for i in order]
            if resolved_label_masks is not None:
                resolved_label_masks = [resolved_label_masks[i] for i in order]

        return paths, metadata, resolved_label_masks

    def _finalize(self, dataset: NumpyDatasetBase) -> NumpyDatasetBase:
        if self.work_dir is not None: # 指定了工作缓存
            dataset.work_dir = Path(self.work_dir) # 创建工作缓存path对象
        return dataset

    @classmethod
    def glob( # 注意到glob是一种类方法,最后的return则允许通过这个类方法将path按照glob path的模式构造一个类对象
        cls: Type[NumpyDatasetConfigT], *glob_paths: str, **kwargs: Any
    ) -> NumpyDatasetConfigT:
        """
        Initialize a dataset config with glob paths.

        .. note::
            Globs are not expanded until :meth:`build()` is called.
            If any of the globs don't expand to any matches a :class:`FileNotFoundError`
            error is raised

        :param glob_paths: The glob patterns.
        :returns: A new dataset config.
        """
        return cls(paths=list(glob_paths), mix=None, mix_base_dir=None, expand_glob=True, **kwargs)

    @classmethod
    def from_data_mix( # 类似的,这里是允许通过这个类方法构造一个使用mix做为datasetname的类对象
        cls: Type[NumpyDatasetConfigT],
        mix: Union[str, DataMixBase],
        *,
        tokenizer: TokenizerConfig,
        **kwargs: Any,
    ) -> NumpyDatasetConfigT:
        """
        Initialize a dataset config from an official data mix.

        :param mix: The data mix.
        :param tokenizer: The tokenizer config.
        :returns: A new dataset config.
        """
        if tokenizer.identifier is None:
            raise OLMoConfigurationError(
                "Missing tokenizer identifier required to construct data mix"
            )
        return cls(mix=mix, paths=None, tokenizer=tokenizer, **kwargs)


@dataclass
class NumpyFSLDatasetConfig(NumpyDatasetConfig): # 校验FSL参数,根据配置构造NumpyFSLDataset或是NumpyFSLDatasetMixture
    sequence_length: int # 相比基类多了sequence length
    """
    The length of a single instance. Generally this should correspond to your model's maximum input length.
    """
    max_target_sequence_length: Optional[int] = None # 这个也是比基类新增的
    """
    Optional upper bound used when precomputing cached offsets. 如果不做warm up的话可以不设置
    比如一开始使用sequence_length = 1024后续在某个阶段变成2048,再变成4096
    If you're planning a sequence-length warm-up, set this to the final chunk size so future
    datasets with larger ``sequence_length`` values can reuse the exact same document ordering.
    The current dataset still returns ``sequence_length``-token windows; this hint simply keeps
    token boundaries and cache files deterministic across warm-up stages. Leave ``None`` if you
    won't rebuild at a larger length.
    """
    generate_doc_lengths: bool = False # 是否在 __getitem__() 返回的 item 中包含每个 document 的长度信息
    """
    Include individual document lengths in the instances returned from
    :meth:`NumpyDatasetBase.__getitem__()`.
    """
    label_mask_paths: Optional[List[str]] = None #与token文件对应的bool mask文件路径,用于指示哪些token应该被 mask
    """
    The paths/URLs to numpy bool files indicating which tokens should be masked.
    """
    source_mixture_config: Optional[SourceMixtureDatasetConfig] = None # 来自source_mixture.py,描述数据混合情况
    """
    A source mixture dataset config. If set, the dataset will be built from a mixture of sources.
    """

    @classmethod
    def from_src_mix( # 与上面类似的类方法,从SourceMixtureDatasetConfig构造一个实例
        cls, src_mix: SourceMixtureDatasetConfig, **kwargs: Any
    ) -> NumpyFSLDatasetConfig:
        """
        Initialize a dataset config from a custom fine-grained data mix.

        :param src_mix: The fine-grained SourceMixtureDatasetConfig.
        :returns: A new dataset config.
        """ # 注意到这里指定了paths=None, mix=None,明确要求从source_mixture_config构造
        return cls(source_mixture_config=src_mix, paths=None, mix=None, mix_base_dir=None, **kwargs)

    def validate(self):
        if self.sequence_length <= 0: # 检查参数合法性
            raise OLMoConfigurationError("'sequence_length' must be positive")
        if self.source_mixture_config is not None: # paths / mix / source_mixture_config 三选一
            if self.paths is not None or self.mix is not None:
                raise OLMoConfigurationError(
                    "Specify only one of 'paths', 'mix', or 'source_mixture_config'"
                )
            if self.label_mask_paths is not None: # 不支持label_mask_paths和source_mixture_config同时出现
                raise OLMoConfigurationError( # 想通过label_mask_paths进行配置就只能选paths / mix构造
                    "'label_mask_paths' is not supported alongside 'source_mixture_config'"
                )

    def build(self) -> NumpyDatasetBase:
        self.validate()

        if self.source_mixture_config is not None: # 从source_mixture.py提供的config构造
            mixture = self.source_mixture_config.build(
                npdtype=self.get_dtype(), sequence_length=self.sequence_length
            )
            dataset = NumpyFSLDatasetMixture( # 解析并传参就行了,source_mixture_config只支持NumpyFSLDatasetMixture
                *mixture.to_paths(),
                seed=self.source_mixture_config.seed,
                path_offset_index=mixture.to_index(),
                sequence_length=self.sequence_length,
                max_target_sequence_length=self.max_target_sequence_length,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id,
                vocab_size=self.tokenizer.vocab_size,
                dtype=self.get_dtype(),
                metadata=self.metadata,
                include_instance_metadata=self.include_instance_metadata,
                generate_doc_lengths=self.generate_doc_lengths,
                bos_token_id=self.tokenizer.bos_token_id,
                instance_filter_config=self.instance_filter_config,
            )
            return self._finalize(dataset)
        # 否则就是从 path / mix 构造了,这种方式允许使用外部label mask
        paths, metadata, label_masks = self._resolve_paths_metadata(
            allow_mix=True, label_mask_paths=self.label_mask_paths
        )
        #  path / mix 只支持 NumpyFSLDataset 这一类型
        dataset = NumpyFSLDataset(
            *paths,
            sequence_length=self.sequence_length,
            max_target_sequence_length=self.max_target_sequence_length,
            pad_token_id=self.tokenizer.pad_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
            vocab_size=self.tokenizer.vocab_size,
            dtype=self.get_dtype(),
            metadata=metadata,
            include_instance_metadata=self.include_instance_metadata,
            generate_doc_lengths=self.generate_doc_lengths,
            bos_token_id=self.tokenizer.bos_token_id,
            instance_filter_config=self.instance_filter_config,
            label_mask_paths=label_masks,
        )
        return self._finalize(dataset)

# 下面的基本类似,简单看一下就好了
@dataclass(kw_only=True)
class NumpyPaddedFSLDatasetConfig(NumpyDatasetConfig):
    sequence_length: int
    """
    The length of a single instance. Generally this should correspond to your model's maximum input length.
    """
    label_mask_paths: Optional[List[str]] = None
    """
    The paths/URLs to numpy bool files indicating which tokens should be masked.
    """

    def validate(self):
        if self.sequence_length <= 0:
            raise OLMoConfigurationError("'sequence_length' must be positive")

    def build(self) -> NumpyDatasetBase:
        self.validate()
        paths, metadata, label_masks = self._resolve_paths_metadata(
            allow_mix=True, label_mask_paths=self.label_mask_paths
        )
        dataset = NumpyPaddedFSLDataset(
            *paths,
            sequence_length=self.sequence_length,
            pad_token_id=self.tokenizer.pad_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
            vocab_size=self.tokenizer.vocab_size,
            dtype=self.get_dtype(),
            bos_token_id=self.tokenizer.bos_token_id,
            metadata=metadata,
            include_instance_metadata=self.include_instance_metadata,
            instance_filter_config=self.instance_filter_config,
            label_mask_paths=label_masks,
        )
        return self._finalize(dataset)


@dataclass(kw_only=True)
class NumpyPackedFSLDatasetConfig(NumpyDatasetConfig):
    sequence_length: int
    """
    The length of a single instance. Generally this should correspond to your model's maximum input length.
    """
    generate_doc_lengths: bool = False
    """
    Include individual document lengths in the instances returned from
    :meth:`NumpyDatasetBase.__getitem__()`.
    """
    label_mask_paths: Optional[List[str]] = None
    """
    The paths/URLs to numpy bool files indicating which tokens should be masked.
    """
    long_doc_strategy: LongDocStrategy = LongDocStrategy.truncate
    """
    The strategy to use for handling long documents.
    """
    source_group_size: int = 1
    """
    The number of source npy files to process together when packing.
    """

    def validate(self):
        if self.sequence_length <= 0:
            raise OLMoConfigurationError("'sequence_length' must be positive")
        if self.source_group_size < 1:
            raise OLMoConfigurationError("'source_group_size' must be at least 1")

    def build(self) -> NumpyDatasetBase:
        self.validate()

        paths, metadata, label_masks = self._resolve_paths_metadata(
            allow_mix=True, label_mask_paths=self.label_mask_paths
        )

        dataset = NumpyPackedFSLDataset(
            *paths,
            sequence_length=self.sequence_length,
            pad_token_id=self.tokenizer.pad_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
            vocab_size=self.tokenizer.vocab_size,
            dtype=self.get_dtype(),
            metadata=metadata,
            include_instance_metadata=self.include_instance_metadata,
            generate_doc_lengths=self.generate_doc_lengths,
            bos_token_id=self.tokenizer.bos_token_id,
            instance_filter_config=self.instance_filter_config,
            long_doc_strategy=self.long_doc_strategy,
            label_mask_paths=label_masks,
            source_group_size=self.source_group_size,
        )
        return self._finalize(dataset)


@dataclass(kw_only=True)
class NumpyInterleavedFSLDatasetConfig(NumpyDatasetConfig):
    sequence_length: int
    """
    The length of a single instance. Generally this should correspond to your model's maximum input length.
    """
    docs_per_instance: int
    """
    The number of documents to include in each instance.
    """
    chunks_per_doc: int
    """
    The number of chunks to include in each document.
    """
    seed: int
    """
    The seed to use for the random number generator.
    """
    label_mask_paths: Optional[List[str]] = None
    """
    The paths/URLs to numpy bool files indicating which tokens should be masked.
    """
    interleaving_exempt_paths: Optional[List[str]] = None
    """
    The paths/URLs to numpy bool files indicating which tokens should be exempt from interleaving.
    """

    def validate(self):
        if self.sequence_length <= 0:
            raise OLMoConfigurationError("'sequence_length' must be positive")
        if self.docs_per_instance <= 0:
            raise OLMoConfigurationError("'docs_per_instance' must be positive")
        if self.chunks_per_doc <= 0:
            raise OLMoConfigurationError("'chunks_per_doc' must be positive")

    def build(self) -> NumpyDatasetBase:
        self.validate()

        paths, metadata, label_masks = self._resolve_paths_metadata(
            allow_mix=True, label_mask_paths=self.label_mask_paths
        )

        interleaving_exempt_paths = cast(Optional[List[PathOrStr]], self.interleaving_exempt_paths)

        dataset = NumpyInterleavedFSLDataset(
            *paths,
            sequence_length=self.sequence_length,
            pad_token_id=self.tokenizer.pad_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
            vocab_size=self.tokenizer.vocab_size,
            seed=self.seed,
            docs_per_instance=self.docs_per_instance,
            chunks_per_doc=self.chunks_per_doc,
            dtype=self.get_dtype(),
            metadata=metadata,
            include_instance_metadata=self.include_instance_metadata,
            instance_filter_config=self.instance_filter_config,
            label_mask_paths=label_masks,
            bos_token_id=self.tokenizer.bos_token_id,
            interleaving_exempt_paths=interleaving_exempt_paths,
        )
        return self._finalize(dataset)


@dataclass(kw_only=True)
class NumpyVSLDatasetConfig(NumpyDatasetConfig):
    max_sequence_length: int
    """
    The maximum sequence length. Generally this should correspond to your model's maximum input length.
    """
    min_sequence_length: int
    """
    The minimum sequence length.
    """
    vsl_curriculum: Optional[VSLCurriculumConfig] = None
    """
    The VSL curriculum config.
    """

    def validate(self):
        if self.max_sequence_length <= 0:
            raise OLMoConfigurationError("'max_sequence_length' must be positive")
        if self.min_sequence_length <= 0:
            raise OLMoConfigurationError("'min_sequence_length' must be positive")
        if self.min_sequence_length > self.max_sequence_length:
            raise OLMoConfigurationError(
                "'min_sequence_length' cannot exceed 'max_sequence_length'"
            )
        if self.tokenizer.bos_token_id is not None:
            raise OLMoConfigurationError("'bos_token_id' is not supported for the VSL dataset")
        if self.vsl_curriculum is not None:
            self.vsl_curriculum.validate()

    def build(self) -> NumpyDatasetBase:
        self.validate()

        paths, metadata, _ = self._resolve_paths_metadata(allow_mix=True)

        dataset = NumpyVSLDataset(
            *paths,
            max_sequence_length=self.max_sequence_length,
            min_sequence_length=self.min_sequence_length,
            curriculum=None if self.vsl_curriculum is None else self.vsl_curriculum.build(),
            pad_token_id=self.tokenizer.pad_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
            vocab_size=self.tokenizer.vocab_size,
            dtype=self.get_dtype(),
            metadata=metadata,
            include_instance_metadata=self.include_instance_metadata,
            instance_filter_config=self.instance_filter_config,
        )
        return self._finalize(dataset)

因文档上限限制,我们下一篇继续。

相关推荐
春风化作秋雨3 小时前
Transformer:颠覆AI的注意力革命
人工智能·深度学习·transformer
无忧智库3 小时前
算力、算法、数据三位一体:构建城市级AI大模型算力池的全景式解构与未来展望(WORD)
大数据·人工智能·算法
L-影3 小时前
下篇:它到底是怎么操作的——AI中半监督学习的类型与作用,以及为什么它成了行业的“最优解”
人工智能·学习·机器学习·ai·半监督学习
后端小肥肠3 小时前
OpenClaw多Agent实战|手把手教你用一只小龙虾接入多个飞书Bot
人工智能·aigc·agent
北京耐用通信3 小时前
从隔离到互联:工业现场中耐达讯自动化CC-Link IE转Modbus RTU实战指南
人工智能·科技·物联网·自动化·信息与通信
cyclejune3 小时前
5 个本地 AI Agent 自动化工作流实战
运维·人工智能·自动化·clawdbot·openclaw
m0_747304163 小时前
机器学习入门
人工智能·深度学习·机器学习
拾光向日葵3 小时前
洛阳科技职业学院2026年最新宿舍条件与周边环境全景测评
大数据·人工智能·物联网
hhy_smile3 小时前
人工智能时代如何对待编程
人工智能