Stanford CS336 Assignment 1 BPE Tokenizer 优化

1. 瓶颈在哪里

回顾下 BPETrainer_Simple 中,循环体内部是这样的:

复制代码
for merge_idx in range(num_merges):
    # 1. 极其耗时:每次都要遍历所有单词,重新计算所有 Pair 的频率
    pair_counts = self._count_pairs(word_counts, word_encodings, ...)
    
    # 2. 找到频率最高的 pair
    merge_pair = max(pair_counts, ...)
    
    # 3. 极其耗时:再次遍历所有单词,进行字符串替换
    for word in word_encodings:
        self._merge_encoding(...)

可以发现

  1. 在计算 pair 的频率时每次都要遍历所有 word,并找到频率最大的 pair;
  2. pair merge 操作后更新包含该 pair 的word 的 encoding 时又是遍历所有 word。

假设有 10,000 个单词,要进行 1,000 次合并。你的代码执行了 1,000×10,000≈1071,000×10,000≈107 次扫描操作。

2. 如何优化

  1. 全局 Pair 计数缓存 (pair_counts):不再每次重新数,而是维护一个全局的计数器。
  2. 倒排索引 (pair_to_words):记录"哪个 Pair 出现在了哪些单词里"。当合并一个 pair 时,只去更新那些包含这个 pair 的单词,而不是扫描所有单词。

引入**倒排索引(Inverted Index)**的思想是:

  • pair_to_words: Dict[Tuple, Set[str]]
    • key: 是一个 pair, 如 ('e', 's')
    • value 是一个 set, 包含所有包含这个 pair 的word,如 {'newest', 'widest', ...}

2.1 优化后的流程

  1. 初始化(只做一次) ,扫描所有 word,建立全局的 pair_counts 和 pair_to_words
  2. 循环合并
  3. pair_counts 中直接拿到最大频率的 pair,时间复杂度(O(1)O(K))。
  4. 从 pair_to_words 找到只受影响的 words(Affected Words)。
  5. 只对这些受影响的 words 进行更新。
    • 减去这些 words 中旧 pair 的计数。
    • 在 words 中执行合并。
    • 加上这些 words 新产生的 pair 的计数。
      这种方式下,如果合并 ('e', 's') 只影响了字典里 1% 的单词,你就只需要处理这 1% 的数据,速度提升是数量级的。

2.2 举例说明

用文档中方的语料

复制代码
low low low low low
lower lower widest widest widest
newest newest newest newest newest newest

在经过 preTokenize 后得到的词频统计如下:
{low: 5, lower: 2, newest: 6, widest: 3}

为了方便演示,我这里假设此时字符已经编码为 ID,但这里为了直观,我直接用字符表示。

第一步:初始化 (Pre-tokenize & Count)

优化代码中的 _pretokenize 和初始 _count_pairs 执行后,内存里的状态如下:
1. word_encodings (当前单词的状态):

  • low: ['l', 'o', 'w']
  • lower: ['l', 'o', 'w', 'e', 'r']
  • newest: ['n', 'e', 'w', 'e', 's', 't']
  • widest: ['w', 'i', 'd', 'e', 's', 't']

2. pair_counts (全局计数器):

  • ('l', 'o'): 5 (from low) + 2 (from lower) = 7
  • ('o', 'w'): 5 (from low) + 2 (from lower) = 7
  • ('e', 's'): 6 (from newest) + 3 (from widest) = 9
  • ('s', 't'): 6 (from newest) + 3 (from widest) = 9
  • ('w', 'e'): 2 (from lower) = 2 ... 等等

3. pair_to_words (倒排索引 - 关键优化):

  • ('e', 's') -> {'newest', 'widest'} <-- 注意这里,只有这两个词包含 e,s
  • ('l', 'o') -> {'low', 'lower'}
  • ...

第二步:开始训练循环

Iteration 1:

  1. Find Max:pair_counts 中找到频率最高的。假设是 ('e', 's'),计数为 9。
  2. Merge: 我们决定将 es 合并为新 Token es (假设 ID 为 256)。
  3. Find Affected Words:
  • pair_to_words[('e', 's')]
  • 得到受影响的单词列表:{'newest', 'widest'}
  • 重点: 我们完全忽略 lowlower,因为它们不包含 ('e', 's')
  1. Update Affected Words (函数 _updated_affected_word_count): 我们以 newest (频率 6) 为例,演示如何增量更新:
  • 旧状态 : ['n', 'e', 'w', 'e', 's', 't']
  • 待合并 Pair : ('e', 's')

(A) 扣除旧 Pair 的计数 (Decrement): 在进行合并前,我们要把 newest 里将会被破坏的 Pair 的计数减掉。

  • ('w', 'e'): 它的右边 e 要被合并了,这个 Pair 即将消失。pair_counts[('w', 'e')] -= 6
  • ('e', 's'): 这是我们要合并的,计数减掉。pair_counts[('e', 's')] -= 6
  • ('s', 't'): 它的左边 s 要被合并了,这个 Pair 即将消失。pair_counts[('s', 't')] -= 6
  • 同时,从 pair_to_words 对应的 Set 中移除 newest

(B) 执行合并 (Merge):

  • newest 变为: ['n', 'e', 'w', 'es', 't'] (注意:这里假设第一个 e 不受影响,只合并后面的 es)

(C) 增加新 Pair 的计数 (Increment): 新产生了哪些 Pair?

  • ('w', 'es'): 新产生的。pair_counts[('w', 'es')] += 6。并将 newest 加入 pair_to_words[('w', 'es')]
  • ('es', 't'): 新产生的。pair_counts[('es', 't')] += 6。并将 newest 加入 pair_to_words[('es', 't')]

widest (频率 3) 重复上述 A, B, C 步骤。

Iteration 1 结束:

  • pair_counts 已经是最新的了(不需要重新扫描全库)。
  • word_encodings 里的 newestwidest 已经更新。
  • lowlower 保持原样,甚至没有被读取过。

总结

之前的 BPETrainer_Sampler 和当前优化版的区别就像是:

  • 简单版(老师点名) :每次有人迟到,老师就重新从花名册第一个人点到最后一个人,数一遍谁没来。
  • 当前优化版(请假条制度) :老师手里有个总人数。每次有人迟到(合并),只需要处理那几个迟到的人 ,总人数减去请假的人,加上新来的人(如果有),不需要管那些正常坐在座位上的同学。
    我写完,让 Gemini 帮我review我的笔记和代码,并告诉它背景,我打算以博客形式分享到网上的,有这么一句,我觉得它说的挺好的
    我们将像修房子一样,先打地基(数据结构),再修墙(辅助函数),最后封顶(主循环)。
    而且它给的代码加的注释也挺好的。

2.3 代码实现

第一步:引入倒排索引(改造"地基"-来自Gemini)

需要在代码中引入一个新的字典: pair_to_words

旧逻辑:只有 pair_counts(记分牌)

新逻辑:增加 pair_to_words(索引表),用来记录某个 pair 出现在哪些 words 中。

修改 _count_pairs 函数,不仅返回计数 pair_counts,还需要返回谁拥有这些 pair 的记录表 pair_to_words

实现如下:

复制代码
# before:just return  pair_counts
# now:add pair_to_words

def _count_pairs(
    self,
    word_counts: Dict[str, int],
    word_encodings: Dict[str, List[int]],
    pair_strings: Dict[Tuple[int, int], Tuple[bytes, bytes]], 
    vocabulary: Dict[int, bytes],
    pair_to_words: Dict[Tuple[int, int], set] # <--- new patameter added
) -> Dict[Tuple[int, int], int]:
    
    pair_counts = defaultdict(int)
    
    for word, count in word_counts.items():
        encoding = word_encodings[word]
        # loop all Pairs
        for i in range(len(encoding) - 1):
            pair = (encoding[i], encoding[i+1])
            
            # 1. Count (same to last version)
            pair_counts[pair] += count
            
            # 2. [added] record inverted index:this pair included in this word
            pair_to_words[pair].add(word) 
            
            # 3. record  the bytes of the pair(same to last version)
            if pair not in pair_strings:
                pair_strings[pair] = (vocabulary[pair[0]], vocabulary[pair[1]])
                
    return pair_counts

第二步:编写核心"手术刀"函数 (增量更新)

这是最难的一步。以前是"全量扫描",现在我们要写一个函数,专门处理单个单词的更新。

想象这个函数是一个会计,它在修改单词结构时,必须保证账目(pair_counts)平配。 逻辑是:先退款(减旧计数),再合并,最后重新下单(加新计数)。

在类中添加下面新函数:

复制代码
def _update_word_for_merge(
    self,
    word: str,
    merge_pair: Tuple[int, int],
    new_token_id: int,
    word_counts: Dict[str, int],
    word_encodings: Dict[str, List[int]],
    pair_counts: Dict[Tuple[int, int], int],
    pair_to_words: Dict[Tuple[int, int], set],
    pair_strings: Dict[Tuple[int, int], Tuple[bytes, bytes]],
    vocabulary: Dict[int, bytes]
):
    encoding = word_encodings[word]
    count = word_counts[word]
    
    # --- 阶段 1: "退款" (Decrement) ---
    # 在修改编码前,先把它内部所有旧 Pair 的计数减掉,并从索引中移除
    # minus count for these word include the pair
    for i in range(len(encoding) - 1):
        old_pair = (encoding[i], encoding[i+1])
        pair_counts[old_pair] -= count
        pair_to_words[old_pair].discard(word) # 从该 pair 的名单里把这个词删掉
        # 如果计数归零,可以清理(可选,为了省内存)
        if pair_counts[old_pair] <= 0:
            del pair_counts[old_pair]

    # --- 阶段 2: "合并" (Merge) ---
    # merge 
    new_encoding = []
    i = 0
    while i < len(encoding):
        if i < len(encoding) - 1 and (encoding[i], encoding[i+1]) == merge_pair:
            new_encoding.append(new_token_id)
            i += 2
        else:
            new_encoding.append(encoding[i])
            i += 1
    
    # 更新单词的编码
    word_encodings[word] = new_encoding

    # --- 阶段 3: "重新下单" (Increment) ---
    # 遍历新的编码,把新产生的 Pair 计数加上,并加入索引
    # add the count of the new pair 
    for i in range(len(new_encoding) - 1):
        new_pair = (new_encoding[i], new_encoding[i+1])
        pair_counts[new_pair] += count
        pair_to_words[new_pair].add(word) # 把这个词加入新 pair 的名单
        
        # 顺便记录一下 pair 的 string (画图用或调试用)
        if new_pair not in pair_strings:
            pair_strings[new_pair] = (vocabulary[new_pair[0]], vocabulary[new_pair[1]])

第三步:重组"主循环" (Train Function)

现在我们要把 train 函数里的旧循环拆掉,换成新的逻辑。

主要改动点:

  1. 初始化时创建 pair_to_words

  2. 循环内不再 调用 _count_pairs

  3. 循环内通过 pair_to_words[merge_pair] 找到受影响的词。

    复制代码
     def train(self, input_path, vocab_size, special_tokens, **kwargs):
         # ... (前面的代码不变:初始化 vocabulary, pre-tokenize, word_encodings) ...
         # 1. Initialize the vocabulary
         vocabulary = {i: bytes([i]) for i in range(N_BYTES)}
         size = N_BYTES
         for token in special_tokens:
             vocabulary[size] = token.encode('utf-8')
             size += 1
         
         # 2. Pre-tokenize
         word_counts = self._pretokenize(input_path, special_tokens)
    
         # 3. Initialize the word encodings
         word_encodings = {word: list(word.encode('utf-8')) for word in word_counts}
    
         print(f"Starting BPE training...")
         merges = []
         
         # --- 【改动开始】 ---
         
         # 准备数据结构
         pair_strings = {}
         pair_to_words = defaultdict(set) # <--- 新增
         
         # 4. 初始全量统计 (只做这一次!)
         pair_counts = self._count_pairs(
             word_counts, word_encodings, pair_strings, vocabulary, pair_to_words
         )
    
         num_merges = vocab_size - size
         
         for merge_idx in range(num_merges):
             if not pair_counts:
                 break
                 
             # a. 找到频率最高的 pair
             # 这里的逻辑不变,只是数据源 pair_counts 是实时维护的
             # 注意:这里加了 .get 防止 pair_strings 某些极端情况缺key报错
             merge_pair = max(pair_counts, key=lambda x: (pair_counts[x], pair_strings.get(x, (b'', b''))))
             
             # b. 更新 vocabulary 和 merges
             new_token_id = size
             vocabulary[new_token_id] = vocabulary[merge_pair[0]] + vocabulary[merge_pair[1]]
             merges.append((vocabulary[merge_pair[0]], vocabulary[merge_pair[1]]))
             size += 1
             
             # c. 【核心优化】只更新受影响的单词
             # 注意:我们要把 set 转成 list,因为我们在遍历过程中会修改 set (discard操作)
             affected_words = list(pair_to_words[merge_pair])
             
             for word in affected_words:
                 self._update_word_for_merge(
                     word, merge_pair, new_token_id,
                     word_counts, word_encodings, pair_counts, 
                     pair_to_words, pair_strings, vocabulary
                 )
                 
             # 打印进度 (可选)
             if merge_idx % 100 == 0:
                 print(f"Merge {merge_idx}/{num_merges}: {merges[-1]}")
    
         # --- 【改动结束】 ---
         
         return vocabulary, merges

测试验证

修改 addpters.py 文件,执行测试命令 uv run pytest tests/test_train_bpe.py 测试结果

复制代码
================================================================================ test session starts =================================================================================
platform linux -- Python 3.11.12, pytest-8.4.1, pluggy-1.6.0
rootdir: /home/fdq/cources/cs336/assignment1-basics
configfile: pyproject.toml
plugins: jaxtyping-0.3.2
collected 3 items                                                                                                                                                                    

tests/test_train_bpe.py::test_train_bpe_speed Starting BPE training (Optimized)...
Merge 0/243: (b' ', b't')
Merge 100/243: (b'i', b'd')
Merge 200/243: (b' ', b'T')
PASSED
tests/test_train_bpe.py::test_train_bpe Starting BPE training (Optimized)...
Merge 0/243: (b' ', b't')
Merge 100/243: (b'i', b'd')
Merge 200/243: (b' ', b'T')
PASSED
tests/test_train_bpe.py::test_train_bpe_special_tokens Starting BPE training (Optimized)...
Merge 0/743: (b'h', b'e')
Merge 100/743: (b'a', b'll')
Merge 200/743: (b' c', b'at')
Merge 300/743: (b' ', b'other')
Merge 400/743: (b' say', b's')
Merge 500/743: (b' e', b'ach')
Merge 600/743: (b' h', b'ard')
Merge 700/743: (b'b', b'y')
PASSED

================================================================================= 3 passed in 1.99s =================================================================================

可以看到 三个都通过了。

现在可以在tinystories 数据集上训练了。

附完整代码

python 复制代码
"""
BPE Trainer Optimized Version:
    save time throug record the affected words when merge a pair
"""


import regex as re
from collections import defaultdict, Counter
from typing import Dict, List, Tuple

N_BYTES = 256
CHUNK_SIZE = 1024 * 50

class BPETrainer_Optimized:
    def __init__(self):
        self.pattern = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

    def train(
        self,
        input_path: str,
        vocab_size: int,
        special_tokens: List[str],
        **kwargs
    ) -> Tuple[Dict[int, bytes], List[Tuple[bytes, bytes]]]:

        # like simple version, initialize the vocabulary, pre-tokenize
        # and initialize the word encodings
        vocabulary = {i: bytes([i]) for i in range(N_BYTES)}
        size = N_BYTES
        for token in special_tokens:
            vocabulary[size] = token.encode('utf-8')
            size += 1
        
        word_counts = self._pretokenize(input_path, special_tokens)
        word_encodings = {word: list(word.encode('utf-8')) for word in word_counts}

        print(f"Starting BPE training (Optimized)...")
        merges = []
        num_merges = vocab_size - size
        # Initialize the pair_strings and pair_to_words
        # pair_to_words: Dict[Tuple[int, int], Set[str]],
        #  new data structure store the affected words encodings need to be updates
        pair_strings = {}
        pair_to_words = defaultdict(set)
        # Initialize the all pair_counts
        pair_counts = self._count_pairs(word_counts, word_encodings, pair_strings, vocabulary, pair_to_words)
        # Training loop
        for merge_idx in range(num_merges):
            if not pair_counts:
                break
            # a. find the max count pair to be merged
            merge_pair = max(pair_counts, key=lambda x: (pair_counts[x], pair_strings.get(x, ('b', 'b'))))
            # b. update the vocabulary and merges
            new_token_id = size
            vocabulary[new_token_id] = vocabulary[merge_pair[0]] + vocabulary[merge_pair[1]]
            merges.append((vocabulary[merge_pair[0]], vocabulary[merge_pair[1]]))
            size += 1
            # c. get the affected words
            affected_words = list(pair_to_words[merge_pair])
            # update encodings just for the affected words
            for word in affected_words:
                self._update_word_for_merge(
                    word, merge_pair, new_token_id,
                    word_counts, word_encodings, pair_counts,
                    pair_to_words, pair_strings, vocabulary
                )

            if merge_idx % 100 == 0:
                print(f"Merge {merge_idx}/{num_merges}: {merges[-1]}")

        return vocabulary, merges

    def _count_pairs(
        self,
        word_counts: Dict[str, int],
        word_encodings: Dict[str, List[int]],
        pair_strings: Dict[Tuple[int, int], Tuple[bytes, bytes]], 
        vocabulary: Dict[int, bytes],
        pair_to_words: Dict[Tuple[int, int], set] # <--- new patameter added
    ) -> Dict[Tuple[int, int], int]:
        
        pair_counts = defaultdict(int)
        
        for word, count in word_counts.items():
            encoding = word_encodings[word]
            # loop all Pairs
            for i in range(len(encoding) - 1):
                pair = (encoding[i], encoding[i+1])
                
                # 1. Count (same to last version)
                pair_counts[pair] += count
                
                # 2. [added] record inverted index:this pair included in this word
                pair_to_words[pair].add(word) 
                
                # 3. record  the bytes of the pair(same to last version)
                if pair not in pair_strings:
                    pair_strings[pair] = (vocabulary[pair[0]], vocabulary[pair[1]])
                    
        return pair_counts

    def _update_word_for_merge(
        self,
        word: str,
        merge_pair: Tuple[int, int],
        new_token_id: int,
        word_counts: Dict[str, int],
        word_encodings: Dict[str, List[int]],
        pair_counts: Dict[Tuple[int, int], int],
        pair_to_words: Dict[Tuple[int, int], set],
        pair_strings: Dict[Tuple[int, int], Tuple[bytes, bytes]],
        vocabulary: Dict[int, bytes]
    ):
        """
        Args:
            word: the word to be update its encoding
            pair_to_words: record the words that include the pair, using for query the words encodings need to be updated
        """
        encoding = word_encodings[word]
        count = word_counts[word]
        # 1. Decrement old pairs
        for i in range(len(encoding) - 1):
            old_pair = (encoding[i], encoding[i+1])
            pair_count -= count
            pair_to_words[old_pair].discard(word)
            if pair_counts[old_pair] <= 0:
                del pair_counts[old_pair]

        # 2. Merge encoding
        new_encoding = []
        i = 0
        while i < len(encoding):
            if i < len(encoding) - 1 and (encoding[i], encoding[i+1]) == merge_pair:
                new_encoding.append(new_token_id)
                i += 2
            else:
                new_encoding.append(encoding[i])
                i += 1
        # 3. Increment new pairs
        for i in range(len(new_encoding) - 1):
            new_pair = (new_encoding[i], new_encoding[i+1])
            pair_counts[new_pair] += count
            pair_to_words[new_pair].add(word)
            if new_pair not in pair_strings:
                pair_strings[new_pair] = (vocabulary[new_pair[0]], vocabulary[new_pair[1]])

            word_encodings[word] = new_encoding


    def _update_word_for_merge(
        self,
        word: str,
        merge_pair: Tuple[int, int],
        new_token_id: int,
        word_counts: Dict[str, int],
        word_encodings: Dict[str, List[int]],
        pair_counts: Dict[Tuple[int, int], int],
        pair_to_words: Dict[Tuple[int, int], set],
        pair_strings: Dict[Tuple[int, int], Tuple[bytes, bytes]],
        vocabulary: Dict[int, bytes]
    ):

        """
        Args:
            word: the word to be update its encoding
            pair_to_words: record the words that include the pair, using for query the words encodings need to be updated
        """
        encoding = word_encodings[word]
        count = word_counts[word]
        
        # minus count for these word include the pair
        for i in range(len(encoding) - 1):
            old_pair = (encoding[i], encoding[i+1])
            pair_counts[old_pair] -= count
            pair_to_words[old_pair].discard(word) # delete the word from  set
            # clear if count is 0(Optional,for memory saving)
            if pair_counts[old_pair] <= 0:
                del pair_counts[old_pair]

        # merge 
        new_encoding = []
        i = 0
        while i < len(encoding):
            if i < len(encoding) - 1 and (encoding[i], encoding[i+1]) == merge_pair:
                new_encoding.append(new_token_id)
                i += 2
            else:
                new_encoding.append(encoding[i])
                i += 1
        
        # update the word encoding
        word_encodings[word] = new_encoding

        # loop the new encoding, add the count of the new pair and add the word 

        for i in range(len(new_encoding) - 1):
            new_pair = (new_encoding[i], new_encoding[i+1])
            pair_counts[new_pair] += count
            pair_to_words[new_pair].add(word) # add the word to  pair list
            
            # record the pair string (for debug or other use)
            if new_pair not in pair_strings:
                pair_strings[new_pair] = (vocabulary[new_pair[0]], vocabulary[new_pair[1]])

    def _pretokenize(
        self,
        input_path: str,
        special_tokens: List[str]
    ) -> Dict[str, int]:
        word_counts = defaultdict(int)
        pattern_compiled = re.compile(self.pattern)
        special_pattern = "|".join(re.escape(t) for t in special_tokens)
        

        for chunk in self._chunk_documents_streaming(input_path):
            if special_tokens:
                blocks = re.split(special_pattern, chunk)
            else:
                blocks = [chunk]
            for block in blocks:
                for match in re.finditer(pattern_compiled, block):
                    word_counts[match.group(0)] += 1

        return word_counts

    def _chunk_documents_streaming(
        self,
        input_path: str,
        chunk_size: int = CHUNK_SIZE,
        special_token: str = "<|endoftext|>"
    ):
        """
        Args:
            input_path: str, the path of the input file
            special_tokens: str
        Returns:
            chunks: Generator[str, None, None]
        """
        leftover = ""
        # special_token = special_tokens[0] if special_tokens else "<|endoftext|>"
        with open(input_path, 'r', encoding='utf-8') as f:
            while True:
                # Read a small block each time,such as 50KB here
                block = f.read(chunk_size)
                # reading complete,quit the loop
                if not block:
                    break
                # add the leftover last block to the current block
                block = leftover + block
                leftover = ""
                # find the split point through the special token
                # if not found,the whole block is the last block
                # rfind(): return the index of the last occurrence of the substring
                last_idx = block.rfind(special_token)
                if last_idx == -1:
                    leftover = block
                # if found, yield the block by the special tokens
                # before the special token, genetor
                # after the special token, to the next block by the leftover
                else:
                    yield block[:last_idx + len(special_token)]
                    leftover = block[last_idx + len(special_token):]
        # deal with the last content if has
        if leftover:
            yield leftover
相关推荐
nuowenyadelunwen8 天前
Stanford CS336 Assignment 1: BPE Tokenizer
llm·bpe tokenizer·stanford cs336
爱听歌的周童鞋22 天前
斯坦福大学 | CS336 | 从零开始构建语言模型 | Spring 2025 | 笔记 | Assignment 1: BPE Tokenizer
llm·assignment·cs336·bpe tokenizer