BPE编码从零开始实现pytorch

我们直接按照目前主流 LLM(如 GPT-2/3/4, Llama)的处理方式来写:基于 Unicode 编码的 BPE

我们将完全使用 Python 原生代码,不依赖任何第三方库。核心逻辑其实非常短,我们分三步走:数据准备 -> 统计与合并 -> 封装训练

第一步:数据准备与预处理

在计算机眼中,文本本质上是一串数字。标准的 BPE(Byte Pair Encoding)实际上是对字节(Bytes)或 Unicode 码点(Code Points)进行操作。

为了演示方便,我们不直接操作二进制字节,而是操作 Unicode 整数。

代码逻辑:

  1. 定义一段训练文本。

  2. 将其转换为 Unicode 码点列表(Integers)。

python 复制代码
# 1. 这是一个很小的语料库,用于演示
text = "hug pug pun bun hugs"

# 2. 将文本转换为 Unicode 码点列表 (Pre-tokenization 的最原始形态)
#ord('h') -> 104, ord('u') -> 117 ...
tokens = [ord(c) for c in text]

print(f"原始文本: {text}")
print(f"初始 Tokens (Unicode): {tokens}")
print(f"当前序列长度: {len(tokens)}")
  • ord(c) 会把字符 c 转成它对应的 Unicode 整数值:

    • ord('h') → 104

    • ord('u') → 117

    • ord('g') → 103

    • 空格 ' ' → 32

  • 列表推导式会对文本中每个字符都做此操作,因此输出是一个整数序列。

这就是最原始的"token":一个字符对应一个整数

第二步:核心功能函数

BPE 需要两个核心动作:

  1. get_stats: 找出当前序列中,哪两个相邻的 Token 出现频率最高。

  2. merge: 将序列中所有指定的"对子"替换成一个新的 ID。

我们一边写函数,一边看它的作用。

2.1 统计频率 (get_stats)

我们需要遍历整个 token 列表,查看相邻的两个数字 (ids[i], ids[i+1]),并用字典统计它们出现的次数。

python 复制代码
def get_stats(ids):
    counts = {}
    # zip(ids, ids[1:]) 是个很 Pythonic 的写法,它可以错位取出相邻对
    # 例如 ids=[1, 2, 3], zip后得到 [(1, 2), (2, 3)]
    for pair in zip(ids, ids[1:]): 
        counts[pair] = counts.get(pair, 0) + 1
    return counts

# 测试一下
stats = get_stats(tokens)
# 找出出现次数最多的对子
most_common_pair = max(stats, key=stats.get)
print(f"\n统计结果: {stats}")
print(f"频率最高的字符对: {most_common_pair} (出现 {stats[most_common_pair]} 次)")

这是 BPE 核心算法的第一步
统计序列中所有相邻 token 对(pair)的出现频率

例如在原序列:

python 复制代码
['h','u','g',' ','p','u','g', ...]

我们要统计相邻的 pair,比如:

python 复制代码
(h,u), (u,g), (g,' '), (' ',p), (p,u), ...

出现次数越多的 pair,越可能被 BPE"合并"成一个新的 token。

  • ids 是一个整数列表(字符转 Unicode 后的 tokens)。

  • counts 字典用于记录每个相邻 token 对的出现次数。

zip(ids, ids[1:]) 怎么工作?

python 复制代码
ids = [1, 2, 3, 4]
ids[1:] = [2, 3, 4]

zip 后得到:

python 复制代码
[(1,2), (2,3), (3,4)]
ids (原列表) ids[1:] (错位列表) zip 组合出的 pair 含义
104 117 (104, 117) 第1、2个字符
117 103 (117, 103) 第2、3个字符
103 (无) (停止) 结束

循环内部逻辑:

  • 第一次循环拿到 (104, 117):字典里记一笔 {(104, 117): 1}

  • 第二次循环拿到 (117, 103):字典里记一笔 {(104, 117): 1, (117, 103): 1}

这样我们就知道哪两个数字挨在一起了。

在循环中:

  • 每遇到一个 pair(例如 (104,117) 对应 h,u),就在字典中计数 +1。

  • counts.get(pair, 0) 是防止键不存在时返回默认 0。

2.2 合并操作 (merge)
python 复制代码
def merge(ids, pair, idx):
    """
    ids: 当前的 token 列表
    pair: 要合并的对子,例如 (104, 117)
    idx: 新的 token ID,例如 256
    """
    newids = []
    i = 0
    while i < len(ids):
        # 如果不是最后一位,并且当前位置和下一位置匹配我们要找的 pair
        if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
            newids.append(idx) # 替换为新 ID
            i += 2             # 跳过这两个被合并的元素
        else:
            newids.append(ids[i]) # 照常添加
            i += 1
    return newids

# 测试一下:假设我们将刚才找到的最高频对 (117, 103) -> ('u', 'g') 合并为 ID 256
# 注意:上面的 most_common_pair 可能是 (117, 103) 即 'u','g',出现 3 次
new_tokens = merge(tokens, most_common_pair, 256)
print(f"合并后的序列: {new_tokens}")
print(f"合并后长度: {len(new_tokens)} (变短了)")

这是 BPE 的"压缩"步骤。假设我们要把 (104, 117) (即 'h', 'u') 合并成一个新的 ID(比如 256)。我们需要生成一个新的列表,扫描旧列表,一旦发现 104 跟着 117,就用 256 替换掉它们。

假设我们要把 (104, 117) 合并成 256。 旧列表 ids = [104, 117, 103]

第一轮循环 (i=0):

  • 检查:ids[0] 是 104 吗?是。ids[1] 是 117 吗?是。

  • 匹配成功!

  • 操作:把 256 放进 newids

  • 指针移动:i 增加 2(因为 104 和 117 已经被 256 替代了,我们不需要再处理它们)。现在 i = 2

第二轮循环 (i=2):

  • 检查:ids[2] 是 103。跟 (104, 117) 没关系。

  • 匹配失败。

  • 操作:把 103 原样放进 newids

  • 指针移动:i 增加 1。现在 i = 3

第三轮循环 (i=3):

  • i 不再小于 len(ids)(也就是 3),循环结束。

结果: newids 变成了 [256, 103]。 (逻辑上就是:[h, u, g] 变成了 [hu, g])。

第三步:训练循环 (The Training Loop)

现在我们将上面的步骤串联起来。我们需要设定一个目标词表大小(vocab_size) 。 假设我们要进行 num_merges 次合并。每次合并都会产生一个新的 Token ID。

python 复制代码
# --- 参数设置 ---
vocab_size = 260 # 假设我们要把词表扩充到 260 (原 Unicode 基础是 256 以内,这里只做几次合并演示)
num_merges = vocab_size - 256 # 我们需要做 4 次合并
ids = list(tokens) # 复制一份原始数据

# 记录我们将要学习到的"合并规则"
merges = {} # (id1, id2) -> new_id

print("\n=== 开始 BPE 训练循环 ===")

for i in range(num_merges):
    # 1. 统计频率
    stats = get_stats(ids)
    
    # 2. 找到最高频的对子
    # 如果所有对子都只出现1次,或者没有对子了,就提前结束
    if not stats:
        break
        
    pair = max(stats, key=stats.get)
    
    # 3. 分配新的 Token ID
    # 通常新的 ID 是从 256 (字节最大值+1) 开始递增
    idx = 256 + i
    
    # 4. 执行合并
    ids = merge(ids, pair, idx)
    
    # 5. 记录规则
    merges[pair] = idx
    
    print(f"Step {i+1}: 合并 {pair} -> {idx}. 序列长度变为: {len(ids)}")

print("\n=== 训练结束 ===")
print(f"最终压缩后的 Tokens: {ids}")
print(f"学习到的合并规则 (Merges): {merges}")

merges 字典:这就是我们的"训练结果"。

  • 它就像一本字典:{(104, 117): 256, (256, 103): 257, ...}

  • 以后推理(Inference)的时候,我们就不需要重新统计频率了,直接查这本字典,把文本里的字符对换成 ID 即可。

总结一下代码流

  1. 变成数字[h, u, g] -> [104, 117, 103]

  2. get_stats :用 zip 错位技术,发现 (104, 117) 出现了。

  3. 决策 :决定把 (104, 117) 变成 256

  4. merge :用 while 循环和 i 指针,扫描整个列表:

    • 遇到 104, 117 -> 写入 256,跳过两步。

    • 遇到其他 -> 照抄,跳过一步。

  5. 结果 :列表变短了 [256, 103]

相关推荐
星释1 小时前
Rust 练习册 32:二分查找与算法实现艺术
开发语言·算法·rust
lisw051 小时前
边缘计算与云计算!
大数据·人工智能·机器学习·云计算·边缘计算
G***技1 小时前
杰和 DN84 AI边缘计算盒:工业质检的“精准快”引擎
人工智能·边缘计算
zenRRan1 小时前
英伟达提出“思考用扩散,说话用自回归”:实现语言模型效率与质量的双赢!
人工智能·机器学习·语言模型·数据挖掘·回归
zl_vslam2 小时前
SLAM中的非线性优-3D图优化之四元数在Opencv-PNP中的应用(五)
人工智能·算法·计算机视觉
林炳然2 小时前
Python-Basic Day-4 函数-基础知识
python
EAIReport2 小时前
企业人力资源管理数据分析:离职因素与群体特征研究
人工智能·数据挖掘·数据分析
FreeCode2 小时前
LangSmith Studio 调试智能体
python·langchain·agent
Paraverse_徐志斌2 小时前
基于 PyTorch + BERT 意图识别与模型微调
人工智能·pytorch·python·bert·transformer