为什么需要 Tokenizer?
问题:计算机只认识数字,但我们要处理文本
"Hello world" → 需要转换成 → [123, 456](数字序列)
解决方案:Tokenizer 就是这个转换器
1.1 Unicode 基础 (第2.1节)
Unicode 是什么?
- 就是一个"全世界字符的编号表"
- 每个字符都有一个编号(叫 code point)
例子:
>>> ord('s') # 's' 的编号是 115
115
>>> ord('牛') # '牛' 的编号是 29275
29275
>>> chr(115) # 反过来,115 对应 's'
's'
Problem:
- (a):What Unicode character does
chr(0)return?
A: chr(0) returns a Null character (often referred to as NULL), which corresponds to the Unicode code point U+0000.
核心知识点 :chr(0)对应的是 Unicode 码位 U+0000,称为 Null Character。 - (b):How does the character's string representation(repr ) diff from its printed representation?
A: Its string representation (repr ) displays the esacped sequence '\x00' for debugging purposes whereas its printed representaion renders the actual invisible control character resulting no visible output.
核心知识点:repr 显示转义符 \x00 以便调试,print 渲染实际字符(不可见)。 - (c): What happens when this character occurs in text?
A: It behaves as a valid, invisible character within the string and does not terminate the string (unlike in C languages), allowing subsequent text to be processed and displayed normally.
核心知识点(重点) : 在 C 语言中,\0是字符串结束符(Null Terminator )。如果你在 C 语言里写"test\0string",打印出来只有"test",后面的会被丢弃。 但在 Python 中,字符串是确定的长度(Length-prefixed),chr(0)只是一个普通的字符。所以print("..."+chr(0)+"...")不会截断字符串,后面的内容依然会被打印出来。
1.2 UTF-8 编码 (第2.2节)
为什么不直接用 Unicode 编号?
问题:Unicode 有 15万+ 字符,编号范围太大
解决方案:UTF-8 编码
- 把字符转成"字节序列"(bytes)
- 每个字节是 0-255 的数字
- 变长编码:英文字母 1 字节,中文 3-4 字节
例子:
>>> "hello".encode("utf-8")
b'hello' # 5个字节:[104, 101, 108, 108, 111]
>>> "你好".encode("utf-8")
b'\xe4\xbd\xa0\xe5\xa5\xbd' # 6个字节(每个中文字 3 字节)
关键理解:
字符串 "hello"
↓ UTF-8 编码
字节序列 [104, 101, 108, 108, 111] ← 这是我们处理的基础
Problems
-
(a): What are some reasons to prefer training our tokenizer on UTF-8 encoded bytes, rather than UTF-16 or UTF-32? It may be helpful to compare the output of these encodings for various input strings. UTF-8 is preferred because it is variable-length and space-efficient, representing common ASCII characters with just one byte, whereas UTF-16 and UTF-32 use fixed larger sizes (2 or 4 bytes) that introduce excessive null-byte padding for the same text.
-
(b): Why is this function incorrect?
-
example function
def decode_utf8_bytes_to_str_wrong(bytestring: bytes):
return "".join([bytes([b]).decode("utf-8") for b in bytestring])
example
decode_utf8_bytes_to_str_wrong("你".encode("utf-8"))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "<stdin>", line 2, in decode_utf8_bytes_to_str_wrong
UnicodeDecodeError: 'utf-8' codec can't decode byte 0xe4 in position 0: unexpected end of data
>>> test_string = "你"
>>> print(test_string.encode("utf-8"))
b'\xe4\xbd\xa0'
>>> b'\xe4\xbd\xa0'.decode("utf-8")
'你'
>>> b'\xe4'.decode("utf-8")
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
UnicodeDecodeError: 'utf-8' codec can't decode byte 0xe4 in position 0: unexpected end of data
We can infer from above:
Exampl Input : b'\xe4\xbd\xa0' (which corresponds to the character "你") Explanation : The function is incorrect because it attempts to decode each byte individually; however, in UTF-8, non-ASCII characters (like "你") are encoded as multi-byte sequences that cannot be decoded in isolation.
(c):Give a two byte sequence that does not decode to any Unicode character(s).
Example : b'\xff\xff' (or b'\xc3\x20') Explanation : This sequence is invalid because 0xff is never a valid byte in UTF-8 (or for the second example: the leading byte \xc3 indicates a two-byte character, but \x20 is not a valid continuation byte).
再详细解释一下:
并不是所有的字节组合都是合法的 UTF-8。UTF-8 有严格的格式规定(比如多字节字符必须以 11xxxxxx 开头,后面跟着 10xxxxxx 等)。 你可以构造一个违反规则的序列:
- 非法的起始字节:
0xFF在 UTF-8 中永远是非法的。所以b'\xff\xff'肯定挂。 - 断掉的序列:
0xC3表示"我是两个字节字符的开头",如果后面跟一个空格0x20(而不是续字节),解码器就会报错。 - 孤立的续字节:
0x80是续字节(continuation byte),它不能单独出现在开头。
1.3 BPE 训练 (第2.4节)
问题:如果直接用字节作为 token,一个句子会变得很长
"the" → [116, 104, 101] (3 个 token)
BPE 的想法:
-
如果 "the" 经常出现,给它分配一个单独的 ID
-
"the" → [256] (1 个 token) 如何实现呢?根据文档可以拆解为以下步骤:
-
初始词汇表:256 个字节(0-255)
-
重复以下步骤 N 次:
a. 统计所有相邻字节对的出现次数
b. 找出最频繁的字节对,比如 ('t', 'h')
c. 把这对合并成新 token 'th'
d. 词汇表大小 +1 -
最终词汇表大小 = 256 + N
-
我们用一个具体例子讲解:
具体例子(来自文档):
训练语料:
"low low low low low"
"lower lower"
"widest widest widest"
"newest newest newest newest newest newest"
步骤1:预分词(按空格分)
{"low": 5, "lower": 2, "widest": 3, "newest": 6}
步骤2:表示成字节序列
{('l','o','w'): 5, ('l','o','w','e','r'): 2, ...}
步骤3:统计字节对
{'lo': 7, 'ow': 7, 'we': 8, 'er': 2, ...}
步骤4:最频繁的是 'es' 和 'st'(并列)
→ 取字典序更大的 'st'
→ 合并:所有 'st' → 新 token [256]
步骤5:更新语料库
{('l','o','w'): 5, ('l','o','w','e','r'): 2,
('w','i','d','e',[256]): 3, ('n','e','w','e',[256]): 6}
重复...
我们的作业任务: Problem (train_bpe): BPE Tokenizer Training
根据文档提示,我们可知 BPE Tokenizer Training 的数据结构:
def train(
self,
input_path: str,
vocab_size: int,
special_tokens: List[str]
) -> Tuple[Dict[int, bytes], List[Tuple[bytes, bytes]]]:
"""
Args:
input_path: str, text read from memory for used to train the BPE
vocab_size: int, the size of the vocabulary
special_tokens: List[str], the special tokens to be added to the vocabulary, but not be merged and trained
Returns:
vocabulary: Dict[int, bytes], the vocabulary of the BPE
merges: List[Tuple[bytes, bytes]], the merge rules of the BPE
Process:
1. Initialize the vocabulary
2. Add special tokens to the vocabulary
3. Pre-tokenize the text
4. Initialize word encodings
5. BPE training loop
"""
代码实现
1. 框架设计和数据结构分析
我们逐步实现,为了逐步验证整个框架逻辑的正确性,我先将用于训练的文本直接使用文档中的字符串, speicial tokens 就只使用<|endoftext|>:
test_text = "low low low low low<|endoftext|> lower lower<|endoftext|> widest widest widest<|endoftext|> newest newest newest newest newest newest<|endoftext|>"
到底该如何逐步实现呢?根据前面对文档中例子的拆解,总结下就是:
整个 BPE 训练可以被抽象为三个阶段:
- 准备阶段:初始化词表,把文本变成初始的 ID 列表。
- 迭代阶段(Loop):不断找最频繁的 Pair 并合并,直到词表填满。
- 收尾阶段 :整理输出结果。
我们需要一个train函数作为整个工程的包工头,它的工作不是亲自干活,而是通过调用_pre_tokenize、_count_pairs和_merge_encoding一个个具体的"工人"管理数据流 和控制迭代节奏
伪代码逻辑 (Blueprint)
在写代码前,我们先用伪代码理清逻辑:
函数 train(文本, 目标词表大小, 特殊Token):
1. 【初始化】
词表 = {0: b'\x00', ..., 255: b'\xff'} + 特殊Token
当前ID = 256 + len(特殊Token)
2. 【预处理】 (数据压缩的关键)
单词计数表 (word_counts) = _pre_tokenize(文本)
单词编码表 (word_encodings) = 把每个词转换成 [byte_id, byte_id, ...]
3. 【主循环】 当 (当前词表大小 < 目标词表大小) 时:
a. 统计侦查:
pair_counts = _count_pairs(单词计数表, 单词编码表)
如果没有 pair 可合并 -> 跳出循环
b. 决策:
best_pair = 找出 pair_counts 中频率最高的 (若相等,选字典序大的)
c. 记录:
新ID = 当前ID
词表[新ID] = best_pair 的两个部分拼接
记录合并规则 (merges)
d. 执行:
_merge_encoding(单词编码表, best_pair, 新ID)
e. 步进:
当前ID += 1
4. 【返回】 词表, 合并规则
2. 代码框架 (Python Framework)
这是你可以在作业中直接使用的骨架。注意看它是如何调用我们之前写的那些"小工具"的。
class BPETrainer_Framework:
def train(
self,
text: str,
vocab_size: int,
special_tokens: List[str]
) -> Tuple[Dict[int, bytes], List[Tuple[bytes, bytes]]]:
# ==========================================
# Phase 1: Initialization (初始化)
# ==========================================
# 1. 基础词表 (0-255)
vocabulary = {i: bytes([i]) for i in range(256)}
next_id = 256
# 2. 加入特殊 Token (不做 BPE 合并,直接占位)
for token in special_tokens:
vocabulary[next_id] = token.encode('utf-8')
next_id += 1
# ==========================================
# Phase 2: Pre-processing (预处理)
# ==========================================
print("1. Pre-tokenizing text...")
# 调用我们实现的预分词函数
# 输出示例: {'hello': 5, 'world': 3}
word_counts = self._pretokenize(text, special_tokens)
# 初始化编码状态:把单词变成 ID 列表
# 示例: {'hello': [104, 101, ...]}
word_encodings = {
word: list(word.encode('utf-8'))
for word in word_counts
}
# ==========================================
# Phase 3: The Training Loop (训练主循环)
# ==========================================
merges = [] # 记录合并规则:(b'e', b's') -> 256
print(f"2. Starting BPE loop. Target size: {vocab_size}")
while len(vocabulary) < vocab_size:
# --- Step A: 统计 (Statistics) ---
# 传入 counts 是为了加权,传入 encodings 是为了看当前切分状态
pair_counts = self._count_pairs(word_counts, word_encodings)
if not pair_counts:
print("No more pairs to merge. Stopping early.")
break
# --- Step B: 决策 (Selection) ---
# 这里的 Key 是 Python 里的 Tuple 比较规则:
# 先比频率 count (x[1]),频率一样比 bytes 字典序 (vocab[...])
# 注意:Assignment 要求 Tie 时取字典序靠后的,还是靠前的?
# 如果是 "lexicographically larger",则直接用 tuple 比较即可
best_pair = max(
pair_counts,
key=lambda p: (pair_counts[p], vocabulary[p[0]] + vocabulary[p[1]])
)
# --- Step C: 更新词表 (Vocabulary Update) ---
new_token_bytes = vocabulary[best_pair[0]] + vocabulary[best_pair[1]]
vocabulary[next_id] = new_token_bytes
merges.append((vocabulary[best_pair[0]], vocabulary[best_pair[1]]))
# --- Step D: 执行合并 (Apply Merge) ---
# 这里是最耗时的部分,需要遍历字典
# 优化点:只遍历包含 best_pair 的词 (Assignment 进阶优化)
for word in word_encodings:
# 调用我们实现的合并函数
new_encoding = self._merge_encoding(
word_encodings[word],
best_pair,
next_id
)
word_encodings[word] = new_encoding
# --- Step E: 步进 ---
next_id += 1
# 打印进度 (Optional)
if (len(vocabulary) - 256) % 10 == 0:
print(f"Created token {next_id-1}: {new_token_bytes} (Freq: {pair_counts[best_pair]})")
return vocabulary, merges
# --- Helper Functions 占位符 ---
def _pretokenize(self, text, special_tokens):
pass
def _count_pairs(self, word_counts, word_encodings):
pass
def _merge_encoding(self, encoding, pair, new_id):
pass
上面的伪代码逻辑和代码框架是我让 gemini 根据我的实现总结提取的,我发现它用的跟我写的有点出入,不过逻辑是一致的,不知道它是哪里我资料还是自己发挥的。 下面开始实现辅助函数
我的分步验证是在 colab 中做的,只是为了验证结果符合预期,我就贴下结果(项目的.toml 没有notebook 的依赖,本身电脑配置有限,我也不想单独安装了,我是WSL的环境,2060的显卡,6G显存,后面的任务打算租)。
第一步:预分词与计数 (Pre-tokenization)
BPE 不是在整段文本上跑的,而是先用正则把句子切成"单词",然后统计每个单词出现的频率。这样我们只需要处理 unique words,效率会高很多。
这个文档中有提到,我们在统计相邻 pair 时,遍历所有字符效率太低,我们可以统计单词的频率,将包含该pair 的单词频次作为该 pair的频次累加。
-
代码实现
import regex as re
from collections import Counter, defaultdict
from typing import Dict, List, TupleGPT-2/4 使用的正则模式
PATTERN = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
def _pretokenize(text: str, special_tokens: list[str]) -> Dict[str, int]:
# 1. 移除特殊 token(它们不参与 BPE 训练)
for token in special_tokens:
text = text.replace(token, '')# 2. 正则切分 words = re.findall(PATTERN, text) # 3. 统计频率 return dict(Counter(words)) -
单元测试与验证
test_text = "low low low low low<|endoftext|> lower lower<|endoftext|> widest widest widest<|endoftext|> newest newest newest newest newest newest"
special_tokens = ["<|endoftext|>"]
word_counts = _pretokenize(test_text, special_tokens)
print(word_counts)
👀 观察输出:
分词结果 (5 unique words):
{'low': 1, ' low': 4, ' lower': 2, ' widest': 3, ' newest': 6}
这里有一点需要注意,'low': 1, ' low': 4,这个是有区别的,一个前面没有空格,一个前面有空格。
-
正则表达式
| ?\p{L}+中的?表示"可选空格"'low'匹配了没有前导空格的情况(句首)' low'匹配了有前导空格的情况(其他位置)
-
测试文本处理后变成:
"low low low low low lower lower" ^ ^ ^ ^ ^ ^ ^ | └───┴───┴───┴───┘ └─── 都有前导空格 └── 句首,无前导空格 -
为什么要区分它们?
- 为了能够完美重建原文(包括空格)
- GPT-2/GPT-3 都是这样设计的
- 这样 tokenizer 是可逆的 :
text → tokens → text
'low'和' low'会被当作不同的 pre-token- 它们有不同的 UTF-8 编码
- BPE 会独立处理它们
- 最终的 tokenizer 可能会学到:
-
'low'的 merge(用于句首) -
' low'的 merge(用于其他位置)这两个是不同的!
'low'.encode('utf-8') # b'low' [108, 111, 119]
' low'.encode('utf-8') # b' low' [32, 108, 111, 119]^
空格的字节值是 32
-
第二步:初始化词编码 (Word Encodings)
BPE 是字节级(Byte-level)的。我们需要把每个单词转换成 UTF-8 字节的整数列表。这是所有合并操作的"底座"。
-
代码实现
将字符串转换为 UTF-8 字节列表
word_encodings = {}
for word in word_counts:
word_encodings[word] = list(word.encode("utf-8")) -
单元测试与验证
print("初始编码状态 (Word Encodings):")
print("word_encodings:",word_encodings) -
👀 观察输出:
初始编码状态 (Word Encodings):
word_encodings: {'low': [108, 111, 119], ' low': [32, 108, 111, 119], ' lower': [32, 108, 111, 119, 101, 114], ' widest': [32, 119, 105, 100, 101, 115, 116], ' newest': [32, 110, 101, 119, 101, 115, 116]}
✅ 符合预期:
l是 108,o是 111,w是 119。- 空格被编码为
32。- 现在每个词都是一个
List[int],准备好被合并了。
第三步:统计字节对频率 (Count Pairs)
这是 BPE 算法的心脏。我们需要扫描所有单词,统计相邻两个 ID 出现的总次数。 关键点 :如果单词 'newest' 的频率是 6,那么它里面的 ('e', 's') 这一对也贡献了 6 次计数。
-
代码实现
def _count_pairs(word_counts: Dict[str, int], word_encodings: Dict[str, List[int]]) -> Dict[Tuple[int, int], int]:
pair_counts = defaultdict(int)
for word, count in word_counts.items():
encoding = word_encodings[word]
for i in range(len(encoding) - 1):
pair = (encoding[i], encoding[i+1])
pair_counts[pair] += count
return dict(pair_counts) -
单元测试与验证
pair_counts = _count_pairs(word_counts, word_encodings)
print("pair_counts:",pair_counts)
👀 观察输出:
-
输出:
pair_counts: {(108, 111): 7, (111, 119): 7, (32, 108): 6, (119, 101): 8, (101, 114): 2, (32, 119): 3, (119, 105): 3, (105, 100): 3, (100, 101): 3, (101, 115): 9, (115, 116): 9, (32, 110): 6, (110, 101): 6, (101, 119): 6}
核对结果
字节对统计结果:
(101, 115): 9 <-- 'e', 's'
(115, 116): 9 <-- 's', 't'
(119, 101): 8 <-- 'w', 'e'
(108, 111): 7 <-- 'l', 'o'
(111, 119): 7 <-- 'o', 'w'
✅ 符合预期:
'est'里的e(101)-s(115)和s(115)-t(116)并列第一(9次)。- 来源:
widest(3次) +newest(6次) = 9次。逻辑完全正确。
第四步:选择最佳合并 (Find Best Merge)
根据上面的统计,我们找到了频率最高的对。如果有多个频率相同的(Tie),通常按照字典序(Lexicographical order)来选,或者直接取第一个。这里为了演示简单,我们直接用 max。 有了最佳对 (101, 115),我们要创造一个新 ID(比如 256),然后把所有单词里的 101, 115 替换成 256。
-
代码实现
def _merge_encoding(encoding: List[int], pair: Tuple[int, int], new_id: int) -> List[int]:
new_encoding = []
i = 0
while i < len(encoding):
if i < len(encoding) - 1 and encoding[i] == pair[0] and encoding[i+1] == pair[1]:
new_encoding.append(new_id)
i += 2
else:
new_encoding.append(encoding[i])
i += 1
return new_encoding -
测试验证
找出频率最大的 pair
注意:正式代码中需要处理 (101, 115) 和 (115, 116) 频率相同的情况
merge_pair = max(pair_counts, key=pair_counts.get) # 找出频率最大的pair,初版忽略相等需要字典序的情况
print("Max pair:",merge_pair)
max_count = pair_counts[merge_pair]
print("Max count:",max_count)这里我手动赋予新的token id 就是256,只是为了验证合并的逻辑,实际操作中是由前面的步骤决定这个值是多大的。
for word in word_encodings:
new_encoding = _merge_encoding(word_encodings[word], merge_pair, 256)
word_encodings[word] = new_encoding
print("After once merge word_encodings:",word_encodings)
👀 观察输出:
-
输出
Max pair: (101, 115) # 可以看出忽略字典序,系统自动合并的是 es。平局时按照字典序应该合并的是st。
Max count: 9
After once merge word_encodings: {'low': [108, 111, 119], ' low': [32, 108, 111, 119], ' lower': [32, 108, 111, 119, 101, 114], ' widest': [32, 119, 105, 100, 256, 116], ' newest': [32, 110, 101, 119, 256, 116]} -
观察输出
👀 观察输出:合并前 'newest': [32, 110, 101, 119, 101, 115, 116]
合并后 'newest': [32, 110, 101, 119, 256, 116]
✅ 符合预期:
- 原来的
..., 101, 115, ...(e, s) 成功变成了..., 256, ...(es)。- 列表长度减少了 1,压缩成功!
第五步:整合,词表与规则记录
最后,我们需要记录下这一步发生了什么,以便将来用于 Tokenizer 的推理。
# 初始化基础词表 (0-255)
vocabulary = {i: bytes([i]) for i in range(256)}
# 记录本次操作
merges = []
merge_bytes = vocabulary[merge_pair[0]] + vocabulary[merge_pair[1]]
print("merge bytes:", merge_bytes)
vocabulary[256] = merge_bytes
merges.append((vocabulary[merge_pair[0]], vocabulary[merge_pair[1]]))
print("merges:", merges)
👀 观察输出:
merge bytes: b'es'
merges: [(b'e', b's')]
总结
通过这一系列的单步调试,我们清晰地看到 BPE 的本质就是一个 "统计 -> 替换 -> 循环" 的过程:
- 数据状态 :从
[e, s, t]变成[es, t]。 - 词表状态:从 256 个基础字节扩充到包含 'es'。
- 循环 :下一轮循环时,
_count_pairs就会统计到(256, 116)(即 es+t) 这样的新组合了。
接下来,只需要把上述步骤放入一个 while 循环中就可以完成完整的 BPE 训练器!
优化
- count_pair 优化
-
当频率相同时,取字典序大的
-
增加了pair 对的缓存,提升效率
-
代码实现
def _count_pairs_v2(
word_counts: Dict[str, int],
word_encodings: Dict[int, List[int]],
pair_strings: Dict[Tuple[int, int], Tuple[bytes, bytes]],
vocabulary: Dict[int, bytes],
) -> Dict[Tuple[int, int], int]:
pair_counts = defaultdict(int)
for word, count in word_counts.items():
encoding = word_encodings[word]
for i in range(len(encoding) - 1):
pair = (encoding[i], encoding[i+1])
pair_counts[pair] += count
# cache pair strings for save query time
if pair not in pair_strings:
pair_strings[pair] = (vocabulary[pair[0]], vocabulary[pair[1]])
return pair_counts -
测试验证:
test_text = "low low low low low<|endoftext|> lower lower<|endoftext|> widest widest widest<|endoftext|> newest newest newest newest newest newest"
special_tokens = ["<|endoftext|>"]
vocabulary = {i: bytes([i]) for i in range(256)}
size = 256
new_token = size
word_counts = _pretokenize(test_text, special_tokens)
size += 1
new_token = size
print(word_counts)
word_encodings = {}
for word in word_counts:
word_encodings[word] = list(word.encode("utf-8"))
print("word_encodings:",word_encodings)pair_counts = _count_pairs(word_counts, word_encodings)
pair_strings = {}
pair_counts = _count_pairs_v2(word_counts, word_encodings, pair_strings, vocabulary)
print("pair_counts:",pair_counts)merge_pair = max(pair_counts, key=pair_counts.get) # 找出频率最大的pair,初版忽略相等需要字典序的情况
merge_pair = max(pair_counts, key=lambda x: (pair_counts[x], pair_strings[x])) # 频率相同时取字典序
print("Max pair:",merge_pair)
max_count = pair_counts[merge_pair]
print("Max count:",max_count)for word in word_encodings:
new_encoding = _merge_encoding(word_encodings[word], merge_pair, new_token)
word_encodings[word] = new_encoding
print("After once merge word_encodings:",word_encodings)vocabulary = {i: bytes([i]) for i in range(256)}
merges = []
merge_bytes = vocabulary[merge_pair[0]] + vocabulary[merge_pair[1]]
print("merge bytes:", merge_bytes)
vocabulary[size] = merge_bytes
merges.append((vocabulary[merge_pair[0]], vocabulary[merge_pair[1]]))
print("merges:", merges) -
输出
{'low': 1, ' low': 4, ' lower': 2, ' widest': 3, ' newest': 6}
word_encodings: {'low': [108, 111, 119], ' low': [32, 108, 111, 119], ' lower': [32, 108, 111, 119, 101, 114], ' widest': [32, 119, 105, 100, 101, 115, 116], ' newest': [32, 110, 101, 119, 101, 115, 116]}
pair_counts: defaultdict(<class 'int'>, {(108, 111): 7, (111, 119): 7, (32, 108): 6, (119, 101): 8, (101, 114): 2, (32, 119): 3, (119, 105): 3, (105, 100): 3, (100, 101): 3, (101, 115): 9, (115, 116): 9, (32, 110): 6, (110, 101): 6, (101, 119): 6})
Max pair: (115, 116)
Max count: 9
After once merge word_encodings: {'low': [108, 111, 119], ' low': [32, 108, 111, 119], ' lower': [32, 108, 111, 119, 101, 114], ' widest': [32, 119, 105, 100, 101, 257], ' newest': [32, 110, 101, 119, 101, 257]}
merge bytes: b'st'
merges: [(b's', b't')]
👀 观察输出:
- merge 的pair 变成了
(115, 116) - 合并如之前分析的是
(b's', b't') - new_token_id 是257,因为初始词表大小256,还有一个special token,所以merge 的新token 从257 开始。 关于merge pair 还有一个小技巧,可以使用
zip(),pair = zip(encoding, encoding[1:])。
-
merge_encoding 优化
在 merge前先判断 encoding 是否有变化,有变化才进行merge -
_merge_encodingdef _merge_encoding( encoding: List[int], merge_pair: Tuple[int, int], new_token_id: int ) -> Tuple[List[int], bool]: """ return word encoding and has changed """ new_encoding = [] has_changed = False i = 0 while i < len(encoding): if i < len(encoding) - 1 and (encoding[i], encoding[i+1]) == merge_pair: new_encoding.append(new_token_id) has_changed = True i += 2 else: new_encoding.append(encoding[i]) i += 1 return new_encoding, has_changed -
训练函数
traindef train( self, text: str, vocab_size: int, special_tokens: List[str] ) -> Tuple[Dict[int, bytes], List[Tuple[bytes, bytes]]]: # Implementation similar to V2 with has_new_id optimization vocabulary = {i: bytes([i]) for i in range(256)} size = 256 for token in special_tokens: vocabulary[size] = token.encode('utf-8') size += 1 word_counts = self._pretokenize(text, special_tokens) word_encodings = {word: list(word.encode('utf-8')) for word in word_counts} merges = [] pair_strings = {} num_merges = vocab_size - size for merge_idx in range(num_merges): pair_counts = self._count_pairs(word_counts, word_encodings, pair_strings, vocabulary) if not pair_counts: break merge_pair = max( pair_counts, key=lambda x: (pair_counts[x], pair_strings[x]) ) max_count = pair_counts[merge_pair] merge_bytes = vocabulary[merge_pair[0]] + vocabulary[merge_pair[1]] new_id = size size += 1 # optimization updated_count = 0 for word in word_encodings: new_encoding, has_changed = self._merge_encoding(word_encodings[word], merge_pair, new_id) if has_changed: word_encodings[word] = new_encoding updated_count += 1 if merge_idx < 5: print(f"Merge #{merge_idx+1}: {merge_pair} (freq: {max_count}) -> ID {new_id}, updated {updated_count} words") merges.append((vocabulary[merge_pair[0]], vocabulary[merge_pair[1]])) return vocabulary, merges
完整实现见 train_bpe.py
下一步就是实现训练文本是从内存中读取,内存中的文本也是从磁盘加载到内存的,而且对于大文本也得考虑数据是无法一次全部加载到内存的,需要分块加载。文档也提到了并行。
附 train_bpe.py 简单版完整代码:
python
"""
BPE training algorithm
"""
import regex as re
from collections import Counter, defaultdict
from typing import List, Tuple, Dict
class BPETrainer_V1:
"""
V1: sample version, only for test
"""
def __init__(self):
self.pattern = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
def train(
self,
text: str,
vocab_size: int,
special_tokens: List[str]
) -> Tuple[Dict[int, bytes], List[Tuple[bytes, bytes]]]:
"""
Args:
text: str, text for used to train the BPE
vocab_size: int, the size of the vocabulary
special_tokens: List[str], the special tokens to be added to the vocabulary, but not be merged and trained
Returns:
vocabulary: Dict[int, bytes], the vocabulary of the BPE
merges: List[Tuple[bytes, bytes]], the merge rules of the BPE
Process:
1. Initialize the vocabulary
2. Add special tokens to the vocabulary
3. Pre-tokenize the text
4. Initialize word encodings
5. BPE training loop
"""
# 1. Initialize the vocabulary with 256 bytes
vocabulary = {i: bytes([i]) for i in range(256)}
size = 256
# 2. Add special tokens to the vocabulary
for token in special_tokens:
vocabulary[size] = token.encode('utf-8')
size += 1
# 3. Pre-tokenize the text by word level and count the frequency of each word
print("📝 Pre-tokenizing...")
word_counts = self._pretokenize(text, special_tokens)
print(f"Found {len(word_counts)} different word in text after remove special tokens")
# 4. Initialize word encodings
print("Initializing word encodings...")
word_encodings = {}
for word in word_counts:
word_encodings[word] = list(word.encode('utf-8'))
# 5. BPE training loop
print(f"Starting BPE training...")
num_merges = vocab_size - size
print(f"Need {num_merges} loop to reach the largest vocabulary size {vocab_size}")
merges = []
for merge_idx in range(num_merges):
# a. Count the frequency of each pair of adjacent tokens
pair_counts = self._count_pairs(word_counts, word_encodings)
if not pair_counts:
print("No more pair to be merges, quit")
break
# b. Find the max count pair
merge_pair = max(pair_counts, key=pair_counts.get)
max_count = pair_counts[merge_pair]
# c. Create new token
new_token_id = size
vocabulary[new_token_id] = vocabulary[merge_pair[0]] + vocabulary[merge_pair[1]]
size += 1
# d. Update word encodings
for word in word_encodings:
new_encoding = self._merge_encoding(word_encodings[word], merge_pair, new_token_id)
word_encodings[word] = new_encoding
# e. Recoding merges
merges.append((vocabulary[merge_pair[0]], vocabulary[merge_pair[1]]))
# f. Print the merge process
if merge_idx < 5 or merge_idx % 10 == 0:
print(f"Merge #{merge_idx+1}: {merge_pair} (freq: {max_count}) -> ID {new_token_id}")
return vocabulary, merges
def _pretokenize(self, text: str, special_tokens: List[str]) -> Dict[str, int]:
# Remove special tokens from text
for token in special_tokens:
text = text.replace(token, '')
# Split text into words using regex
words = re.findall(self.pattern, text)
# Count the frequency of each word
return dict(Counter(words))
def _count_pairs(
self,
word_counts: Dict[str, int],
word_encodings: Dict[str, List[int]]
) -> Dict[Tuple[int, int], int]:
pair_counts = defaultdict(int)
for word, count in word_counts.items():
encoding = word_encodings[word]
for i in range(len(encoding) - 1):
pair = (encoding[i], encoding[i+1])
pair_counts[pair] += count
return pair_counts
def _merge_encoding(
self,
encoding: List[int],
merge_pair: Tuple[int, int],
new_token_id: int
) -> List[int]:
new_encoding = []
p0, p1 = merge_pair
i = 0
while i < len(encoding):
if i < len(encoding) -1 and encoding[i] == p0 and encoding[i+1] == p1:
new_encoding.append(new_token_id)
i += 2
else:
new_encoding.append(encoding[i])
i += 1
return new_encoding
class BPETrainer_V2:
"""
Added:
- In case of tie in pair frequency, choose the lexicographically larger pair
- pair strings buffer
- max key function optimization
"""
def __init__(self):
self.pattern = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
def train(
self,
text: str,
vocab_size: int,
special_tokens: List[str]
) -> Tuple[Dict[int, bytes], List[Tuple[bytes, bytes]]]:
# Implementation similar to V1 with optimizations and tie-breaking logic
vocabulary = {i: bytes([i]) for i in range(256)}
size = 256
for token in special_tokens:
vocabulary[size] = token.encode('utf-8')
size += 1
print("📝 Pre-tokenizing...")
word_counts = self._pretokenize(text, special_tokens)
print(f"Found {len(word_counts)} different word in text after remove special tokens")
print("Initializing word encodings...")
word_encodings = {word: list(word.encode('utf-8')) for word in word_counts}
print(f"Starting BPE training...")
merges = []
pair_strings = {} # new buffer for pair strings
num_merges = vocab_size - size
for merge_idx in range(num_merges):
pair_counts = self._count_pairs(word_counts, word_encodings, pair_strings, vocabulary)
if not pair_counts:
break
# fix max frequency tie by lexicographical order
merge_pair = max(
pair_counts,
key=lambda x: (pair_counts[x], pair_strings[x])
)
max_count = pair_counts[merge_pair]
merge_bytes = vocabulary[merge_pair[0]] + vocabulary[merge_pair[1]]
new_token_id = size
vocabulary[new_token_id] = merge_bytes
size += 1
if merge_idx < 5 or merge_idx % 10 == 0:
print(f"Merge #{merge_idx+1}: {merge_pair} (freq: {max_count}) -> ID {new_token_id}")
for word in word_encodings:
new_encoding = self._merge_encoding(word_encodings[word], merge_pair, new_token_id)
word_encodings[word] = new_encoding
merges.append((vocabulary[merge_pair[0]], vocabulary[merge_pair[1]]))
return vocabulary, merges
def _pretokenize(self, text: str, special_tokens: List[str]) -> Dict[str, int]:
for token in special_tokens:
text = text.replace(token, '')
words = re.findall(self.pattern, text)
return dict(Counter(words))
def _count_pairs(
self,
word_counts: Dict[str, int],
word_encodings: Dict[int, List[int]],
pair_strings: Dict[Tuple[int, int], Tuple[bytes, bytes]],
vocabulary: Dict[int, bytes],
) -> Dict[Tuple[int, int], int]:
pair_counts = defaultdict(int)
for word, count in word_counts.items():
encoding = word_encodings[word]
for i in range(len(encoding) - 1):
pair = (encoding[i], encoding[i+1])
pair_counts[pair] += count
# cache pair strings for save query time
if pair not in pair_strings:
pair_strings[pair] = (vocabulary[pair[0]], vocabulary[pair[1]])
return pair_counts
def _merge_encoding(
self,
encoding: List[int],
merge_pair: Tuple[int, int],
new_token_id: int
) -> List[int]:
new_encoding = []
i = 0
while i < len(encoding):
if i < len(encoding) - 1 and (encoding[i], encoding[i+1]) == merge_pair:
new_encoding.append(new_token_id)
i += 2
else:
new_encoding.append(encoding[i])
i += 1
return new_encoding
# ============================================
# Version 3: add has_new_id opitimization
# ============================================
class BPETrainer_V3:
def __init__(self):
self.pattern = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
def train(
self,
text: str,
vocab_size: int,
special_tokens: List[str]
) -> Tuple[Dict[int, bytes], List[Tuple[bytes, bytes]]]:
# Implementation similar to V2 with has_new_id optimization
vocabulary = {i: bytes([i]) for i in range(256)}
size = 256
for token in special_tokens:
vocabulary[size] = token.encode('utf-8')
size += 1
word_counts = self._pretokenize(text, special_tokens)
word_encodings = {word: list(word.encode('utf-8')) for word in word_counts}
merges = []
pair_strings = {}
num_merges = vocab_size - size
for merge_idx in range(num_merges):
pair_counts = self._count_pairs(word_counts, word_encodings, pair_strings, vocabulary)
if not pair_counts:
break
merge_pair = max(
pair_counts,
key=lambda x: (pair_counts[x], pair_strings[x])
)
max_count = pair_counts[merge_pair]
merge_bytes = vocabulary[merge_pair[0]] + vocabulary[merge_pair[1]]
new_id = size
size += 1
# optimization
updated_count = 0
for word in word_encodings:
new_encoding, has_changed = self._merge_encoding(word_encodings[word], merge_pair, new_id)
if has_changed:
word_encodings[word] = new_encoding
updated_count += 1
if merge_idx < 5:
print(f"Merge #{merge_idx+1}: {merge_pair} (freq: {max_count}) -> ID {new_id}, updated {updated_count} words")
merges.append((vocabulary[merge_pair[0]], vocabulary[merge_pair[1]]))
return vocabulary, merges
def _merge_encoding(
encoding: List[int],
merge_pair: Tuple[int, int],
new_token_id: int
) -> Tuple[List[int], bool]:
"""
return word encoding and has changed
"""
new_encoding = []
has_changed = False
i = 0
while i < len(encoding):
if i < len(encoding) - 1 and (encoding[i], encoding[i+1]) == merge_pair:
new_encoding.append(new_token_id)
has_changed = True
i += 2
else:
new_encoding.append(encoding[i])
i += 1
return new_encoding, has_changed
# ============================================
# Test Code
# ============================================
def test_versions():
test_text = "low low low low low<|endoftext|> lower lower<|endoftext|> widest widest widest<|endoftext|> newest newest newest newest newest newest<|endoftext|>"
print("="*60)
print("Test data:", test_text[:50], "...")
print("="*60)
print()
# Test V1
print("【Version 1】Sample BPE trainer sample implementation")
print("-"*60)
trainer_v1 = BPETrainer_V1()
vocab_v1, merges_v1 = trainer_v1.train(test_text, 262, [])
print(f"✅ Training complete,vocab_size: {len(vocab_v1)}")
print()
# Test V2
print("【Version 2】add lexicographical order")
print("-"*60)
trainer_v2 = BPETrainer_V2()
vocab_v2, merges_v2 = trainer_v2.train(test_text, 262, [])
print(f"✅ Training complete,vocab_size: {len(vocab_v2)}")
print()
if __name__ == "__main__":
test_versions()