Stanford CS336 Assignment 1 BPE Tokenizer trainning on tinystories dataset

3. Tinystories 数据集上训练

3.1. 阅读分析题目要求

这个题目有两个要求

  • a. vocab size 最大 10,000,确保将 speicial token "<|endoftext|>" 加入到 vocabulary。 资源要求 :训练时长 ≤ 30 minutes (no GPUs), 占用内存 ≤ 30GB RAM
    Tips:如果要在 2 分钟内完成训练,可以考虑多线程处理 pretokenize。
  • b. "tokenizer 训练过程中哪一部分最耗时?"

按照作业要求,我将分三步来完成:

  1. 编写训练脚本:包含加载数据、训练、保存模型、统计时间和内存。

  2. 运行并分析性能 (Profiling):回答瓶颈在哪里。

  3. 检查结果:找出最长的 Token。

3.2 逐步实现

  • 查看训练数据

    训练前,可以用 head 命令查看下训练数据,确认数据与之前测试的类型基本一致。

  • 代码关键实现分析

  • 获取项目根路径
    为了正确方便读写文件,工程中通常首先获取项目根路径。

    sys.path.append(os.path.abspath(os.path.join(os.path.dirname(file), "..")))

  • 获取进程运行时的内存

    def get_memory_usage_mb():
    """Get current process memory usage in MB"""
    process = psutil.Process(os.getpid())
    return process.memory_info().rss / 1024 / 1024

  • 将训练获得的 vocabulary 和 merges 规则写入磁盘进行持久化。
    vocabulary 结构是 Dict,所以保存为 json格式,并且为了人类阅读,将bytes decode 为字符串保存在 vocab.json 中,无法解码的保留 repr 形式。

    def save_vocab_and_merges(vocab, merges, output_dir="results"):
    """Save vocabulary and merges to disk"""
    Path(output_dir).mkdir(exist_ok=True, parents=True)

    复制代码
      # Save vocab.json
      vocab_str = {}
      for idx, token_bytes in vocab.items():
          try:
              vocab_str[idx] = token_bytes.decode('utf-8')
          except UnicodeDecodeError:
              vocab_str[idx] = str(token_bytes)
    
      with open(f"{output_dir}/vocab.json", "w", encoding="utf-8") as f:
          json.dump(vocab_str, f, ensure_ascii=False, indent=2)
          
      # Save merges.txt
      with open(f"{output_dir}/merges.txt", "w", encoding="utf-8") as f:
          for p1, p2 in merges:
              f.write(f"{p1.decode('utf-8', errors='ignore')} {p2.decode('utf-8', errors='ignore')}\n")
    
      print(f"Artifacts saved to {output_dir}/")
  • 定义训练 & 打印统计信息

    def run_training(input_path, vocab_size, special_tokens, output_dir):

    复制代码
      # record the initial memory usage before training
      print(f"Initial Memory: {get_memory_usage_mb():.2f} MB")
    
      # Initialize the BPE trainer
      trainer = BPETrainer_Optimized()
      
      # start training and record the time and memory usage
      start_time = time.time()
      print(f"Starting training on {input_path}...")
      
      vocab, merges = trainer.train(
          input_path=input_path,
          vocab_size=vocab_size,
          special_tokens=special_tokens
      )
      
      end_time = time.time()
      duration = end_time - start_time
      peak_memory = get_memory_usage_mb()
      
      print("-" * 30)
      print(f"Training Complete.")
      print(f"Time Taken: {duration:.2f} seconds ({duration/60:.2f} minutes)")
      print(f"Final Memory: {peak_memory:.2f} MB")
      print("-" * 30)
    
      save_vocab_and_merges(vocab, merges, output_dir)
      
      # Output Statistics information
      print("\n=== Statistics (Problem b) ===")
      
      # 1. Longest token
      longest_token_bytes = max(vocab.values(), key=len)
      try:
          longest_token_str = longest_token_bytes.decode('utf-8')
      except:
          longest_token_str = str(longest_token_bytes)
      print(f"Longest Token: {longest_token_str!r}")
      print(f"Length in bytes: {len(longest_token_bytes)}")
    
      # 2. Most frequent token (approximate, based on merge priority if we tracked it, 
      # but here we can just say the last merged token was the most frequent *at that step*)
      # The assignment asks for "most frequent token in the dataset"? 
      # Usually BPE doesn't keep full frequency counts of final vocab unless we re-tokenize.
      # We will just print the last merge which represents the most frequent pair remaining.
      print(f"Total Merges: {len(merges)}")

3.3 运行训练脚本

  1. 执行如下命令

    uv run python scripts/train_bpe_tinystories.py
    --input_path data/TinyStoriesV2-GPT4-train.txt
    --vocab_size 10000
    --profile

控制台可以看到如下输出

复制代码
Enabling cProfile...
Initial Memory: 23.41 MB
Starting training on data/TinyStoriesV2-GPT4-train.txt...
Starting BPE training (Optimized)...
Merge 0/9743: (b' ', b't')
Merge 100/9743: (b'ri', b'end')
...
Merge 9600/9743: (b' pain', b'ful')
Merge 9700/9743: (b'St', b'ill')
------------------------------
Training Complete.
Time Taken: 731.09 seconds (12.18 minutes)
Final Memory: 96.74 MB
------------------------------
Artifacts saved to tinystories_tokenizer/

=== Statistics (Problem b) ===
Longest Token: ' accomplishment'
Length in bytes: 15
Total Merges: 9743

Profiling data saved to training.prof
Use 'snakeviz training.prof' to visualize.

3.4 检查确认

  1. 0,255\] 为 256 个字符。

    输出 258: "256": "<|endoftext|>", 确认 "<|endoftext|>" 加入到 vocabulary, idx 为 256,符合预期。
    文件不大也可以打开查看。
  2. 257:\]: 训练完成得到的 vocabulary。

由于 toml 文件中没有 snakeviz,我们需要单独安装 snakeviz,

  1. 安装 snakeviz (如果还没安装)

    uv pip install snakeviz

  2. 启动可视化服务器

    uv run snakeviz training.prof

然后浏览器打开

如果本地不方便,可以使用 python 内置的 pstats 查看文本报告

复制代码
# 查看累计耗时(cumulative)排名前 10 的函数
uv run python -c "import pstats; p = pstats.Stats('training.prof'); p.sort_stats('cumulative').print_stats(10)"  

耗时最长的前10

复制代码
         1368017311 function calls (1368017210 primitive calls) in 731.144 seconds

   Ordered by: cumulative time
   List reduced from 251 to 10 due to restriction <10>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.029    0.029  731.144  731.144 /home/fdq/cources/cs336/assignment1-basics/scripts/train_bpe_tinystories.py:51(run_training)
        1    0.372    0.372  731.066  731.066 /home/fdq/cources/cs336/assignment1-basics/cs336_basics/train_bpe_optimize.py:18(train)
        1  369.361  369.361  460.691  460.691 /home/fdq/cources/cs336/assignment1-basics/cs336_basics/train_bpe_optimize.py:205(_pretokenize)
9746/9745   83.735    0.009  265.975    0.027 {built-in method builtins.max}
369218707  118.986    0.000  182.240    0.000 /home/fdq/cources/cs336/assignment1-basics/cs336_basics/train_bpe_optimize.py:52(<lambda>)
372023448   63.614    0.000   63.614    0.000 {method 'get' of 'dict' objects}
536592168   44.123    0.000   44.123    0.000 {method 'group' of '_regex.Match' objects}
  2761194    2.553    0.000   39.444    0.000 /home/fdq/cources/cs336/assignment1-basics/.venv/lib/python3.11/site-packages/regex/regex.py:340(finditer)
  2804690    8.909    0.000   36.517    0.000 /home/fdq/cources/cs336/assignment1-basics/.venv/lib/python3.11/site-packages/regex/regex.py:449(_compile)
  2761196    3.285    0.000   12.991    0.000 /home/fdq/.local/share/uv/python/cpython-3.11.12-linux-x86_64-gnu/lib/python3.11/locale.py:679(getpreferredencoding)

根据 Profiling (性能分析) 结果揭示了程序的运行状况!这对于回答作业中的问题 (b) 以及后续优化至关重要。

1. 深度分析 Profiling 结果

根据 pstats 文本和 snakeviz 可视化页面,我们可以清晰地得出结论:

瓶颈一:Pre-tokenization (最耗时)

  • 原因 1:这是单线程运行的。Python 的正则引擎虽然快,但要处理的是几 GB 的文本数据。一个 CPU 核心逐字逐句地扫描、匹配、计数,必然是慢的。
  • 证据_pretokenize 函数占据了 460.69 秒 (约占总时间的 63%)。
  • 作业提示 :这就是为什么作业 Hint 提到 "using multiprocessing during pretokenization"
  • 原因 2:正则匹配耗时
  • 证据:finditer 和 _compile 以及 group 方法被调用了数百万次,占用了大量时间。我是在循环外编译的,

瓶颈二:寻找最佳 Pair (max 操作)

  • 证据builtins.max 耗时 265.97 秒

  • 细节分析

  • 调用了 max 约 1 万次(每次合并一次)。

  • 但关键在于那个 lambda 函数被调用了 3.69 亿次

  • 369218707 ... <lambda>372023448 ... {method 'get' of 'dict' objects}

  • 罪魁祸首

找到这行代码

复制代码
max(pair_counts, key=lambda x: (pair_counts[x], pair_strings.get(x, (b'', b''))))

当pair 频次相同时,为了打破平局(Tie-breaking),再按 pair 的字典序排序,对于字典里的每一个 Pair,都调用了一次 pair_strings.get(x)。这在 3.69 亿次调用下,带来了巨大的累积开销 (约 63秒花在 get 上,118秒花在 lambda 本身逻辑上)。


2. 回答作业问题 (b)

(b) Profile your code. What part of the tokenizer training process takes the most time?

中文: "性能分析显示,预分词 (pre-tokenization) 步骤 是最耗时的部分,约占总运行时间的 63% (460秒),这是由于单线程正则处理大量语料造成的。第二大的瓶颈是合并循环中的 max 操作 (265秒),主要是因为在处理平局的 lambda 函数中频繁进行字典查找 (pair_strings.get) 带来的开销。

" 英文回答: "The profiling results indicate that the pre-tokenization step is the most time-consuming part, accounting for approximately 63% of the total runtime (460s) due to the single-threaded regex processing of the large corpus. The second largest bottleneck is the max operation in the merge loop (265s), specifically caused by the overhead of dictionary lookups (pair_strings.get) inside the tie-breaking lambda function."


3. (可选但推荐) 冲击 "2分钟" 目标:代码优化

虽然现在的 12 分钟已经远低于作业要求的 30 分钟,但如果想体验一下极致优化的快感(以及满足 Hint 里的 2 分钟目标),需要做两件事:

第一步:优化 max 的 Tie-breaking (极速修正)

这个只需要修改一行代码,我就先做了,结果翻车了。。。。

max 里做这种复杂的 get 查找太慢了。我们可以利用 Python 元组比较的特性,直接比较 pair 本身(它是 int tuple,比如 (12, 34)),这比查字典找 bytes 快得多。

修改 train_bpe_optimize.py 中的这一行:

复制代码
# 原来的代码 (慢)
# merge_pair = max(pair_counts, key=lambda x: (pair_counts[x], pair_strings.get(x, (b'', b''))))

# 修改后的代码 (快)
# 解释:如果频率一样,Python 会自动比较 key 本身 (id_1, id_2)。
# 这种比较是 C 语言层面的,极快,且结果是确定的。
merge_pair = max(pair_counts, key=lambda x: (pair_counts[x], x))

profiling 结果:

复制代码
   List reduced from 251 to 10 due to restriction <10>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.029    0.029  713.010  713.010 /home/fdq/cources/cs336/assignment1-basics/scripts/train_bpe_tinystories.py:51(run_training)
        1    0.355    0.355  712.921  712.921 /home/fdq/cources/cs336/assignment1-basics/cs336_basics/train_bpe_optimize.py:18(train)
        1  364.844  364.844  452.898  452.898 /home/fdq/cources/cs336/assignment1-basics/cs336_basics/train_bpe_optimize.py:205(_pretokenize)
9746/9745   83.417    0.009  255.732    0.026 {built-in method builtins.max}
369218707  114.033    0.000  172.315    0.000 /home/fdq/cources/cs336/assignment1-basics/cs336_basics/train_bpe_optimize.py:52(<lambda>)
372023448   58.602    0.000   58.602    0.000 {method 'get' of 'dict' objects}
536592168   44.608    0.000   44.608    0.000 {method 'group' of '_regex.Match' objects}
  2761194    2.350    0.000   35.984    0.000 /home/fdq/cources/cs336/assignment1-basics/.venv/lib/python3.11/site-packages/regex/regex.py:340(finditer)
  2804690    8.114    0.000   33.361    0.000 /home/fdq/cources/cs336/assignment1-basics/.venv/lib/python3.11/site-packages/regex/regex.py:449(_compile)
  2761196    3.020    0.000   11.858    0.000 /home/fdq/.local/share/uv/python/cpython-3.11.12-linux-x86_64-gnu/lib/python3.11/locale.py:679(getpreferredencoding)

但是,我修改了之后并没有改善,而且从 profiling 看,还在调用 dict.get,改动没有生效????我打算暂时放过这里,把经历放在 pretokenize 的多进程上,这是assignment 的建议。

接下来我把重点放在使用 multiprocessing 实现 pretokenize上。

相关推荐
nuowenyadelunwen1 天前
CS336 Assignment 1:BEP Tokenizer 2 - 分块读取文档 & 预分词
bpe tokenizer·stanford 336·分块读取文件·assignment 1
nuowenyadelunwen1 天前
Stanford CS336 Assignment 1 BPE Tokenizer 优化
bpe tokenizer·stanford 335·llm scratch
nuowenyadelunwen9 天前
Stanford CS336 Assignment 1: BPE Tokenizer
llm·bpe tokenizer·stanford cs336
nuowenyadelunwen9 天前
Stanford CS336 Language Models from Scratch-Assignment 1 OVerview
大语言模型llm·stanford cs336·cs336assignment
爱听歌的周童鞋23 天前
斯坦福大学 | CS336 | 从零开始构建语言模型 | Spring 2025 | 笔记 | Assignment 1: BPE Tokenizer
llm·assignment·cs336·bpe tokenizer
ahead~7 个月前
【大模型入门】访问GPT_API实战案例
人工智能·python·gpt·大语言模型llm
Fuliy962 年前
大语言模型 API
人工智能·语言模型·自然语言处理·api·学习资料·大语言模型llm
姚家湾2 年前
利用大语言模型(KIMI)构建控制信息模型
大数据·json·大语言模型llm·物联网·