1. 瓶颈在哪里
回顾下 BPETrainer_Simple 中,循环体内部是这样的:
for merge_idx in range(num_merges):
# 1. 极其耗时:每次都要遍历所有单词,重新计算所有 Pair 的频率
pair_counts = self._count_pairs(word_counts, word_encodings, ...)
# 2. 找到频率最高的 pair
merge_pair = max(pair_counts, ...)
# 3. 极其耗时:再次遍历所有单词,进行字符串替换
for word in word_encodings:
self._merge_encoding(...)
可以发现
- 在计算
pair的频率时每次都要遍历所有 word,并找到频率最大的pair; - 对
pairmerge 操作后更新包含该 pair 的word 的 encoding 时又是遍历所有 word。
假设有 10,000 个单词,要进行 1,000 次合并。你的代码执行了 1,000×10,000≈1071,000×10,000≈107 次扫描操作。
2. 如何优化
- 全局 Pair 计数缓存 (
pair_counts):不再每次重新数,而是维护一个全局的计数器。 - 倒排索引 (
pair_to_words):记录"哪个 Pair 出现在了哪些单词里"。当合并一个pair时,只去更新那些包含这个pair的单词,而不是扫描所有单词。
引入**倒排索引(Inverted Index)**的思想是:
- pair_to_words: Dict[Tuple, Set[str]]
- key: 是一个 pair, 如
('e', 's')。 - value 是一个 set, 包含所有包含这个 pair 的word,如
{'newest', 'widest', ...}。
- key: 是一个 pair, 如
2.1 优化后的流程
- 初始化(只做一次) ,扫描所有
word,建立全局的pair_counts和 pair_to_words。 - 循环合并
- 从
pair_counts中直接拿到最大频率的pair,时间复杂度(O(1)或O(K))。 - 从 pair_to_words 找到只受影响的 words(Affected Words)。
- 只对这些受影响的 words 进行更新。
- 减去这些 words 中旧 pair 的计数。
- 在 words 中执行合并。
- 加上这些 words 新产生的 pair 的计数。
这种方式下,如果合并('e', 's')只影响了字典里 1% 的单词,你就只需要处理这 1% 的数据,速度提升是数量级的。
2.2 举例说明
用文档中方的语料
low low low low low
lower lower widest widest widest
newest newest newest newest newest newest
在经过 preTokenize 后得到的词频统计如下:
{low: 5, lower: 2, newest: 6, widest: 3}
为了方便演示,我这里假设此时字符已经编码为 ID,但这里为了直观,我直接用字符表示。
第一步:初始化 (Pre-tokenize & Count)
优化代码中的 _pretokenize 和初始 _count_pairs 执行后,内存里的状态如下:
1. word_encodings (当前单词的状态):
low:['l', 'o', 'w']lower:['l', 'o', 'w', 'e', 'r']newest:['n', 'e', 'w', 'e', 's', 't']widest:['w', 'i', 'd', 'e', 's', 't']
2. pair_counts (全局计数器):
('l', 'o'): 5 (from low) + 2 (from lower) = 7('o', 'w'): 5 (from low) + 2 (from lower) = 7('e', 's'): 6 (from newest) + 3 (from widest) = 9('s', 't'): 6 (from newest) + 3 (from widest) = 9('w', 'e'): 2 (from lower) = 2 ... 等等
3. pair_to_words (倒排索引 - 关键优化):
('e', 's')->{'newest', 'widest'}<-- 注意这里,只有这两个词包含 e,s('l', 'o')->{'low', 'lower'}- ...
第二步:开始训练循环
Iteration 1:
- Find Max: 在
pair_counts中找到频率最高的。假设是('e', 's'),计数为 9。 - Merge: 我们决定将
e和s合并为新 Tokenes(假设 ID 为 256)。 - Find Affected Words:
- 查
pair_to_words[('e', 's')]。 - 得到受影响的单词列表:
{'newest', 'widest'}。 - 重点: 我们完全忽略
low和lower,因为它们不包含('e', 's')。
- Update Affected Words (函数
_updated_affected_word_count): 我们以newest(频率 6) 为例,演示如何增量更新:
- 旧状态 :
['n', 'e', 'w', 'e', 's', 't'] - 待合并 Pair :
('e', 's')
(A) 扣除旧 Pair 的计数 (Decrement): 在进行合并前,我们要把 newest 里将会被破坏的 Pair 的计数减掉。
('w', 'e'): 它的右边e要被合并了,这个 Pair 即将消失。pair_counts[('w', 'e')] -= 6。('e', 's'): 这是我们要合并的,计数减掉。pair_counts[('e', 's')] -= 6。('s', 't'): 它的左边s要被合并了,这个 Pair 即将消失。pair_counts[('s', 't')] -= 6。- 同时,从
pair_to_words对应的 Set 中移除newest。
(B) 执行合并 (Merge):
newest变为:['n', 'e', 'w', 'es', 't'](注意:这里假设第一个 e 不受影响,只合并后面的 es)
(C) 增加新 Pair 的计数 (Increment): 新产生了哪些 Pair?
('w', 'es'): 新产生的。pair_counts[('w', 'es')] += 6。并将newest加入pair_to_words[('w', 'es')]。('es', 't'): 新产生的。pair_counts[('es', 't')] += 6。并将newest加入pair_to_words[('es', 't')]。
对 widest (频率 3) 重复上述 A, B, C 步骤。
Iteration 1 结束:
pair_counts已经是最新的了(不需要重新扫描全库)。word_encodings里的newest和widest已经更新。low和lower保持原样,甚至没有被读取过。
总结
之前的 BPETrainer_Sampler 和当前优化版的区别就像是:
- 简单版(老师点名) :每次有人迟到,老师就重新从花名册第一个人点到最后一个人,数一遍谁没来。
- 当前优化版(请假条制度) :老师手里有个总人数。每次有人迟到(合并),只需要处理那几个迟到的人 ,总人数减去请假的人,加上新来的人(如果有),不需要管那些正常坐在座位上的同学。
我写完,让 Gemini 帮我review我的笔记和代码,并告诉它背景,我打算以博客形式分享到网上的,有这么一句,我觉得它说的挺好的
我们将像修房子一样,先打地基(数据结构),再修墙(辅助函数),最后封顶(主循环)。
而且它给的代码加的注释也挺好的。
2.3 代码实现
第一步:引入倒排索引(改造"地基"-来自Gemini)
需要在代码中引入一个新的字典: pair_to_words。
旧逻辑:只有 pair_counts(记分牌)
新逻辑:增加 pair_to_words(索引表),用来记录某个 pair 出现在哪些 words 中。
修改 _count_pairs 函数,不仅返回计数 pair_counts,还需要返回谁拥有这些 pair 的记录表 pair_to_words
实现如下:
# before:just return pair_counts
# now:add pair_to_words
def _count_pairs(
self,
word_counts: Dict[str, int],
word_encodings: Dict[str, List[int]],
pair_strings: Dict[Tuple[int, int], Tuple[bytes, bytes]],
vocabulary: Dict[int, bytes],
pair_to_words: Dict[Tuple[int, int], set] # <--- new patameter added
) -> Dict[Tuple[int, int], int]:
pair_counts = defaultdict(int)
for word, count in word_counts.items():
encoding = word_encodings[word]
# loop all Pairs
for i in range(len(encoding) - 1):
pair = (encoding[i], encoding[i+1])
# 1. Count (same to last version)
pair_counts[pair] += count
# 2. [added] record inverted index:this pair included in this word
pair_to_words[pair].add(word)
# 3. record the bytes of the pair(same to last version)
if pair not in pair_strings:
pair_strings[pair] = (vocabulary[pair[0]], vocabulary[pair[1]])
return pair_counts
第二步:编写核心"手术刀"函数 (增量更新)
这是最难的一步。以前是"全量扫描",现在我们要写一个函数,专门处理单个单词的更新。
想象这个函数是一个会计,它在修改单词结构时,必须保证账目(pair_counts)平配。 逻辑是:先退款(减旧计数),再合并,最后重新下单(加新计数)。
在类中添加下面新函数:
def _update_word_for_merge(
self,
word: str,
merge_pair: Tuple[int, int],
new_token_id: int,
word_counts: Dict[str, int],
word_encodings: Dict[str, List[int]],
pair_counts: Dict[Tuple[int, int], int],
pair_to_words: Dict[Tuple[int, int], set],
pair_strings: Dict[Tuple[int, int], Tuple[bytes, bytes]],
vocabulary: Dict[int, bytes]
):
encoding = word_encodings[word]
count = word_counts[word]
# --- 阶段 1: "退款" (Decrement) ---
# 在修改编码前,先把它内部所有旧 Pair 的计数减掉,并从索引中移除
# minus count for these word include the pair
for i in range(len(encoding) - 1):
old_pair = (encoding[i], encoding[i+1])
pair_counts[old_pair] -= count
pair_to_words[old_pair].discard(word) # 从该 pair 的名单里把这个词删掉
# 如果计数归零,可以清理(可选,为了省内存)
if pair_counts[old_pair] <= 0:
del pair_counts[old_pair]
# --- 阶段 2: "合并" (Merge) ---
# merge
new_encoding = []
i = 0
while i < len(encoding):
if i < len(encoding) - 1 and (encoding[i], encoding[i+1]) == merge_pair:
new_encoding.append(new_token_id)
i += 2
else:
new_encoding.append(encoding[i])
i += 1
# 更新单词的编码
word_encodings[word] = new_encoding
# --- 阶段 3: "重新下单" (Increment) ---
# 遍历新的编码,把新产生的 Pair 计数加上,并加入索引
# add the count of the new pair
for i in range(len(new_encoding) - 1):
new_pair = (new_encoding[i], new_encoding[i+1])
pair_counts[new_pair] += count
pair_to_words[new_pair].add(word) # 把这个词加入新 pair 的名单
# 顺便记录一下 pair 的 string (画图用或调试用)
if new_pair not in pair_strings:
pair_strings[new_pair] = (vocabulary[new_pair[0]], vocabulary[new_pair[1]])
第三步:重组"主循环" (Train Function)
现在我们要把 train 函数里的旧循环拆掉,换成新的逻辑。
主要改动点:
-
初始化时创建
pair_to_words。 -
循环内不再 调用
_count_pairs。 -
循环内通过
pair_to_words[merge_pair]找到受影响的词。def train(self, input_path, vocab_size, special_tokens, **kwargs): # ... (前面的代码不变:初始化 vocabulary, pre-tokenize, word_encodings) ... # 1. Initialize the vocabulary vocabulary = {i: bytes([i]) for i in range(N_BYTES)} size = N_BYTES for token in special_tokens: vocabulary[size] = token.encode('utf-8') size += 1 # 2. Pre-tokenize word_counts = self._pretokenize(input_path, special_tokens) # 3. Initialize the word encodings word_encodings = {word: list(word.encode('utf-8')) for word in word_counts} print(f"Starting BPE training...") merges = [] # --- 【改动开始】 --- # 准备数据结构 pair_strings = {} pair_to_words = defaultdict(set) # <--- 新增 # 4. 初始全量统计 (只做这一次!) pair_counts = self._count_pairs( word_counts, word_encodings, pair_strings, vocabulary, pair_to_words ) num_merges = vocab_size - size for merge_idx in range(num_merges): if not pair_counts: break # a. 找到频率最高的 pair # 这里的逻辑不变,只是数据源 pair_counts 是实时维护的 # 注意:这里加了 .get 防止 pair_strings 某些极端情况缺key报错 merge_pair = max(pair_counts, key=lambda x: (pair_counts[x], pair_strings.get(x, (b'', b'')))) # b. 更新 vocabulary 和 merges new_token_id = size vocabulary[new_token_id] = vocabulary[merge_pair[0]] + vocabulary[merge_pair[1]] merges.append((vocabulary[merge_pair[0]], vocabulary[merge_pair[1]])) size += 1 # c. 【核心优化】只更新受影响的单词 # 注意:我们要把 set 转成 list,因为我们在遍历过程中会修改 set (discard操作) affected_words = list(pair_to_words[merge_pair]) for word in affected_words: self._update_word_for_merge( word, merge_pair, new_token_id, word_counts, word_encodings, pair_counts, pair_to_words, pair_strings, vocabulary ) # 打印进度 (可选) if merge_idx % 100 == 0: print(f"Merge {merge_idx}/{num_merges}: {merges[-1]}") # --- 【改动结束】 --- return vocabulary, merges
测试验证
修改 addpters.py 文件,执行测试命令 uv run pytest tests/test_train_bpe.py 测试结果
================================================================================ test session starts =================================================================================
platform linux -- Python 3.11.12, pytest-8.4.1, pluggy-1.6.0
rootdir: /home/fdq/cources/cs336/assignment1-basics
configfile: pyproject.toml
plugins: jaxtyping-0.3.2
collected 3 items
tests/test_train_bpe.py::test_train_bpe_speed Starting BPE training (Optimized)...
Merge 0/243: (b' ', b't')
Merge 100/243: (b'i', b'd')
Merge 200/243: (b' ', b'T')
PASSED
tests/test_train_bpe.py::test_train_bpe Starting BPE training (Optimized)...
Merge 0/243: (b' ', b't')
Merge 100/243: (b'i', b'd')
Merge 200/243: (b' ', b'T')
PASSED
tests/test_train_bpe.py::test_train_bpe_special_tokens Starting BPE training (Optimized)...
Merge 0/743: (b'h', b'e')
Merge 100/743: (b'a', b'll')
Merge 200/743: (b' c', b'at')
Merge 300/743: (b' ', b'other')
Merge 400/743: (b' say', b's')
Merge 500/743: (b' e', b'ach')
Merge 600/743: (b' h', b'ard')
Merge 700/743: (b'b', b'y')
PASSED
================================================================================= 3 passed in 1.99s =================================================================================
可以看到 三个都通过了。
现在可以在tinystories 数据集上训练了。
附完整代码
python
"""
BPE Trainer Optimized Version:
save time throug record the affected words when merge a pair
"""
import regex as re
from collections import defaultdict, Counter
from typing import Dict, List, Tuple
N_BYTES = 256
CHUNK_SIZE = 1024 * 50
class BPETrainer_Optimized:
def __init__(self):
self.pattern = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
def train(
self,
input_path: str,
vocab_size: int,
special_tokens: List[str],
**kwargs
) -> Tuple[Dict[int, bytes], List[Tuple[bytes, bytes]]]:
# like simple version, initialize the vocabulary, pre-tokenize
# and initialize the word encodings
vocabulary = {i: bytes([i]) for i in range(N_BYTES)}
size = N_BYTES
for token in special_tokens:
vocabulary[size] = token.encode('utf-8')
size += 1
word_counts = self._pretokenize(input_path, special_tokens)
word_encodings = {word: list(word.encode('utf-8')) for word in word_counts}
print(f"Starting BPE training (Optimized)...")
merges = []
num_merges = vocab_size - size
# Initialize the pair_strings and pair_to_words
# pair_to_words: Dict[Tuple[int, int], Set[str]],
# new data structure store the affected words encodings need to be updates
pair_strings = {}
pair_to_words = defaultdict(set)
# Initialize the all pair_counts
pair_counts = self._count_pairs(word_counts, word_encodings, pair_strings, vocabulary, pair_to_words)
# Training loop
for merge_idx in range(num_merges):
if not pair_counts:
break
# a. find the max count pair to be merged
merge_pair = max(pair_counts, key=lambda x: (pair_counts[x], pair_strings.get(x, ('b', 'b'))))
# b. update the vocabulary and merges
new_token_id = size
vocabulary[new_token_id] = vocabulary[merge_pair[0]] + vocabulary[merge_pair[1]]
merges.append((vocabulary[merge_pair[0]], vocabulary[merge_pair[1]]))
size += 1
# c. get the affected words
affected_words = list(pair_to_words[merge_pair])
# update encodings just for the affected words
for word in affected_words:
self._update_word_for_merge(
word, merge_pair, new_token_id,
word_counts, word_encodings, pair_counts,
pair_to_words, pair_strings, vocabulary
)
if merge_idx % 100 == 0:
print(f"Merge {merge_idx}/{num_merges}: {merges[-1]}")
return vocabulary, merges
def _count_pairs(
self,
word_counts: Dict[str, int],
word_encodings: Dict[str, List[int]],
pair_strings: Dict[Tuple[int, int], Tuple[bytes, bytes]],
vocabulary: Dict[int, bytes],
pair_to_words: Dict[Tuple[int, int], set] # <--- new patameter added
) -> Dict[Tuple[int, int], int]:
pair_counts = defaultdict(int)
for word, count in word_counts.items():
encoding = word_encodings[word]
# loop all Pairs
for i in range(len(encoding) - 1):
pair = (encoding[i], encoding[i+1])
# 1. Count (same to last version)
pair_counts[pair] += count
# 2. [added] record inverted index:this pair included in this word
pair_to_words[pair].add(word)
# 3. record the bytes of the pair(same to last version)
if pair not in pair_strings:
pair_strings[pair] = (vocabulary[pair[0]], vocabulary[pair[1]])
return pair_counts
def _update_word_for_merge(
self,
word: str,
merge_pair: Tuple[int, int],
new_token_id: int,
word_counts: Dict[str, int],
word_encodings: Dict[str, List[int]],
pair_counts: Dict[Tuple[int, int], int],
pair_to_words: Dict[Tuple[int, int], set],
pair_strings: Dict[Tuple[int, int], Tuple[bytes, bytes]],
vocabulary: Dict[int, bytes]
):
"""
Args:
word: the word to be update its encoding
pair_to_words: record the words that include the pair, using for query the words encodings need to be updated
"""
encoding = word_encodings[word]
count = word_counts[word]
# 1. Decrement old pairs
for i in range(len(encoding) - 1):
old_pair = (encoding[i], encoding[i+1])
pair_count -= count
pair_to_words[old_pair].discard(word)
if pair_counts[old_pair] <= 0:
del pair_counts[old_pair]
# 2. Merge encoding
new_encoding = []
i = 0
while i < len(encoding):
if i < len(encoding) - 1 and (encoding[i], encoding[i+1]) == merge_pair:
new_encoding.append(new_token_id)
i += 2
else:
new_encoding.append(encoding[i])
i += 1
# 3. Increment new pairs
for i in range(len(new_encoding) - 1):
new_pair = (new_encoding[i], new_encoding[i+1])
pair_counts[new_pair] += count
pair_to_words[new_pair].add(word)
if new_pair not in pair_strings:
pair_strings[new_pair] = (vocabulary[new_pair[0]], vocabulary[new_pair[1]])
word_encodings[word] = new_encoding
def _update_word_for_merge(
self,
word: str,
merge_pair: Tuple[int, int],
new_token_id: int,
word_counts: Dict[str, int],
word_encodings: Dict[str, List[int]],
pair_counts: Dict[Tuple[int, int], int],
pair_to_words: Dict[Tuple[int, int], set],
pair_strings: Dict[Tuple[int, int], Tuple[bytes, bytes]],
vocabulary: Dict[int, bytes]
):
"""
Args:
word: the word to be update its encoding
pair_to_words: record the words that include the pair, using for query the words encodings need to be updated
"""
encoding = word_encodings[word]
count = word_counts[word]
# minus count for these word include the pair
for i in range(len(encoding) - 1):
old_pair = (encoding[i], encoding[i+1])
pair_counts[old_pair] -= count
pair_to_words[old_pair].discard(word) # delete the word from set
# clear if count is 0(Optional,for memory saving)
if pair_counts[old_pair] <= 0:
del pair_counts[old_pair]
# merge
new_encoding = []
i = 0
while i < len(encoding):
if i < len(encoding) - 1 and (encoding[i], encoding[i+1]) == merge_pair:
new_encoding.append(new_token_id)
i += 2
else:
new_encoding.append(encoding[i])
i += 1
# update the word encoding
word_encodings[word] = new_encoding
# loop the new encoding, add the count of the new pair and add the word
for i in range(len(new_encoding) - 1):
new_pair = (new_encoding[i], new_encoding[i+1])
pair_counts[new_pair] += count
pair_to_words[new_pair].add(word) # add the word to pair list
# record the pair string (for debug or other use)
if new_pair not in pair_strings:
pair_strings[new_pair] = (vocabulary[new_pair[0]], vocabulary[new_pair[1]])
def _pretokenize(
self,
input_path: str,
special_tokens: List[str]
) -> Dict[str, int]:
word_counts = defaultdict(int)
pattern_compiled = re.compile(self.pattern)
special_pattern = "|".join(re.escape(t) for t in special_tokens)
for chunk in self._chunk_documents_streaming(input_path):
if special_tokens:
blocks = re.split(special_pattern, chunk)
else:
blocks = [chunk]
for block in blocks:
for match in re.finditer(pattern_compiled, block):
word_counts[match.group(0)] += 1
return word_counts
def _chunk_documents_streaming(
self,
input_path: str,
chunk_size: int = CHUNK_SIZE,
special_token: str = "<|endoftext|>"
):
"""
Args:
input_path: str, the path of the input file
special_tokens: str
Returns:
chunks: Generator[str, None, None]
"""
leftover = ""
# special_token = special_tokens[0] if special_tokens else "<|endoftext|>"
with open(input_path, 'r', encoding='utf-8') as f:
while True:
# Read a small block each time,such as 50KB here
block = f.read(chunk_size)
# reading complete,quit the loop
if not block:
break
# add the leftover last block to the current block
block = leftover + block
leftover = ""
# find the split point through the special token
# if not found,the whole block is the last block
# rfind(): return the index of the last occurrence of the substring
last_idx = block.rfind(special_token)
if last_idx == -1:
leftover = block
# if found, yield the block by the special tokens
# before the special token, genetor
# after the special token, to the next block by the leftover
else:
yield block[:last_idx + len(special_token)]
leftover = block[last_idx + len(special_token):]
# deal with the last content if has
if leftover:
yield leftover