Stanford CS336 Assignment 1: BPE Tokenizer

为什么需要 Tokenizer?

问题:计算机只认识数字,但我们要处理文本

复制代码
"Hello world" → 需要转换成 → [123, 456](数字序列)

解决方案:Tokenizer 就是这个转换器

1.1 Unicode 基础 (第2.1节)

Unicode 是什么?

  • 就是一个"全世界字符的编号表"
  • 每个字符都有一个编号(叫 code point)

例子

复制代码
>>> ord('s')        # 's' 的编号是 115
115
>>> ord('牛')       # '牛' 的编号是 29275  
29275
>>> chr(115)        # 反过来,115 对应 's'
's'

Problem:

  • (a):What Unicode character does chr(0) return?
    A: chr(0) returns a Null character (often referred to as NULL), which corresponds to the Unicode code point U+0000.
    核心知识点chr(0) 对应的是 Unicode 码位 U+0000,称为 Null Character。
  • (b):How does the character's string representation(repr ) diff from its printed representation?
    A: Its string representation (repr ) displays the esacped sequence '\x00' for debugging purposes whereas its printed representaion renders the actual invisible control character resulting no visible output.
    核心知识点:repr 显示转义符 \x00 以便调试,print 渲染实际字符(不可见)。
  • (c): What happens when this character occurs in text?
    A: It behaves as a valid, invisible character within the string and does not terminate the string (unlike in C languages), allowing subsequent text to be processed and displayed normally.
    核心知识点(重点) : 在 C 语言中,\0 是字符串结束符(Null Terminator )。如果你在 C 语言里写 "test\0string",打印出来只有 "test",后面的会被丢弃。 但在 Python 中,字符串是确定的长度(Length-prefixed),chr(0) 只是一个普通的字符。所以 print("..."+chr(0)+"...") 不会截断字符串,后面的内容依然会被打印出来。

1.2 UTF-8 编码 (第2.2节)

为什么不直接用 Unicode 编号?

问题:Unicode 有 15万+ 字符,编号范围太大

解决方案:UTF-8 编码

  • 把字符转成"字节序列"(bytes)
  • 每个字节是 0-255 的数字
  • 变长编码:英文字母 1 字节,中文 3-4 字节

例子

复制代码
>>> "hello".encode("utf-8")
b'hello'                        # 5个字节:[104, 101, 108, 108, 111]

>>> "你好".encode("utf-8")  
b'\xe4\xbd\xa0\xe5\xa5\xbd'    # 6个字节(每个中文字 3 字节)

关键理解

复制代码
字符串 "hello" 
   ↓ UTF-8 编码
字节序列 [104, 101, 108, 108, 111]  ← 这是我们处理的基础

Problems

  • (a): What are some reasons to prefer training our tokenizer on UTF-8 encoded bytes, rather than UTF-16 or UTF-32? It may be helpful to compare the output of these encodings for various input strings. UTF-8 is preferred because it is variable-length and space-efficient, representing common ASCII characters with just one byte, whereas UTF-16 and UTF-32 use fixed larger sizes (2 or 4 bytes) that introduce excessive null-byte padding for the same text.

  • (b): Why is this function incorrect?

  • example function

    def decode_utf8_bytes_to_str_wrong(bytestring: bytes):
    return "".join([bytes([b]).decode("utf-8") for b in bytestring])

example

复制代码
decode_utf8_bytes_to_str_wrong("你".encode("utf-8"))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<stdin>", line 2, in decode_utf8_bytes_to_str_wrong
UnicodeDecodeError: 'utf-8' codec can't decode byte 0xe4 in position 0: unexpected end of data

>>> test_string = "你"
>>> print(test_string.encode("utf-8"))
b'\xe4\xbd\xa0'

>>> b'\xe4\xbd\xa0'.decode("utf-8")
'你'
>>> b'\xe4'.decode("utf-8")
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
UnicodeDecodeError: 'utf-8' codec can't decode byte 0xe4 in position 0: unexpected end of data

We can infer from above:
Exampl Input : b'\xe4\xbd\xa0' (which corresponds to the character "你") Explanation : The function is incorrect because it attempts to decode each byte individually; however, in UTF-8, non-ASCII characters (like "你") are encoded as multi-byte sequences that cannot be decoded in isolation.

(c):Give a two byte sequence that does not decode to any Unicode character(s).
Example : b'\xff\xff' (or b'\xc3\x20') Explanation : This sequence is invalid because 0xff is never a valid byte in UTF-8 (or for the second example: the leading byte \xc3 indicates a two-byte character, but \x20 is not a valid continuation byte).

再详细解释一下:

并不是所有的字节组合都是合法的 UTF-8。UTF-8 有严格的格式规定(比如多字节字符必须以 11xxxxxx 开头,后面跟着 10xxxxxx 等)。 你可以构造一个违反规则的序列:

  1. 非法的起始字节:0xFF 在 UTF-8 中永远是非法的。所以 b'\xff\xff' 肯定挂。
  2. 断掉的序列:0xC3 表示"我是两个字节字符的开头",如果后面跟一个空格 0x20(而不是续字节),解码器就会报错。
  3. 孤立的续字节:0x80 是续字节(continuation byte),它不能单独出现在开头。

1.3 BPE 训练 (第2.4节)

问题:如果直接用字节作为 token,一个句子会变得很长

复制代码
"the" → [116, 104, 101]  (3 个 token)

BPE 的想法

  • 如果 "the" 经常出现,给它分配一个单独的 ID

  • "the" → [256] (1 个 token) 如何实现呢?根据文档可以拆解为以下步骤:

    1. 初始词汇表:256 个字节(0-255)

    2. 重复以下步骤 N 次:
      a. 统计所有相邻字节对的出现次数
      b. 找出最频繁的字节对,比如 ('t', 'h')
      c. 把这对合并成新 token 'th'
      d. 词汇表大小 +1

    3. 最终词汇表大小 = 256 + N

我们用一个具体例子讲解:
具体例子(来自文档)

复制代码
训练语料:
"low low low low low"
"lower lower"  
"widest widest widest"
"newest newest newest newest newest newest"

步骤1:预分词(按空格分)
{"low": 5, "lower": 2, "widest": 3, "newest": 6}

步骤2:表示成字节序列
{('l','o','w'): 5, ('l','o','w','e','r'): 2, ...}

步骤3:统计字节对
{'lo': 7, 'ow': 7, 'we': 8, 'er': 2, ...}

步骤4:最频繁的是 'es' 和 'st'(并列)
→ 取字典序更大的 'st'
→ 合并:所有 'st' → 新 token [256]

步骤5:更新语料库
{('l','o','w'): 5, ('l','o','w','e','r'): 2, 
 ('w','i','d','e',[256]): 3, ('n','e','w','e',[256]): 6}

重复...

我们的作业任务: Problem (train_bpe): BPE Tokenizer Training

根据文档提示,我们可知 BPE Tokenizer Training 的数据结构:

复制代码
   def train(
        self, 
        input_path: str, 
        vocab_size: int,
        special_tokens: List[str]
        ) -> Tuple[Dict[int, bytes], List[Tuple[bytes, bytes]]]:
        """
        Args:
            input_path: str, text read from memory for used to train the BPE
            vocab_size: int, the size of the vocabulary
            special_tokens: List[str], the special tokens to be added to the vocabulary, but not be merged and trained
        Returns:
            vocabulary: Dict[int, bytes], the vocabulary of the BPE
            merges: List[Tuple[bytes, bytes]], the merge rules of the BPE
        Process:
            1. Initialize the vocabulary
            2. Add special tokens to the vocabulary
            3. Pre-tokenize the text
            4. Initialize word encodings
            5. BPE training loop
        """

代码实现

1. 框架设计和数据结构分析

我们逐步实现,为了逐步验证整个框架逻辑的正确性,我先将用于训练的文本直接使用文档中的字符串, speicial tokens 就只使用<|endoftext|>

复制代码
test_text = "low low low low low<|endoftext|> lower lower<|endoftext|> widest widest widest<|endoftext|> newest newest newest newest newest newest<|endoftext|>"

到底该如何逐步实现呢?根据前面对文档中例子的拆解,总结下就是:

整个 BPE 训练可以被抽象为三个阶段:

  1. 准备阶段:初始化词表,把文本变成初始的 ID 列表。
  2. 迭代阶段(Loop):不断找最频繁的 Pair 并合并,直到词表填满。
  3. 收尾阶段 :整理输出结果。
    我们需要一个 train 函数作为整个工程的包工头,它的工作不是亲自干活,而是通过调用 _pre_tokenize_count_pairs_merge_encoding 一个个具体的"工人"管理数据流控制迭代节奏

伪代码逻辑 (Blueprint)

在写代码前,我们先用伪代码理清逻辑:

复制代码
函数 train(文本, 目标词表大小, 特殊Token):
    
    1. 【初始化】
       词表 = {0: b'\x00', ..., 255: b'\xff'} + 特殊Token
       当前ID = 256 + len(特殊Token)

    2. 【预处理】 (数据压缩的关键)
       单词计数表 (word_counts) = _pre_tokenize(文本)
       单词编码表 (word_encodings) = 把每个词转换成 [byte_id, byte_id, ...]

    3. 【主循环】 当 (当前词表大小 < 目标词表大小) 时:
       
       a. 统计侦查:
          pair_counts = _count_pairs(单词计数表, 单词编码表)
          如果没有 pair 可合并 -> 跳出循环

       b. 决策:
          best_pair = 找出 pair_counts 中频率最高的 (若相等,选字典序大的)
          
       c. 记录:
          新ID = 当前ID
          词表[新ID] = best_pair 的两个部分拼接
          记录合并规则 (merges)
       
       d. 执行:
          _merge_encoding(单词编码表, best_pair, 新ID)
          
       e. 步进:
          当前ID += 1

    4. 【返回】 词表, 合并规则
2. 代码框架 (Python Framework)

这是你可以在作业中直接使用的骨架。注意看它是如何调用我们之前写的那些"小工具"的。

复制代码
class BPETrainer_Framework:
    def train(
        self, 
        text: str, 
        vocab_size: int, 
        special_tokens: List[str]
    ) -> Tuple[Dict[int, bytes], List[Tuple[bytes, bytes]]]:
        
        # ==========================================
        # Phase 1: Initialization (初始化)
        # ==========================================
        # 1. 基础词表 (0-255)
        vocabulary = {i: bytes([i]) for i in range(256)}
        next_id = 256
        
        # 2. 加入特殊 Token (不做 BPE 合并,直接占位)
        for token in special_tokens:
            vocabulary[next_id] = token.encode('utf-8')
            next_id += 1
            
        # ==========================================
        # Phase 2: Pre-processing (预处理)
        # ==========================================
        print("1. Pre-tokenizing text...")
        # 调用我们实现的预分词函数
        # 输出示例: {'hello': 5, 'world': 3}
        word_counts = self._pretokenize(text, special_tokens) 
        
        # 初始化编码状态:把单词变成 ID 列表
        # 示例: {'hello': [104, 101, ...]}
        word_encodings = {
            word: list(word.encode('utf-8')) 
            for word in word_counts
        }
        
        # ==========================================
        # Phase 3: The Training Loop (训练主循环)
        # ==========================================
        merges = [] # 记录合并规则:(b'e', b's') -> 256
        
        print(f"2. Starting BPE loop. Target size: {vocab_size}")
        
        while len(vocabulary) < vocab_size:
            
            # --- Step A: 统计 (Statistics) ---
            # 传入 counts 是为了加权,传入 encodings 是为了看当前切分状态
            pair_counts = self._count_pairs(word_counts, word_encodings)
            
            if not pair_counts:
                print("No more pairs to merge. Stopping early.")
                break
                
            # --- Step B: 决策 (Selection) ---
            # 这里的 Key 是 Python 里的 Tuple 比较规则:
            # 先比频率 count (x[1]),频率一样比 bytes 字典序 (vocab[...])
            # 注意:Assignment 要求 Tie 时取字典序靠后的,还是靠前的?
            # 如果是 "lexicographically larger",则直接用 tuple 比较即可
            best_pair = max(
                pair_counts,
                key=lambda p: (pair_counts[p], vocabulary[p[0]] + vocabulary[p[1]])
            )
            
            # --- Step C: 更新词表 (Vocabulary Update) ---
            new_token_bytes = vocabulary[best_pair[0]] + vocabulary[best_pair[1]]
            vocabulary[next_id] = new_token_bytes
            merges.append((vocabulary[best_pair[0]], vocabulary[best_pair[1]]))
            
            # --- Step D: 执行合并 (Apply Merge) ---
            # 这里是最耗时的部分,需要遍历字典
            # 优化点:只遍历包含 best_pair 的词 (Assignment 进阶优化)
            for word in word_encodings:
                # 调用我们实现的合并函数
                new_encoding = self._merge_encoding(
                    word_encodings[word], 
                    best_pair, 
                    next_id
                )
                word_encodings[word] = new_encoding
                
            # --- Step E: 步进 ---
            next_id += 1
            
            # 打印进度 (Optional)
            if (len(vocabulary) - 256) % 10 == 0:
                print(f"Created token {next_id-1}: {new_token_bytes} (Freq: {pair_counts[best_pair]})")
                
        return vocabulary, merges

    # ---  Helper Functions 占位符 ---
    def _pretokenize(self, text, special_tokens):
        pass 
        
    def _count_pairs(self, word_counts, word_encodings):
        pass 
        
    def _merge_encoding(self, encoding, pair, new_id):
        pass 

上面的伪代码逻辑和代码框架是我让 gemini 根据我的实现总结提取的,我发现它用的跟我写的有点出入,不过逻辑是一致的,不知道它是哪里我资料还是自己发挥的。 下面开始实现辅助函数

我的分步验证是在 colab 中做的,只是为了验证结果符合预期,我就贴下结果(项目的.toml 没有notebook 的依赖,本身电脑配置有限,我也不想单独安装了,我是WSL的环境,2060的显卡,6G显存,后面的任务打算租)。

第一步:预分词与计数 (Pre-tokenization)

BPE 不是在整段文本上跑的,而是先用正则把句子切成"单词",然后统计每个单词出现的频率。这样我们只需要处理 unique words,效率会高很多。

这个文档中有提到,我们在统计相邻 pair 时,遍历所有字符效率太低,我们可以统计单词的频率,将包含该pair 的单词频次作为该 pair的频次累加。

  1. 代码实现

    import regex as re
    from collections import Counter, defaultdict
    from typing import Dict, List, Tuple

    GPT-2/4 使用的正则模式

    PATTERN = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

    def _pretokenize(text: str, special_tokens: list[str]) -> Dict[str, int]:
    # 1. 移除特殊 token(它们不参与 BPE 训练)
    for token in special_tokens:
    text = text.replace(token, '')

    复制代码
     # 2. 正则切分
     words = re.findall(PATTERN, text)
     
     # 3. 统计频率
     return dict(Counter(words))
  2. 单元测试与验证

    test_text = "low low low low low<|endoftext|> lower lower<|endoftext|> widest widest widest<|endoftext|> newest newest newest newest newest newest"
    special_tokens = ["<|endoftext|>"]
    word_counts = _pretokenize(test_text, special_tokens)
    print(word_counts)

👀 观察输出:

复制代码
分词结果 (5 unique words):
{'low': 1, ' low': 4, ' lower': 2, ' widest': 3, ' newest': 6}

这里有一点需要注意,'low': 1, ' low': 4,这个是有区别的,一个前面没有空格,一个前面有空格

  • 正则表达式 | ?\p{L}+ 中的 ? 表示"可选空格"

    • 'low' 匹配了没有前导空格的情况(句首)
    • ' low' 匹配了有前导空格的情况(其他位置)
  • 测试文本处理后变成

    复制代码
    "low low low low low lower lower"
    ^   ^   ^   ^   ^   ^     ^
    |   └───┴───┴───┴───┘     └─── 都有前导空格
    └── 句首,无前导空格
  • 为什么要区分它们?

    • 为了能够完美重建原文(包括空格)
    • GPT-2/GPT-3 都是这样设计的
    • 这样 tokenizer 是可逆的text → tokens → text
  • 'low'' low' 会被当作不同的 pre-token
  • 它们有不同的 UTF-8 编码
  • BPE 会独立处理它们
  • 最终的 tokenizer 可能会学到:
    • 'low' 的 merge(用于句首)

    • ' low' 的 merge(用于其他位置)

      这两个是不同的!

      'low'.encode('utf-8') # b'low' [108, 111, 119]
      ' low'.encode('utf-8') # b' low' [32, 108, 111, 119]

      ^

      空格的字节值是 32

第二步:初始化词编码 (Word Encodings)

BPE 是字节级(Byte-level)的。我们需要把每个单词转换成 UTF-8 字节的整数列表。这是所有合并操作的"底座"。

  1. 代码实现

    将字符串转换为 UTF-8 字节列表

    word_encodings = {}
    for word in word_counts:
    word_encodings[word] = list(word.encode("utf-8"))

  2. 单元测试与验证

    print("初始编码状态 (Word Encodings):")
    print("word_encodings:",word_encodings)

  3. 👀 观察输出:

    初始编码状态 (Word Encodings):
    word_encodings: {'low': [108, 111, 119], ' low': [32, 108, 111, 119], ' lower': [32, 108, 111, 119, 101, 114], ' widest': [32, 119, 105, 100, 101, 115, 116], ' newest': [32, 110, 101, 119, 101, 115, 116]}

✅ 符合预期

  • l 是 108, o 是 111, w 是 119。
  • 空格被编码为 32
  • 现在每个词都是一个 List[int],准备好被合并了。

第三步:统计字节对频率 (Count Pairs)

这是 BPE 算法的心脏。我们需要扫描所有单词,统计相邻两个 ID 出现的总次数。 关键点 :如果单词 'newest' 的频率是 6,那么它里面的 ('e', 's') 这一对也贡献了 6 次计数。

  1. 代码实现

    def _count_pairs(word_counts: Dict[str, int], word_encodings: Dict[str, List[int]]) -> Dict[Tuple[int, int], int]:
    pair_counts = defaultdict(int)
    for word, count in word_counts.items():
    encoding = word_encodings[word]
    for i in range(len(encoding) - 1):
    pair = (encoding[i], encoding[i+1])
    pair_counts[pair] += count
    return dict(pair_counts)

  2. 单元测试与验证

    pair_counts = _count_pairs(word_counts, word_encodings)
    print("pair_counts:",pair_counts)

👀 观察输出:

  • 输出:

    pair_counts: {(108, 111): 7, (111, 119): 7, (32, 108): 6, (119, 101): 8, (101, 114): 2, (32, 119): 3, (119, 105): 3, (105, 100): 3, (100, 101): 3, (101, 115): 9, (115, 116): 9, (32, 110): 6, (110, 101): 6, (101, 119): 6}

核对结果

复制代码
字节对统计结果:
(101, 115): 9   <-- 'e', 's'
(115, 116): 9   <-- 's', 't'
(119, 101): 8   <-- 'w', 'e'
(108, 111): 7   <-- 'l', 'o'
(111, 119): 7   <-- 'o', 'w'

✅ 符合预期

  • 'est' 里的 e(101)-s(115)s(115)-t(116) 并列第一(9次)。
  • 来源:widest (3次) + newest (6次) = 9次。逻辑完全正确。

第四步:选择最佳合并 (Find Best Merge)

根据上面的统计,我们找到了频率最高的对。如果有多个频率相同的(Tie),通常按照字典序(Lexicographical order)来选,或者直接取第一个。这里为了演示简单,我们直接用 max。 有了最佳对 (101, 115),我们要创造一个新 ID(比如 256),然后把所有单词里的 101, 115 替换成 256

  1. 代码实现

    def _merge_encoding(encoding: List[int], pair: Tuple[int, int], new_id: int) -> List[int]:
    new_encoding = []
    i = 0
    while i < len(encoding):
    if i < len(encoding) - 1 and encoding[i] == pair[0] and encoding[i+1] == pair[1]:
    new_encoding.append(new_id)
    i += 2
    else:
    new_encoding.append(encoding[i])
    i += 1
    return new_encoding

  2. 测试验证

    找出频率最大的 pair

    注意:正式代码中需要处理 (101, 115) 和 (115, 116) 频率相同的情况

    merge_pair = max(pair_counts, key=pair_counts.get) # 找出频率最大的pair,初版忽略相等需要字典序的情况
    print("Max pair:",merge_pair)
    max_count = pair_counts[merge_pair]
    print("Max count:",max_count)

    这里我手动赋予新的token id 就是256,只是为了验证合并的逻辑,实际操作中是由前面的步骤决定这个值是多大的。

    for word in word_encodings:
    new_encoding = _merge_encoding(word_encodings[word], merge_pair, 256)
    word_encodings[word] = new_encoding
    print("After once merge word_encodings:",word_encodings)

👀 观察输出:

  • 输出

    Max pair: (101, 115) # 可以看出忽略字典序,系统自动合并的是 es。平局时按照字典序应该合并的是st。
    Max count: 9
    After once merge word_encodings: {'low': [108, 111, 119], ' low': [32, 108, 111, 119], ' lower': [32, 108, 111, 119, 101, 114], ' widest': [32, 119, 105, 100, 256, 116], ' newest': [32, 110, 101, 119, 256, 116]}

  • 观察输出
    👀 观察输出:

    合并前 'newest': [32, 110, 101, 119, 101, 115, 116]
    合并后 'newest': [32, 110, 101, 119, 256, 116]


✅ 符合预期

  • 原来的 ..., 101, 115, ... (e, s) 成功变成了 ..., 256, ... (es)。
  • 列表长度减少了 1,压缩成功!
第五步:整合,词表与规则记录

最后,我们需要记录下这一步发生了什么,以便将来用于 Tokenizer 的推理。

复制代码
# 初始化基础词表 (0-255)
vocabulary = {i: bytes([i]) for i in range(256)}
# 记录本次操作
merges = []
merge_bytes = vocabulary[merge_pair[0]] + vocabulary[merge_pair[1]]
print("merge bytes:", merge_bytes)
vocabulary[256] = merge_bytes
merges.append((vocabulary[merge_pair[0]], vocabulary[merge_pair[1]]))
print("merges:", merges)

👀 观察输出:

复制代码
merge bytes: b'es'
merges: [(b'e', b's')]

总结

通过这一系列的单步调试,我们清晰地看到 BPE 的本质就是一个 "统计 -> 替换 -> 循环" 的过程:

  1. 数据状态 :从 [e, s, t] 变成 [es, t]
  2. 词表状态:从 256 个基础字节扩充到包含 'es'。
  3. 循环 :下一轮循环时,_count_pairs 就会统计到 (256, 116) (即 es+t) 这样的新组合了。

接下来,只需要把上述步骤放入一个 while 循环中就可以完成完整的 BPE 训练器!

优化
  • count_pair 优化
  1. 当频率相同时,取字典序大的

  2. 增加了pair 对的缓存,提升效率

  3. 代码实现

    def _count_pairs_v2(
    word_counts: Dict[str, int],
    word_encodings: Dict[int, List[int]],
    pair_strings: Dict[Tuple[int, int], Tuple[bytes, bytes]],
    vocabulary: Dict[int, bytes],
    ) -> Dict[Tuple[int, int], int]:
    pair_counts = defaultdict(int)
    for word, count in word_counts.items():
    encoding = word_encodings[word]
    for i in range(len(encoding) - 1):
    pair = (encoding[i], encoding[i+1])
    pair_counts[pair] += count
    # cache pair strings for save query time
    if pair not in pair_strings:
    pair_strings[pair] = (vocabulary[pair[0]], vocabulary[pair[1]])
    return pair_counts

  4. 测试验证:

    test_text = "low low low low low<|endoftext|> lower lower<|endoftext|> widest widest widest<|endoftext|> newest newest newest newest newest newest"
    special_tokens = ["<|endoftext|>"]
    vocabulary = {i: bytes([i]) for i in range(256)}
    size = 256
    new_token = size
    word_counts = _pretokenize(test_text, special_tokens)
    size += 1
    new_token = size
    print(word_counts)
    word_encodings = {}
    for word in word_counts:
    word_encodings[word] = list(word.encode("utf-8"))
    print("word_encodings:",word_encodings)

    pair_counts = _count_pairs(word_counts, word_encodings)

    pair_strings = {}
    pair_counts = _count_pairs_v2(word_counts, word_encodings, pair_strings, vocabulary)
    print("pair_counts:",pair_counts)

    merge_pair = max(pair_counts, key=pair_counts.get) # 找出频率最大的pair,初版忽略相等需要字典序的情况

    merge_pair = max(pair_counts, key=lambda x: (pair_counts[x], pair_strings[x])) # 频率相同时取字典序
    print("Max pair:",merge_pair)
    max_count = pair_counts[merge_pair]
    print("Max count:",max_count)

    for word in word_encodings:
    new_encoding = _merge_encoding(word_encodings[word], merge_pair, new_token)
    word_encodings[word] = new_encoding
    print("After once merge word_encodings:",word_encodings)

    vocabulary = {i: bytes([i]) for i in range(256)}

    merges = []
    merge_bytes = vocabulary[merge_pair[0]] + vocabulary[merge_pair[1]]
    print("merge bytes:", merge_bytes)
    vocabulary[size] = merge_bytes
    merges.append((vocabulary[merge_pair[0]], vocabulary[merge_pair[1]]))
    print("merges:", merges)

  5. 输出

    {'low': 1, ' low': 4, ' lower': 2, ' widest': 3, ' newest': 6}
    word_encodings: {'low': [108, 111, 119], ' low': [32, 108, 111, 119], ' lower': [32, 108, 111, 119, 101, 114], ' widest': [32, 119, 105, 100, 101, 115, 116], ' newest': [32, 110, 101, 119, 101, 115, 116]}
    pair_counts: defaultdict(<class 'int'>, {(108, 111): 7, (111, 119): 7, (32, 108): 6, (119, 101): 8, (101, 114): 2, (32, 119): 3, (119, 105): 3, (105, 100): 3, (100, 101): 3, (101, 115): 9, (115, 116): 9, (32, 110): 6, (110, 101): 6, (101, 119): 6})
    Max pair: (115, 116)
    Max count: 9
    After once merge word_encodings: {'low': [108, 111, 119], ' low': [32, 108, 111, 119], ' lower': [32, 108, 111, 119, 101, 114], ' widest': [32, 119, 105, 100, 101, 257], ' newest': [32, 110, 101, 119, 101, 257]}
    merge bytes: b'st'
    merges: [(b's', b't')]

👀 观察输出:

  1. merge 的pair 变成了 (115, 116)
  2. 合并如之前分析的是 (b's', b't')
  3. new_token_id 是257,因为初始词表大小256,还有一个special token,所以merge 的新token 从257 开始。 关于merge pair 还有一个小技巧,可以使用 zip(),pair = zip(encoding, encoding[1:])
  • merge_encoding 优化
    在 merge前先判断 encoding 是否有变化,有变化才进行merge

  • _merge_encoding

    复制代码
      def _merge_encoding(
          encoding: List[int],
          merge_pair: Tuple[int, int],
          new_token_id: int
      ) -> Tuple[List[int], bool]:
          """
          return word encoding and has changed
          """
          new_encoding = []
          has_changed = False
          i = 0
          while i < len(encoding):
              if i < len(encoding) - 1 and (encoding[i], encoding[i+1]) == merge_pair:
                  new_encoding.append(new_token_id)
                  has_changed = True
                  i += 2
              else:
                  new_encoding.append(encoding[i])
                  i += 1
          return new_encoding, has_changed
  • 训练函数 train

    复制代码
      def train(
          self,
          text: str,
          vocab_size: int,
          special_tokens: List[str]
      ) -> Tuple[Dict[int, bytes], List[Tuple[bytes, bytes]]]:
          # Implementation similar to V2 with has_new_id optimization
          vocabulary = {i: bytes([i]) for i in range(256)}
          size = 256
          for token in special_tokens:
              vocabulary[size] = token.encode('utf-8')
              size += 1
    
          word_counts = self._pretokenize(text, special_tokens)
          word_encodings = {word: list(word.encode('utf-8')) for word in word_counts}
    
          merges = []
          pair_strings = {}
          num_merges = vocab_size - size
          for merge_idx in range(num_merges):
              pair_counts = self._count_pairs(word_counts, word_encodings, pair_strings, vocabulary)
              if not pair_counts:
                  break
              merge_pair = max(
                  pair_counts,
                  key=lambda x: (pair_counts[x], pair_strings[x])
              )
              max_count = pair_counts[merge_pair]
              merge_bytes = vocabulary[merge_pair[0]] + vocabulary[merge_pair[1]]
              new_id = size
              size += 1
    
              # optimization
              updated_count = 0
              for word in word_encodings:
                  new_encoding, has_changed = self._merge_encoding(word_encodings[word], merge_pair, new_id)
                  if has_changed:
                      word_encodings[word] = new_encoding
                      updated_count += 1
              if merge_idx < 5:
                  print(f"Merge #{merge_idx+1}: {merge_pair} (freq: {max_count}) -> ID {new_id}, updated {updated_count} words")
              merges.append((vocabulary[merge_pair[0]], vocabulary[merge_pair[1]]))
    
          return vocabulary, merges

完整实现见 train_bpe.py

下一步就是实现训练文本是从内存中读取,内存中的文本也是从磁盘加载到内存的,而且对于大文本也得考虑数据是无法一次全部加载到内存的,需要分块加载。文档也提到了并行。

train_bpe.py 简单版完整代码:

python 复制代码
"""
BPE training algorithm
"""
import regex as re
from collections import Counter, defaultdict
from typing import List, Tuple, Dict


class BPETrainer_V1:
    """
    V1: sample version, only for test
    """
    def __init__(self):
        self.pattern = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

    def train(
        self, 
        text: str, 
        vocab_size: int,
        special_tokens: List[str]
        ) -> Tuple[Dict[int, bytes], List[Tuple[bytes, bytes]]]:
        """
        Args:
            text: str, text for used to train the BPE
            vocab_size: int, the size of the vocabulary
            special_tokens: List[str], the special tokens to be added to the vocabulary, but not be merged and trained
        Returns:
            vocabulary: Dict[int, bytes], the vocabulary of the BPE
            merges: List[Tuple[bytes, bytes]], the merge rules of the BPE
        Process:
            1. Initialize the vocabulary
            2. Add special tokens to the vocabulary
            3. Pre-tokenize the text
            4. Initialize word encodings
            5. BPE training loop
        """
        # 1. Initialize the vocabulary with 256  bytes
        vocabulary = {i: bytes([i]) for i in range(256)}
        size = 256

        # 2. Add special tokens to the vocabulary
        for token in special_tokens:
            vocabulary[size] = token.encode('utf-8')
            size += 1
        
        # 3. Pre-tokenize the text by word level and count the frequency of each word
        print("📝 Pre-tokenizing...")
        word_counts = self._pretokenize(text, special_tokens)
        print(f"Found {len(word_counts)} different word in text after remove special tokens")

        # 4. Initialize word encodings
        print("Initializing word encodings...")
        word_encodings = {}
        for word in word_counts:
            word_encodings[word] = list(word.encode('utf-8'))

        # 5. BPE training loop
        print(f"Starting BPE training...")
        num_merges = vocab_size - size
        print(f"Need {num_merges} loop to reach the largest vocabulary size {vocab_size}")
        merges = []
        for merge_idx in range(num_merges):
            # a. Count the frequency of each pair of adjacent tokens
            pair_counts = self._count_pairs(word_counts, word_encodings)
            if not pair_counts:
                print("No more pair to be merges, quit")
                break
            # b. Find the max count pair
            merge_pair = max(pair_counts, key=pair_counts.get)
            max_count = pair_counts[merge_pair]

            # c. Create new token
            new_token_id = size
            vocabulary[new_token_id] = vocabulary[merge_pair[0]] + vocabulary[merge_pair[1]]
            size += 1
            
            # d. Update word encodings
            for word in word_encodings:
                new_encoding = self._merge_encoding(word_encodings[word], merge_pair, new_token_id)
                word_encodings[word] = new_encoding

            # e. Recoding merges
            merges.append((vocabulary[merge_pair[0]], vocabulary[merge_pair[1]]))
            # f. Print the merge process
            if merge_idx < 5 or merge_idx % 10 == 0:
                print(f"Merge #{merge_idx+1}: {merge_pair} (freq: {max_count}) -> ID {new_token_id}")

        return vocabulary, merges

    def _pretokenize(self, text: str, special_tokens: List[str]) -> Dict[str, int]:
        # Remove special tokens from text
        for token in special_tokens:
            text = text.replace(token, '')
        # Split text into words using regex
        words = re.findall(self.pattern, text)
        # Count the frequency of each word
        return dict(Counter(words))

    def _count_pairs(
        self,
        word_counts: Dict[str, int],
        word_encodings: Dict[str, List[int]]
    ) -> Dict[Tuple[int, int], int]:
        pair_counts = defaultdict(int)
        for word, count in word_counts.items():
            encoding = word_encodings[word]
            for i in range(len(encoding) - 1):
                pair = (encoding[i], encoding[i+1])
                pair_counts[pair] += count

        return pair_counts

    def _merge_encoding(
        self,
        encoding: List[int],
        merge_pair: Tuple[int, int],
        new_token_id: int
    ) -> List[int]:
        new_encoding = []
        p0, p1 = merge_pair
        i = 0
        while i < len(encoding):
            if i < len(encoding) -1 and encoding[i] == p0 and encoding[i+1] == p1:
                new_encoding.append(new_token_id)
                i += 2
            else:
                new_encoding.append(encoding[i])
                i += 1

        return new_encoding


class BPETrainer_V2:
    """
    Added:
    - In case of tie in pair frequency, choose the lexicographically larger pair
    - pair strings buffer
    - max key function optimization
    """
    def __init__(self):
        self.pattern = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

    def train(
        self,
        text: str,
        vocab_size: int,
        special_tokens: List[str]
    ) -> Tuple[Dict[int, bytes], List[Tuple[bytes, bytes]]]:
        # Implementation similar to V1 with optimizations and tie-breaking logic
        vocabulary = {i: bytes([i]) for i in range(256)}
        size = 256

        for token in special_tokens:
            vocabulary[size] = token.encode('utf-8')
            size += 1

        print("📝 Pre-tokenizing...")
        word_counts = self._pretokenize(text, special_tokens)
        print(f"Found {len(word_counts)} different word in text after remove special tokens")

        print("Initializing word encodings...")
        word_encodings = {word: list(word.encode('utf-8')) for word in word_counts}
        print(f"Starting BPE training...")

        merges = []
        pair_strings = {} # new buffer for pair strings
        num_merges = vocab_size - size
        for merge_idx in range(num_merges):
            pair_counts = self._count_pairs(word_counts, word_encodings, pair_strings, vocabulary)
            if not pair_counts:
                break
            # fix max frequency tie by lexicographical order
            merge_pair = max(
                pair_counts,
                key=lambda x: (pair_counts[x], pair_strings[x])
            )
            max_count = pair_counts[merge_pair]

            merge_bytes = vocabulary[merge_pair[0]] + vocabulary[merge_pair[1]]
            new_token_id = size
            vocabulary[new_token_id] = merge_bytes
            size += 1

            if merge_idx < 5 or merge_idx % 10 == 0:
                print(f"Merge #{merge_idx+1}: {merge_pair} (freq: {max_count}) -> ID {new_token_id}")

            for word in word_encodings:
                new_encoding = self._merge_encoding(word_encodings[word], merge_pair, new_token_id)
                word_encodings[word] = new_encoding
            merges.append((vocabulary[merge_pair[0]], vocabulary[merge_pair[1]]))

        return vocabulary, merges
    def _pretokenize(self, text: str, special_tokens: List[str]) -> Dict[str, int]:
        for token in special_tokens:
            text = text.replace(token, '')

        words = re.findall(self.pattern, text)
        return dict(Counter(words))

    def _count_pairs(
        self,
        word_counts: Dict[str, int],
        word_encodings: Dict[int, List[int]],
        pair_strings: Dict[Tuple[int, int], Tuple[bytes, bytes]],
        vocabulary: Dict[int, bytes],
    ) -> Dict[Tuple[int, int], int]:
        pair_counts = defaultdict(int)
        for word, count in word_counts.items():
            encoding = word_encodings[word]
            for i in range(len(encoding) - 1):
                pair = (encoding[i], encoding[i+1])
                pair_counts[pair] += count
                # cache pair strings for save query time
                if pair not in pair_strings:
                    pair_strings[pair] = (vocabulary[pair[0]], vocabulary[pair[1]])
        return pair_counts

    def _merge_encoding(
        self,
        encoding: List[int],
        merge_pair: Tuple[int, int],
        new_token_id: int
    ) -> List[int]:
        new_encoding = []
        i = 0
        while i < len(encoding):
            if i < len(encoding) - 1 and (encoding[i], encoding[i+1]) == merge_pair:
                new_encoding.append(new_token_id)
                i += 2
            else:
                new_encoding.append(encoding[i])
                i += 1
        return new_encoding

# ============================================
# Version 3: add has_new_id opitimization
# ============================================ 
class BPETrainer_V3:
    def __init__(self):
        self.pattern = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
    def train(
        self,
        text: str,
        vocab_size: int,
        special_tokens: List[str]
    ) -> Tuple[Dict[int, bytes], List[Tuple[bytes, bytes]]]:
        # Implementation similar to V2 with has_new_id optimization
        vocabulary = {i: bytes([i]) for i in range(256)}
        size = 256
        for token in special_tokens:
            vocabulary[size] = token.encode('utf-8')
            size += 1

        word_counts = self._pretokenize(text, special_tokens)
        word_encodings = {word: list(word.encode('utf-8')) for word in word_counts}

        merges = []
        pair_strings = {}
        num_merges = vocab_size - size
        for merge_idx in range(num_merges):
            pair_counts = self._count_pairs(word_counts, word_encodings, pair_strings, vocabulary)
            if not pair_counts:
                break
            merge_pair = max(
                pair_counts,
                key=lambda x: (pair_counts[x], pair_strings[x])
            )
            max_count = pair_counts[merge_pair]
            merge_bytes = vocabulary[merge_pair[0]] + vocabulary[merge_pair[1]]
            new_id = size
            size += 1

            # optimization
            updated_count = 0
            for word in word_encodings:
                new_encoding, has_changed = self._merge_encoding(word_encodings[word], merge_pair, new_id)
                if has_changed:
                    word_encodings[word] = new_encoding
                    updated_count += 1
            if merge_idx < 5:
                print(f"Merge #{merge_idx+1}: {merge_pair} (freq: {max_count}) -> ID {new_id}, updated {updated_count} words")
            merges.append((vocabulary[merge_pair[0]], vocabulary[merge_pair[1]]))

        return vocabulary, merges

    def _merge_encoding(
        encoding: List[int],
        merge_pair: Tuple[int, int],
        new_token_id: int
    ) -> Tuple[List[int], bool]:
        """
        return word encoding and has changed
        """
        new_encoding = []
        has_changed = False
        i = 0
        while i < len(encoding):
            if i < len(encoding) - 1 and (encoding[i], encoding[i+1]) == merge_pair:
                new_encoding.append(new_token_id)
                has_changed = True
                i += 2
            else:
                new_encoding.append(encoding[i])
                i += 1
        return new_encoding, has_changed
# ============================================
# Test Code
# ============================================

def test_versions():
    test_text = "low low low low low<|endoftext|> lower lower<|endoftext|> widest widest widest<|endoftext|> newest newest newest newest newest newest<|endoftext|>"
    
    print("="*60)
    print("Test data:", test_text[:50], "...")
    print("="*60)
    print()
    
    # Test V1
    print("【Version 1】Sample BPE trainer sample implementation")
    print("-"*60)
    trainer_v1 = BPETrainer_V1()
    vocab_v1, merges_v1 = trainer_v1.train(test_text, 262, [])
    print(f"✅ Training complete,vocab_size: {len(vocab_v1)}")
    print() 


    # Test V2
    print("【Version 2】add lexicographical order")
    print("-"*60)
    trainer_v2 = BPETrainer_V2()
    vocab_v2, merges_v2 = trainer_v2.train(test_text, 262, [])
    print(f"✅ Training complete,vocab_size: {len(vocab_v2)}")
    print()
if __name__ == "__main__":
    test_versions()
相关推荐
农场主John5 小时前
Accelerate_deepspeed使用
pytorch·llm·deepspeed
组合缺一6 小时前
论 AI Skills 分布式发展的必然性:从单体智能到“云端大脑”的跃迁
java·人工智能·分布式·llm·mcp·skills
小哈里7 小时前
【计算】Ray框架介绍,AI基础设施之“通用”分布式计算(跨场景,门槛低,大规模生产,单机->集群->推理一站式)
人工智能·大模型·llm·分布式计算·ray
nuowenyadelunwen7 小时前
Stanford CS336 Language Models from Scratch-Assignment 1 OVerview
大语言模型llm·stanford cs336·cs336assignment
山顶夕景1 天前
【VLM】Visual Merit or Linguistic Crutch? 看DeepSeek-OCR
大模型·llm·ocr·多模态
玄同7651 天前
LangChain 核心组件全解析:构建大模型应用的 “乐高积木”
人工智能·python·语言模型·langchain·llm·nlp·知识图谱
亚里随笔1 天前
相对优势估计存在偏差——揭示群体相对强化学习中的系统性偏差问题
人工智能·深度学习·机器学习·llm·agentic·rlvr
带刺的坐椅1 天前
论 AI Skills 分布式发展的必然性:从单体智能到“云端大脑”的跃迁
java·ai·llm·mcp·tool-call·skills
中杯可乐多加冰1 天前
RAG 深度实践系列(三):RAG 技术演变与核心架构的深度剖析
人工智能·深度学习·大模型·llm·知识库·rag·graphrag