BPE 词表构建与编解码说明
一、BPE 背景
BPE(Byte Pair Encoding,字节对编码) 是一种数据压缩与分词算法,后被广泛用于 NLP 的词表构建。其核心思想是:从字符(或字节)级别出发,反复将出现频率最高的相邻二元组合并成一个新符号,直到词表大小达到设定值。
- 起源:早期用于文本压缩;在 NLP 中由 Sennrich 等人引入,用于机器翻译等任务的子词分词。
- 特点:词表在 256(单字节)基础上扩展,能平衡字符级与词级表示,对未登录词、多语言更友好。
- 与 LLM 的关系:GPT、LLaMA 等大模型都使用 BPE 或类似子词算法(如 SentencePiece),将文本切分为 token 序列再输入模型。
二、任务意义
本任务实现一个简化版 BPE,完成三件事:
- 训练/构建词表 :在语料上统计相邻字节对频次,按频次从高到低依次合并,得到合并规则表
merges与 id→字节 映射vocab。 - 编码 :给定字符串与
merges,将字符串转为 UTF-8 字节后,按训练时的合并顺序不断合并,得到 token id 列表。 - 解码 :给定 token id 列表与
vocab,将每个 id 还原为字节并拼接,再按 UTF-8 解码为字符串。
意义在于理解:子词词表如何从语料中学习得到 ,以及编码/解码如何与合并顺序一致,为后续学习 Transformer、Embedding 等打基础。
三、如何解决:整体思路与代码作用
| 步骤 | 做什么 | 代码/数据结构 |
|---|---|---|
| 语料准备 | 读入多文件,拼成一个大字符串 | corpus,build_vocab(corpus) 的输入 |
| 统计 | 对当前 id 序列统计所有相邻二元组出现次数 | get_stats(ids) → { (id_i, id_{i+1}): count } |
| 选 pair | 训练时选出现最多 的 pair 合并;编码时选在 merges 里编号最小的 pair 合并 | 训练:max(stats, key=stats.get);编码:min(stats, key=merges.get(..., inf)) |
| 合并 | 把序列里所有该 pair 替换成一个新 id | merge(ids, pair, idx) |
| 记录规则 | 训练时记 (p0,p1)→新 id;并维护 id→字节 | merges[pair]=idx,vocab[idx]=vocab[p0]+vocab[p1] |
| 编码 | 字符串→字节 list,再按 merges 顺序反复合并 | encode(text, merges) → list[int] |
| 解码 | id 列表→按 vocab 取字节→拼接→UTF-8 解码 | decode(ids, vocab) → str |
核心约束:编码时合并顺序必须与训练时一致,因此用「在 merges 中的编号」表示顺序,编码时每次选编号最小的 pair 合并(即最先被学到的规则)。
四、代码思路与模块作用
-
get_stats(ids)
统计
ids中所有相邻二元组出现次数,返回dict[(int,int), int]。训练时用来找「当前频次最高的 pair」;编码时用来找「当前序列里存在、且出现在 merges 里的 pair」。 -
merge(ids, pair, idx)
在
ids中把所有连续的(pair[0], pair[1])替换成一个idx,返回新列表。训练和编码都会反复调用。 -
build_vocab(text)
- 把 text 转为 UTF-8 字节再转为 0--255 的 id 列表。
- 循环若干轮(由
vocab_size - 256决定):每轮get_stats→ 选频次最高的 pair →merge→ 把该 pair→新 id 记入merges,新 id 从 256 递增。 - 用 merges 构建
vocab:0--255 为单字节;256 及以上为对应两个子 token 的字节拼接。 - 返回
merges, vocab,供编码和解码使用。
-
encode(text, merges)
把 text 转为字节 id 列表后,只要长度≥2 就:
get_stats→ 在 stats 的键中选「merges 中编号最小」的 pair(min(..., key=merges.get(..., inf)))→ 若在 merges 中则merge,否则退出。保证合并顺序与训练一致。 -
decode(ids, vocab)
按 ids 顺序用
vocab[id]取字节并拼接成一条 bytes,再decode("utf-8", errors="replace")得到字符串。
五、代码讲解(按执行顺序)
-
主流程 :读目录下所有文件 → 拼成
corpus→build_vocab(corpus)得到merges, vocabs→ 对示例字符串encode再decode验证无损。 -
get_stats :
zip(ids, ids[1:])得到所有相邻对,对每个 pair 计数;返回的 key 是 tuple (int,int),value 是出现次数。 -
merge:顺序扫描 ids,若当前与下一项等于 pair 则压入 idx 并跳过两项,否则压入当前项并跳过一项。
-
build_vocab 中的关键:
pair = max(stats, key=stats.get):训练时选频次最高的 pair。merges[pair] = idx:记录 (p0,p1)→新 id,新 id 从 256 起递增,即「合并顺序」。vocab[idx] = vocab[p0] + vocab[p1]:新 token 的字节 = 两子 token 字节拼接,用于解码。
-
encode 中的关键:
pair = min(stats, key=lambda p: merges.get(p, float("inf"))):在当前出现的 pair 里,选在 merges 中编号最小的(即最先被学到的),保证与训练顺序一致;不在 merges 的 pair 用 inf 避免被选到。- 若
pair not in merges则 break,否则按该规则做一次 merge,循环直到无法再合并。
-
decode :
b"".join(vocab[idx] for idx in ids)拼接字节,再 UTF-8 解码。
五、各代码块输出样式与数据示例
以下用具体输入/输出说明每个步骤的数据形式(示例中数字与中文仅为说明,实际以运行结果为准)。
1. get_stats(ids)
输入:id 列表(整数序列)。
输出 :字典,键为相邻二元组 (int, int),值为出现次数。
# 输入
ids = [97, 98, 98, 97, 98]
# 输出(样式)
get_stats(ids)
# => {(97, 98): 2, (98, 98): 1, (98, 97): 1}
2. merge(ids, pair, idx)
输入 :ids 列表、要合并的 pair 元组、新 token 的 idx。
输出:新 id 列表(所有该 pair 被替换为 idx)。
# 输入
ids = [97, 98, 98, 97, 98]
pair = (97, 98)
idx = 256
# 输出(样式)
merge(ids, pair, idx)
# => [256, 98, 97, 98]
3. build_vocab(text) 的 merges / vocab
输入 :语料字符串 text。
输出 :merges 与 vocab。
merges 输出样式 :键为 (p0, p1),值为新 id(从 256 递增)。
# merges 示例(前几条)
merges = {
(228, 184): 256,
(184, 187): 257,
(230, 136): 258,
...
}
vocab 输出样式 :键为 id(0~255 为单字节,256 起为合并得到的 id),值为对应字节串 bytes。
# vocab 示例(片段)
vocab = {
0: b'\x00',
1: b'\x01',
...
97: b'a',
98: b'b',
...
256: b'\xe4\xb8\xad', # 例如「中」的 UTF-8 两字节合并后
257: b'...',
...
}
4. encode(text, merges)
输入 :字符串 text,合并表 merges。
输出 :token id 列表 list[int]。
# 输入
text = "亚索(托儿索)"
merges = { ... } # 由 build_vocab 得到
# 输出(样式)
encode(text, merges)
# => [256, 258, 260, 261, 259, 262, 263]
# 实际长度与数值依词表而定,此处仅为示例
5. decode(ids, vocab)
输入 :token id 列表 ids,词表 vocab。
输出:解码后的字符串。
# 输入
ids = [256, 258, 260, 261, 259, 262, 263]
vocab = { ... } # 由 build_vocab 得到
# 输出(样式)
decode(ids, vocab)
# => "亚索(托儿索)"
6. 主流程:语料 → 编解码
corpus 片段(样式):
# 读入目录下所有文件拼接后,corpus 为一大段字符串,例如:
corpus = "英雄名:亚索\n背景故事:亚索是一名来自艾欧尼亚的剑客...\n技能1:斩钢闪..."
构建词表后:
merges, vocabs = build_vocab(corpus)
# merges: 244 条 (pair -> idx),idx 从 256 到 499
# vocabs: 500 个 id -> bytes
编码结果(样式):
string = "亚索(托儿索)"
encode_ids = encode(string, merges)
# => [256, 258, 260, 261, 259, 262, 263] # 示例,实际由词表决定
解码结果(样式):
decode_string = decode(encode_ids, vocabs)
# => "亚索(托儿索)"
六、总结与思路总结
- BPE 做了什么:从字节序列出发,按「频次最高的相邻对优先合并」的规则,得到合并表与扩展词表,从而在固定词表大小下得到有意义的子词单元。
- 本实现的思路 :
1)训练阶段:语料→字节 id→多轮「统计→选最高频 pair→合并→记录 (pair→新 id)、构建 id→字节」→得到 merges 与 vocab。
2)编码阶段:字符串→字节 id→按 merges 中编号从小到大 的顺序反复合并→得到 token id 列表。
3)解码阶段:id 列表→按 vocab 还原字节→拼接→UTF-8 解码。 - 关键点 :编码时必须按「训练时的合并顺序」进行,因此用
min(stats, key=merges.get(..., inf))每次只做「最先被学到」的合并。
七、语料
英雄名:亚索(托儿索)
背景故事:亚索是一名来自艾欧尼亚的剑客,也是同门中唯一能掌握传奇风之剑术的弟子。当他被指控谋杀长老时,他被迫挥剑自保,杀死自己的兄长以证清白。长老之死真相大白后,亚索踏上了赎罪之路,在故乡的土地上流浪,只有疾风指引着他的剑刃。
<br><br>亚索的剑术迅捷如风,他能够斩钢闪突刺敌人,第三次施放时更会释放一道击飞敌人的旋风。风之障壁能格挡一切飞行道具,踏前斩则让他穿梭于敌阵之中。当敌人被击飞至空中,亚索可施放狂风绝息斩,瞬移至目标身旁给予致命一击。
技能1:斩钢闪, 技能描述:向前出剑,对直线上的敌人造成物理伤害。若在突进过程中施放,斩钢闪会呈环形出剑。在短时间内连续命中两次后,第三次斩钢闪会吹出一道击飞敌人的旋风。
技能2:风之障壁, 技能描述:形成一堵风墙,持续数秒。风墙会阻挡敌方的所有飞行道具(包括普攻弹道、技能弹道等)。
技能3:踏前斩, 技能描述:向目标敌人突进,造成魔法伤害。每次施放都会在短时间内提升下次突进的基础伤害。同一目标在短时间内无法被重复突进。
技能4:狂风绝息斩, 技能描述:闪烁至一名被击飞的敌方英雄身旁,造成物理伤害并使范围内所有被击飞的敌人在空中多停留一段时间。获得满额穿甲加成,持续数秒。
八、完整代码
python
import os
def get_stats(ids):
counts = {}
for pair in zip(ids, ids[1:]):
counts[pair] = counts.get(pair, 0) + 1
return counts
def merge(ids, pair, idx):
newids = []
i = 0
while i < len(ids):
if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
newids.append(idx)
i += 2
else:
newids.append(ids[i])
i += 1
return newids
def build_vocab(text):
vocab_size = 500
num_merges = vocab_size - 256
tokens = text.encode("utf-8")
tokens = list(map(int, tokens))
ids = list(tokens)
merges = {}
for i in range(num_merges):
stats = get_stats(ids)
pair = max(stats, key=stats.get)
idx = 256 + i
ids = merge(ids, pair, idx)
merges[pair] = idx
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in merges.items():
vocab[idx] = vocab[p0] + vocab[p1]
return merges, vocab
def encode(text, merges):
tokens = list(text.encode("utf-8"))
while len(tokens) >= 2:
stats = get_stats(tokens)
pair = min(stats, key=lambda p: merges.get(p, float("inf")))
if pair not in merges:
break
idx = merges[pair]
tokens = merge(tokens, pair, idx)
return tokens
def decode(ids, vocab):
tokens = b"".join(vocab[idx] for idx in ids)
text = tokens.decode("utf-8", errors="replace")
return text
if __name__ == "__main__":
dir_path = r"/Users/tripleh/Heroes"
corpus = ""
for path in os.listdir(dir_path):
path = os.path.join(dir_path, path)
with open(path, encoding="utf-8", errors="replace") as f:
text = f.read()
corpus += text + '\n'
merges, vocabs = build_vocab(corpus)
string = "亚索(托儿索)"
encode_ids = encode(string, merges)
decode_string = decode(encode_ids, vocabs)