BPE 算法原理与训练实现

一、BPE 算法核心原理

1. 核心思想

BPE 的核心思想是从基础词汇单元(字符 / 字节)出发,反复迭代地合并出现频率最高的相邻字符对(字节对),将其作为新的子词单元,直到达到预设的词汇表大小或没有可合并的字符对为止。

这种思想既保留了字符级别的细粒度(解决 OOV 问题),又能通过合并高频子词形成更具语义的单元(如 "un-"、"happy"、"ing"),提升编码效率。

2. 关键概念铺垫

  • 基础单元:初始为文本中的单个字符(通常会在词尾添加特殊标记</w>,用于区分词内子词和词尾子词,如 "low"和"lower");

  • 频率统计:以 "词 - 出现次数" 的形式统计语料中所有词的频率;

  • 相邻字符对:单个词内的连续两个基础单元(或已合并的子词单元);

  • 合并停止条件:两种常见条件(满足其一即可):

    1. 词汇表大小达到预设阈值(如 30000、50000);
    2. 语料中不存在出现频率 > 1 的相邻字符对。

3. 算法执行步骤(原理层面)

  1. 数据预处理与初始化

    • 对原始语料进行分词、清洗,为每个词添加词尾标记</w>
    • 统计每个词的出现频率,形成「词:频率」字典;
    • 将每个词拆分为单个字符的序列,作为初始子词单元(如 "low" 拆分为l o w </w>)。
  2. 统计相邻字符对频率

    • 遍历所有词的字符序列,统计所有相邻字符对的全局出现频率;
    • 例如语料中有 "low":5、"lower":3,会统计到(l,o):8(o,w):8(w,</w>):5等。
  3. 合并最高频字符对

    • 找到全局频率最高的相邻字符对,将其合并为一个新的子词单元;
    • 遍历所有词的字符序列,将该字符对替换为新子词(如合并w </w>w</w>,则 "low" 变为l o w</w>)。
  4. 迭代合并

    • 重复步骤 2 和步骤 3,每次合并后都会生成新的子词单元,词汇表不断扩大;
    • 每次合并都会记录「合并规则」(即哪两个单元合并为新单元),用于后续的编码和解码。
  5. 停止迭代,生成最终词汇表

    • 当词汇表大小达到预设值或无高频字符对可合并时,停止迭代;
    • 最终词汇表包含初始字符单元和所有迭代过程中生成的合并子词单元。

二、BPE 算法训练实现(Python 手动实现)

下面通过一个极简示例,手动实现 BPE 的训练过程,清晰展示其核心逻辑(无第三方库依赖)。

步骤 1:准备初始数据(带频率的语料)

我们选用一个简单的模拟语料,包含 4 个词及其出现频率,方便观察合并过程:

python 复制代码
# 步骤1:初始化带频率的词表(已添加词尾标记</w>)
word_freqs = {
    "low</w>": 5,
    "lower</w>": 3,
    "newest</w>": 2,
    "widest</w>": 2
}

# 将每个词拆分为字符列表,形成初始的「词序列: 频率」字典
def init_word_sequences(word_freqs):
    word_seqs = {}
    for word, freq in word_freqs.items():
        # 拆分为单个字符(如"low</w>" -> ["l", "o", "w", "</w>"])
        char_seq = list(word)
        word_seqs[tuple(char_seq)] = freq  # 用tuple作为key(list不可哈希)
    return word_seqs

word_sequences = init_word_sequences(word_freqs)

步骤 2:定义核心辅助函数

包括「统计相邻字符对频率」、「合并最高频字符对」两个核心函数:

python 复制代码
from collections import defaultdict

def get_pair_freqs(word_sequences):
    """
    步骤2:统计所有相邻字符对的全局频率
    """
    pair_freqs = defaultdict(int)
    for char_seq, freq in word_sequences.items():
        # 遍历单个词的字符序列,统计相邻对
        for i in range(len(char_seq) - 1):
            pair = (char_seq[i], char_seq[i+1])
            pair_freqs[pair] += freq
    return pair_freqs

def merge_highest_freq_pair(word_sequences, best_pair):
    """
    步骤3:合并全局频率最高的字符对(best_pair)
    """
    new_word_sequences = {}
    for char_seq, freq in word_sequences.items():
        new_char_seq = []
        i = 0
        while i < len(char_seq):
            # 找到可合并的对,合并后跳过下一个字符
            if i < len(char_seq) - 1 and (char_seq[i], char_seq[i+1]) == best_pair:
                merged_token = char_seq[i] + char_seq[i+1]
                new_char_seq.append(merged_token)
                i += 2  # 跳过已合并的下一个字符
            else:
                new_char_seq.append(char_seq[i])
                i += 1
        # 更新新的词序列字典
        new_word_sequences[tuple(new_char_seq)] = freq
    return new_word_sequences

步骤 3:执行迭代合并(完整训练流程)

设置预设词汇表大小,执行迭代合并,记录合并规则和最终词汇表:

python 复制代码
def train_bpe(word_freqs, vocab_size=10):
    """
    完整BPE训练流程
    :param word_freqs: 初始词频字典
    :param vocab_size: 预设词汇表大小(需大于初始字符数)
    :return: 合并规则列表、最终词汇表
    """
    # 初始化
    word_sequences = init_word_sequences(word_freqs)
    merge_rules = []  # 记录所有合并规则([(a,b), (c,d), ...])
    # 提取初始字符词汇表(去重)
    vocab = set()
    for word in word_freqs.keys():
        vocab.update(list(word))
    vocab = list(vocab)
    
    # 迭代合并,直到达到词汇表大小
    while len(vocab) < vocab_size:
        # 步骤1:统计相邻对频率
        pair_freqs = get_pair_freqs(word_sequences)
        if not pair_freqs:  # 无可用合并对,提前终止
            break
        
        # 步骤2:找到频率最高的字符对
        best_pair = max(pair_freqs.items(), key=lambda x: x[1])[0]
        
        # 步骤3:合并最高频字符对
        word_sequences = merge_highest_freq_pair(word_sequences, best_pair)
        
        # 步骤4:记录合并规则,更新词汇表
        merge_rules.append(best_pair)
        new_token = best_pair[0] + best_pair[1]
        vocab.append(new_token)
        
        # 打印中间过程(可选,方便观察)
        print(f"合并 {best_pair} -> {new_token} | 当前词汇表大小:{len(vocab)}")
    
    return merge_rules, vocab, word_sequences

# 执行BPE训练,预设词汇表大小为15
merge_rules, final_vocab, final_word_sequences = train_bpe(word_freqs, vocab_size=15)

步骤 4:查看训练结果

python 复制代码
# 打印最终结果
print("\n=== 训练完成 ===")
print(f"合并规则列表(共{len(merge_rules)}条):")
for idx, rule in enumerate(merge_rules):
    print(f"  {idx+1}: {rule} -> {rule[0]+rule[1]}")

print(f"\n最终词汇表(共{len(final_vocab)}个单元):")
print(sorted(final_vocab))

print(f"\n最终词序列(合并后):")
for seq, freq in final_word_sequences.items():
    print(f"  {seq}: {freq}")

运行结果解读

运行上述代码后,会看到迭代合并的过程(部分输出如下):

plaintext 复制代码
合并 ('e', 's') -> es | 当前词汇表大小:9
合并 ('s', 't') -> st | 当前词汇表大小:10
合并 ('e', 'st') -> est | 当前词汇表大小:11
...
  1. 合并规则按迭代顺序记录,后续编码时需严格按照该顺序进行子词分割;
  2. 最终词汇表包含初始字符(low等)和合并生成的子词(esstest等);
  3. 最终词序列已被合并为更粗粒度的子词单元,减少了冗余,提升了编码效率。

三、关键补充说明

  1. BPE 的优势

    • 无监督训练,无需人工标注子词;
    • 有效解决未登录词(OOV)问题,即使遇到新词,也能拆分为基础字符单元;
    • 词汇表大小可控,平衡编码效率和模型复杂度。
  2. 实际应用中的优化

    • 上述实现为极简版本,实际工业界(如 Hugging Face)的 BPE 实现会优化存储和计算(如用哈希表加速查找);
    • 通常以「字节」而非「字符」作为初始单元(尤其针对多语言场景),避免字符编码(如 UTF-8)带来的问题;
    • 会添加特殊标记(如<unk>)处理罕见字符。
  3. 解码过程

    • 解码时只需反向应用合并规则,或将子词单元直接拼接(注意</w>标记需替换为空格或直接删除)。
相关推荐
胡萝卜不甜9 小时前
算法宗门---广度有优先搜索BFS
算法·宽度优先
独自破碎E10 小时前
【归并】数组中的逆序对
java·数据结构·算法
f***241110 小时前
MATLAB高效算法优化实战指南
开发语言·算法·matlab
Blossom.11810 小时前
大模型自动化压缩:基于权重共享的超网神经架构搜索实战
运维·人工智能·python·算法·chatgpt·架构·自动化
优选资源分享10 小时前
MD5 哈希值校验工具 v1.5.3 实用文件校验工具
算法·哈希算法
AI科技星10 小时前
能量绝对性与几何本源:统一场论能量方程的第一性原理推导、验证与范式革命
服务器·人工智能·科技·线性代数·算法·机器学习·生活
Coder_Boy_10 小时前
基于SpringAI的在线考试系统-数据库表设计
java·数据库·算法
散峰而望10 小时前
【算法竞赛】链表和 list
数据结构·c++·算法·链表·list·哈希算法·推荐算法
爱编程的小吴10 小时前
【力扣练习题】55. 跳跃游戏
算法·leetcode