大模型Tokenizer原理:深入理解BPE与WordPiece子词编码技术

大模型Tokenizer原理:深入理解BPE与WordPiece子词编码技术

在大型语言模型的技术架构中,Tokenizer(分词器)是连接原始文本与模型输入的关键桥梁。不同于简单的按空格或标点分割,一个优秀的分词器需要将文本切分为模型能够高效处理的Token序列,同时尽可能保留语义信息。本文深入剖析当前大模型中最常用的两种子词分词算法------Byte Pair Encoding(BPE)和WordPiece,从底层原理到代码实现进行全面讲解。

BPE算法原理与训练过程

BPE最初由Philip Gage在1994年提出,用于数据压缩领域。其核心思想是通过迭代合并高频出现的字节对来构建符号表。这一思想被迁移到NLP领域后,成为构建子词词汇表的标准方法。

BPE训练的核心流程如下。首先,将训练语料中的每个单词拆分为字符序列,并在每个单词末尾添加特殊分隔符 。同时统计每个单词出现的频率。例如单词"higher"会变为"h i g h e r ",单词"low"变为"l o w ``"。初始词汇表包含所有单个字符和特殊分隔符。

接下来进入迭代合并阶段。在每次迭代中,算法遍历所有相邻字符对,统计它们在语料库中共同出现的总频率。选择频率最高的字符对作为合并规则,加入词汇表,并将语料库中所有该字符对合并为一个新符号。这个过程重复进行,直到词汇表达到预设大小。

假设语料库中有单词"low"出现5次,"lower"出现2次,"new"出现3次。在初始状态下,字符序列分别为"l o w"、"l o w e r"和"n e w"。经过若干次迭代后,可能形成"lo"、"wer"等子词单元,这些子词能够组合表示原单词,同时在统计意义上具有更高的出现频率。

BPE的最终分词过程是确定性的。对于任意输入单词,首先拆分为字符序列,然后从左到右遍历,贪心地应用已学到的合并规则。每次检查当前位置是否存在可合并的字符对,如果存在则合并,否则保持原样并移动到下一个位置。这种方法保证每个单词都能被分解为词汇表中的子词组合。

WordPiece算法深度解析

WordPiece是Google为语音搜索系统开发的分词技术,后被BERT采用并广为人知。与BPE基于频率的贪心合并不同,WordPiece采用基于概率的训练目标,这导致了本质性的差异。

WordPiece的训练目标是最大化训练语料的语言模型概率。给定一个单词序列,完整的分词方案是将其切分为若干子词单元。设分词结果为(t1, t2, ..., tn),则该分词方案的语言模型概率为各个子词条件概率的乘积:

P(分词) = P(``|t1) × P(t1|t2) × P(t2|t3) × ... × P(tn-1|tn)

每个条件概率P(ti|ti+1)可以通过统计训练语料中子词对的出现频率计算得到:P(ti|ti+1) = Count(ti, ti+1) / Count(ti+1)。

在训练过程中,WordPiece需要决定哪两个子词应该合并。不同于BPE直接选择最高频的字符对,WordPiece评估的是合并后对语言模型概率的提升。具体来说,对于候选合并对(A, B),计算合并前的联合概率贡献与合并后的联合概率贡献之差,选择使整体似然提升最大的对。

这个差异在实际应用中产生了明显区别。考虑单词"unsupervised",BPE可能优先合并出现频率最高的字符对,而WordPiece会考虑合并后对整体句子概率的影响。如果"un"和"super"在语料中有明确且独立的语义作用,WordPiece可能选择保留它们而非强行合并。

分词阶段也存在差异。BPE采用确定性的贪心匹配,而WordPiece通常采用动态规划或类似Viterbi算法来寻找最优分词路径。具体而言,对于输入单词,从右到左(或从左到右)遍历所有可能的分词位置,计算每种分词方案的概率,选择概率最高的方案。

BPE代码实现详解

理解算法原理后,通过代码实现可以更深入地掌握细节。以下是一个完整的BPE训练和分词实现。

python 复制代码
from collections import Counter, defaultdict
import re

class BPE:
    def __init__(self, vocab_size=10000):
            self.vocab_size = vocab_size
                    self.vocab = {}
                            self.merges = {}
                                
                                    def get_stats(self, vocab):
                                            """统计所有字符对的频率"""
                                                    pairs = Counter()
                                                            for word, freq in vocab.items():
                                                                        symbols = word.split()
                                                                                    for i in range(len(symbols) - 1):
                                                                                                    pairs[(symbols[i], symbols[i+1])] += freq
                                                                                                            return pairs
                                                                                                                
                                                                                                                    def merge_vocab(self, pair, vocab):
                                                                                                                            """合并所有词汇中的指定字符对"""
                                                                                                                                    v_out = {}
                                                                                                                                            bigram = re.escape(pair[0] + ' ' + pair[1])
                                                                                                                                                    pattern = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
                                                                                                                                                            
                                                                                                                                                                    for word in vocab:
                                                                                                                                                                                w_out = pattern.sub(''.join(pair), word)
                                                                                                                                                                                            v_out[w_out] = vocab[word]
                                                                                                                                                                                                    return v_out
                                                                                                                                                                                                        
                                                                                                                                                                                                            def train(self, corpus):
                                                                                                                                                                                                                    """
                                                                                                                                                                                                                            训练BPE模型
                                                                                                                                                                                                                                    corpus: 单词列表
                                                                                                                                                                                                                                            """
                                                                                                                                                                                                                                                    # 初始化词汇表:每个单词拆分为单字符
                                                                                                                                                                                                                                                            vocab = Counter()
                                                                                                                                                                                                                                                                    for word in corpus:
                                                                                                                                                                                                                                                                                word_tokens = list(word) + ['</w>']
                                                                                                                                                                                                                                                                                            vocab[' '.join(word_tokens)] += 1
                                                                                                                                                                                                                                                                                                    
                                                                                                                                                                                                                                                                                                            # 迭代合并
                                                                                                                                                                                                                                                                                                                    while len(vocab) < self.vocab_size:
                                                                                                                                                                                                                                                                                                                                pairs = self.get_stats(vocab)
                                                                                                                                                                                                                                                                                                                                            if not pairs:
                                                                                                                                                                                                                                                                                                                                                            break
                                                                                                                                                                                                                                                                                                                                                                        
                                                                                                                                                                                                                                                                                                                                                                                    best_pair = max(pairs, key=pairs.get)
                                                                                                                                                                                                                                                                                                                                                                                                vocab = self.merge_vocab(best_pair, vocab)
                                                                                                                                                                                                                                                                                                                                                                                                            
                                                                                                                                                                                                                                                                                                                                                                                                                        self.merges[best_pair] = True
                                                                                                                                                                                                                                                                                                                                                                                                                                    self.vocab[best_pair] = len(self.vocab)
                                                                                                                                                                                                                                                                                                                                                                                                                                                
                                                                                                                                                                                                                                                                                                                                                                                                                                                            print(f"合并 {best_pair},词汇表大小: {len(vocab)}")
                                                                                                                                                                                                                                                                                                                                                                                                                                                                    
                                                                                                                                                                                                                                                                                                                                                                                                                                                                            # 添加单字符到词汇表
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    for char in set(''.join(corpus))):
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                if char not in self.vocab:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                self.vocab[char] = len(self.vocab)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        def tokenize(self, text):
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                """对输入文本进行分词"""
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        tokens = list(text) + ['</w>']
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        while len(tokens) > 1:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    # 寻找第一个可合并的位置
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                pairs = [(tokens[i], tokens[i+1]) for i in range(len(tokens)-1)]
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        # 找优先级最高的合并
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    min_rank = None
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                min_pair = None
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        for pair in pairs:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        if pair in self.merges:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            rank = self.merges[pair]
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                if min_rank is None or rank < min_rank:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        min_rank = rank
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                min_pair = pair
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        if min_pair is None:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        break
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                # 执行合并
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            new_tokens = []
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        i = 0
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    while i < len(tokens):
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    if i < len(tokens) - 1 and (tokens[i], tokens[i+1]) == min_pair:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        new_tokens.append(''.join(min_pair))
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            i += 2
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            else:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                new_tokens.append(tokens[i])
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    i += 1
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                tokens = new_tokens
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                return [t for t in tokens if t != '</w>']
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                ```
这段实现展示了BPE的核心机制:词汇表构建阶段的迭代合并和分词阶段的贪心应用。关键点在于使用空格分隔符来标记字符边界,以及通过re.escape处理可能包含特殊字符的合并对。

### WordPiece代码实现

WordPiece的实现更加复杂,因为它需要维护完整的词汇表并使用动态规划进行最优分词。

```python
class WordPiece:
    def __init__(self, vocab=None):
            self.vocab = vocab or {}
                    self.unk_token = '[UNK]'
                            self.unk_id = 0
                                
                                    def tokenize(self, text):
                                            """对文本进行分词,返回子词序列"""
                                                    output_tokens = []
                                                            
                                                                    for token in self._basic_tokenize(text):
                                                                                chars = list(token)
                                                                                            
                                                                                                        if token in self.vocab:
                                                                                                                        output_tokens.append(token)
                                                                                                                                        continue
                                                                                                                                                    
                                                                                                                                                                # 尝试将单词切分为子词
                                                                                                                                                                            tokens = []
                                                                                                                                                                                        start = 0
                                                                                                                                                                                                    
                                                                                                                                                                                                                while start < len(chars):
                                                                                                                                                                                                                                end = len(chars)
                                                                                                                                                                                                                                                cur_substr = None
                                                                                                                                                                                                                                                                
                                                                                                                                                                                                                                                                                # 从后向前寻找最长匹配
                                                                                                                                                                                                                                                                                                while start < end:
                                                                                                                                                                                                                                                                                                                    substr = ''.join(chars[start:end])
                                                                                                                                                                                                                                                                                                                                        if start > 0:
                                                                                                                                                                                                                                                                                                                                                                substr = '##' + substr
                                                                                                                                                                                                                                                                                                                                                                                    
                                                                                                                                                                                                                                                                                                                                                                                                        if substr in self.vocab:
                                                                                                                                                                                                                                                                                                                                                                                                                                cur_substr = substr
                                                                                                                                                                                                                                                                                                                                                                                                                                                        break
                                                                                                                                                                                                                                                                                                                                                                                                                                                                            end -= 1
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            if cur_substr is None:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                # 没有找到匹配,返回UNK
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    output_tokens.append(self.unk_token)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        break
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        tokens.append(cur_substr)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        start = end
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                output_tokens.extend(tokens)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                return output_tokens
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        def _basic_tokenize(self, text):
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                """基础分词:处理标点和空格"""
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        import re
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                # 简单实现:按空格分词,保留标点
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        tokens = re.findall(r'\w+|[^\s\w]', text)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                return [t.lower() for t in tokens]
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                ```
实际应用中,WordPiece词汇表通常由专门的工具(如Google的sentencepiece或BERT的tokenization工具)生成。词汇表中的子词带有特定前缀标记(``##`)表示这是词内延续。

### 大模型中的Tokenizer选择

不同大模型选择了不同的Tokenizer策略,这些选择深刻影响了模型的能力边界。

GPT系列采用BPE的变体------ByteLevelBPE。关键改进是使用UTF-8字节而非Unicode字符作为初始单位。UTF-8中任何字符都可以表示为1-4个字节,这意味着词汇表可以从256个基础字节开始训练。这种方法有两个优势:理论上有无限的"字符"词汇表,以及能够处理任意Unicode字符串而不会出现未知字符问题。GPT-4的词汇表包含超过10万个Token,反映了其处理多语言和特殊符号的能力。

BERT采用WordPiece,这与其预训练任务设计密切相关。BERT使用Masked Language Modeling,需要将输入的一部分Token替换为[MASK],WordPiece的概率优化目标与此高度一致。此外,BERT的词汇表中包含丰富的词根和词缀,这有助于模型学习形态学特征。

SentencePiece是另一个值得了解的框架,由Google开发并被T5等模型采用。SentencePiece将输入视为原始字节流,可以直接训练BPE或Unigram模型,无需预分词步骤。这种端到端的处理方式避免了不同语言的分词规则差异,更适合多语言模型。

### 分词对模型的影响

理解Tokenizer的选择对实际应用至关重要。相同的文本经过不同Tokenizer会产生显著差异的Token序列长度。以"神经网络Transformer大模型"为例:

使用较小词汇表的Tokenizer可能产生15-20个Token,而使用更大词汇表的Tokenizer可能只需要8-12个Token。Token数量的差异直接影响计算成本(与Token数的平方成正比)、显存占用和推理延迟。

此外,分词粒度影响模型对语义的理解能力。过粗的分词(如按空格分词)会导致严重的OOV问题;过细的分词(如纯字符级)会显著增加序列长度,削弱模型捕捉长距离依赖的能力。子词分词在两者之间取得了平衡,通过数据驱动的方式学习适合目标语料的词汇表。

---

标签:大模型、Tokenization、BPE、WordPiece、NLP
相关推荐
qq_411262422 小时前
四博 AI 智能音箱 4G S3 版本工程落地方案:三模联网、远场唤醒、打断播放与 AI 会话框架
人工智能·智能音箱
薛定猫AI2 小时前
【深度解析】Gemma Chat 本地 AI 编程 Agent:Electron + MLX + 开源模型的离线 Vibe Coding 实战
javascript·人工智能·electron
txg6662 小时前
MDVul:用语义路径重塑漏洞检测的图模型能力
人工智能·安全·网络安全
人工智能培训2 小时前
工程科研中的AI应用:结构力学分析技巧
人工智能·深度学习·机器学习·docker·容器
qq_411262422 小时前
四博 AI 智能音箱 4G S3 版本工程方案:三模联网、远场唤醒、AI 会话与打断架构设计
人工智能·智能音箱
风落无尘2 小时前
Claude Code 常用命令速查手册
人工智能
努力努力再努力FFF2 小时前
律师想了解AI法律咨询工具,能否用它提升案件检索效率?
大数据·人工智能
极智视界2 小时前
分类数据集 - 自然灾害场景飓风野火洪水地震分类数据集下载
人工智能·yolo·数据集·图像分类·算法训练·自然灾害检测
GlobalInfo2 小时前
全球人工智能停车机器人市场份额、规模、技术研究报告2026
人工智能·机器人