我们直接按照目前主流 LLM(如 GPT-2/3/4, Llama)的处理方式来写:基于 Unicode 编码的 BPE。
我们将完全使用 Python 原生代码,不依赖任何第三方库。核心逻辑其实非常短,我们分三步走:数据准备 -> 统计与合并 -> 封装训练。
第一步:数据准备与预处理
在计算机眼中,文本本质上是一串数字。标准的 BPE(Byte Pair Encoding)实际上是对字节(Bytes)或 Unicode 码点(Code Points)进行操作。
为了演示方便,我们不直接操作二进制字节,而是操作 Unicode 整数。
代码逻辑:
-
定义一段训练文本。
-
将其转换为 Unicode 码点列表(Integers)。
python
# 1. 这是一个很小的语料库,用于演示
text = "hug pug pun bun hugs"
# 2. 将文本转换为 Unicode 码点列表 (Pre-tokenization 的最原始形态)
#ord('h') -> 104, ord('u') -> 117 ...
tokens = [ord(c) for c in text]
print(f"原始文本: {text}")
print(f"初始 Tokens (Unicode): {tokens}")
print(f"当前序列长度: {len(tokens)}")
-
ord(c)会把字符c转成它对应的 Unicode 整数值:-
ord('h')→ 104 -
ord('u')→ 117 -
ord('g')→ 103 -
空格
' '→ 32
-
-
列表推导式会对文本中每个字符都做此操作,因此输出是一个整数序列。
这就是最原始的"token":一个字符对应一个整数。
第二步:核心功能函数
BPE 需要两个核心动作:
-
get_stats: 找出当前序列中,哪两个相邻的 Token 出现频率最高。 -
merge: 将序列中所有指定的"对子"替换成一个新的 ID。
我们一边写函数,一边看它的作用。
2.1 统计频率 (get_stats)
我们需要遍历整个 token 列表,查看相邻的两个数字 (ids[i], ids[i+1]),并用字典统计它们出现的次数。
python
def get_stats(ids):
counts = {}
# zip(ids, ids[1:]) 是个很 Pythonic 的写法,它可以错位取出相邻对
# 例如 ids=[1, 2, 3], zip后得到 [(1, 2), (2, 3)]
for pair in zip(ids, ids[1:]):
counts[pair] = counts.get(pair, 0) + 1
return counts
# 测试一下
stats = get_stats(tokens)
# 找出出现次数最多的对子
most_common_pair = max(stats, key=stats.get)
print(f"\n统计结果: {stats}")
print(f"频率最高的字符对: {most_common_pair} (出现 {stats[most_common_pair]} 次)")
这是 BPE 核心算法的第一步 :
统计序列中所有相邻 token 对(pair)的出现频率。
例如在原序列:
python
['h','u','g',' ','p','u','g', ...]
我们要统计相邻的 pair,比如:
python
(h,u), (u,g), (g,' '), (' ',p), (p,u), ...
出现次数越多的 pair,越可能被 BPE"合并"成一个新的 token。
-
ids是一个整数列表(字符转 Unicode 后的 tokens)。 -
counts字典用于记录每个相邻 token 对的出现次数。
zip(ids, ids[1:]) 怎么工作?
python
ids = [1, 2, 3, 4]
ids[1:] = [2, 3, 4]
zip 后得到:
python
[(1,2), (2,3), (3,4)]
| ids (原列表) | ids[1:] (错位列表) | zip 组合出的 pair | 含义 |
|---|---|---|---|
| 104 | 117 | (104, 117) |
第1、2个字符 |
| 117 | 103 | (117, 103) |
第2、3个字符 |
| 103 | (无) | (停止) | 结束 |
循环内部逻辑:
-
第一次循环拿到
(104, 117):字典里记一笔{(104, 117): 1}。 -
第二次循环拿到
(117, 103):字典里记一笔{(104, 117): 1, (117, 103): 1}。
这样我们就知道哪两个数字挨在一起了。
在循环中:
-
每遇到一个 pair(例如
(104,117)对应h,u),就在字典中计数 +1。 -
用
counts.get(pair, 0)是防止键不存在时返回默认 0。
2.2 合并操作 (merge)
python
def merge(ids, pair, idx):
"""
ids: 当前的 token 列表
pair: 要合并的对子,例如 (104, 117)
idx: 新的 token ID,例如 256
"""
newids = []
i = 0
while i < len(ids):
# 如果不是最后一位,并且当前位置和下一位置匹配我们要找的 pair
if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
newids.append(idx) # 替换为新 ID
i += 2 # 跳过这两个被合并的元素
else:
newids.append(ids[i]) # 照常添加
i += 1
return newids
# 测试一下:假设我们将刚才找到的最高频对 (117, 103) -> ('u', 'g') 合并为 ID 256
# 注意:上面的 most_common_pair 可能是 (117, 103) 即 'u','g',出现 3 次
new_tokens = merge(tokens, most_common_pair, 256)
print(f"合并后的序列: {new_tokens}")
print(f"合并后长度: {len(new_tokens)} (变短了)")
这是 BPE 的"压缩"步骤。假设我们要把 (104, 117) (即 'h', 'u') 合并成一个新的 ID(比如 256)。我们需要生成一个新的列表,扫描旧列表,一旦发现 104 跟着 117,就用 256 替换掉它们。
假设我们要把 (104, 117) 合并成 256。 旧列表 ids = [104, 117, 103]。
第一轮循环 (i=0):
-
检查:
ids[0]是 104 吗?是。ids[1]是 117 吗?是。 -
匹配成功!
-
操作:把
256放进newids。 -
指针移动:
i增加 2(因为 104 和 117 已经被 256 替代了,我们不需要再处理它们)。现在i = 2。
第二轮循环 (i=2):
-
检查:
ids[2]是 103。跟(104, 117)没关系。 -
匹配失败。
-
操作:把
103原样放进newids。 -
指针移动:
i增加 1。现在i = 3。
第三轮循环 (i=3):
i不再小于len(ids)(也就是 3),循环结束。
结果: newids 变成了 [256, 103]。 (逻辑上就是:[h, u, g] 变成了 [hu, g])。
第三步:训练循环 (The Training Loop)
现在我们将上面的步骤串联起来。我们需要设定一个目标词表大小(vocab_size) 。 假设我们要进行 num_merges 次合并。每次合并都会产生一个新的 Token ID。
python
# --- 参数设置 ---
vocab_size = 260 # 假设我们要把词表扩充到 260 (原 Unicode 基础是 256 以内,这里只做几次合并演示)
num_merges = vocab_size - 256 # 我们需要做 4 次合并
ids = list(tokens) # 复制一份原始数据
# 记录我们将要学习到的"合并规则"
merges = {} # (id1, id2) -> new_id
print("\n=== 开始 BPE 训练循环 ===")
for i in range(num_merges):
# 1. 统计频率
stats = get_stats(ids)
# 2. 找到最高频的对子
# 如果所有对子都只出现1次,或者没有对子了,就提前结束
if not stats:
break
pair = max(stats, key=stats.get)
# 3. 分配新的 Token ID
# 通常新的 ID 是从 256 (字节最大值+1) 开始递增
idx = 256 + i
# 4. 执行合并
ids = merge(ids, pair, idx)
# 5. 记录规则
merges[pair] = idx
print(f"Step {i+1}: 合并 {pair} -> {idx}. 序列长度变为: {len(ids)}")
print("\n=== 训练结束 ===")
print(f"最终压缩后的 Tokens: {ids}")
print(f"学习到的合并规则 (Merges): {merges}")
merges 字典:这就是我们的"训练结果"。
-
它就像一本字典:
{(104, 117): 256, (256, 103): 257, ...} -
以后推理(Inference)的时候,我们就不需要重新统计频率了,直接查这本字典,把文本里的字符对换成 ID 即可。
总结一下代码流
-
变成数字 :
[h, u, g]->[104, 117, 103] -
get_stats:用zip错位技术,发现(104, 117)出现了。 -
决策 :决定把
(104, 117)变成256。 -
merge:用while循环和i指针,扫描整个列表:-
遇到
104, 117-> 写入256,跳过两步。 -
遇到其他 -> 照抄,跳过一步。
-
-
结果 :列表变短了
[256, 103]。