大模型Tokenizer原理:BPE、WordPiece与子词编码的核心机制深度解析

大模型Tokenizer原理:BPE、WordPiece与子词编码的核心机制深度解析

在大型语言模型(LLM)的技术栈中,Tokenizer(分词器)是一个常被忽视却至关重要的组件。它位于文本输入模型的第一道关卡,直接决定了模型能够处理什么样的文本、以何种粒度理解语言,以及最终的性能表现。本文将深入剖析当前主流的子词分词算法------BPE、WordPiece与Unigram Language Model的核心原理,并通过代码示例揭示这些算法在工程实践中的实现细节。

为什么需要子词分词

传统的分词方式主要分为基于单词(Word-level)和基于字符(Character-level)两种。单词级分词将文本切分为完整词汇,如将"深度学习"切分为"深度"和"学习"。这种方法简单直接,但面临严重的词汇表膨胀问题------为保证模型能够处理各种词汇,需要构建包含数十万甚至上百万词条的词汇表,而许多低频词的存在造成严重的参数浪费。

字符级分词则将文本切分为单个字符,如"深"和"度"。这种方法的词汇表可以很小(通常只需覆盖所有Unicode字符),但每个token携带的语义信息极为有限,导致序列长度急剧膨胀,计算成本呈指数级上升。

子词分词(Subword Tokenization)应运而生,它在词级和字符级之间取得精妙平衡。通过将文本分解为更小的语义单元------既包含完整单词也包含词缀、词根等子词单元------子词分词能够在有限词汇表规模下高效处理开放词汇,同时保持合理的序列长度。

BPE:字节对编码的数学原理

BPE(Byte Pair Encoding)算法最初由Philip Gage在1994年提出,用于数据压缩领域。其核心思想简洁优雅:迭代合并频次最高的相邻字符对,逐步构建出能够高效表示训练语料的符号表。

算法流程

BPE的训练过程遵循以下步骤:首先将训练文本按字符级别拆分,每个字符作为独立的token;然后统计所有相邻token对的频次,选取出现最频繁的对进行合并;将选定的token对替换为新的合并token后更新语料;重复上述合并过程,直至达到预设的词汇表大小。

形式化地表述,假设当前词汇表为V,训练语料为D。算法维护一个合并规则集合M,初始为空。在每轮迭代中,算法扫描语料D,统计所有可应用的合并操作频次:

复制代码
score(pair) = count(merge(pair) in D)

选取score最高的pair作为本轮合并规则,加入M并更新语料。这个过程持续进行,直到|M|达到目标词汇表大小。

代码实现

以下是一个完整的BPE训练与分词实现:

python 复制代码
import re
from collections import Counter, defaultdict
from typing import List, Tuple, Dict

class BPETokenizer:
    def __init__(self, vocab_size: int = 8000):
            self.vocab_size = vocab_size
                    self.merges: List[Tuple[str, str]] = []
                            self.vocab: Dict[str, int] = {}
                                    self.reverse_vocab: Dict[int, str] = {}
                                            
                                                def get_stats(self, word_counts: Counter) -> Counter:
                                                        """统计所有相邻token对的频次"""
                                                                pairs = Counter()
                                                                        for word, freq in word_counts.items():
                                                                                    symbols = word.split()
                                                                                                for i in range(len(symbols) - 1):
                                                                                                                pairs[(symbols[i], symbols[i+1])] += freq
                                                                                                                        return pairs
                                                                                                                            
                                                                                                                                def merge_vocab(self, best_pair: Tuple[str, str], word_counts: Counter) -> Counter:
                                                                                                                                        """执行最高频token对的合并"""
                                                                                                                                                replacement = ''.join(best_pair)
                                                                                                                                                        new_word_counts = Counter()
                                                                                                                                                                pattern = re.compile(r'(?<!\S)' + re.escape(best_pair[0] + ' ' + best_pair[1]) + r'(?!\S)')
                                                                                                                                                                        
                                                                                                                                                                                for word, freq in word_counts.items():
                                                                                                                                                                                            new_word = pattern.sub(replacement, word)
                                                                                                                                                                                                        new_word_counts[new_word] = freq
                                                                                                                                                                                                                return new_word_counts
                                                                                                                                                                                                                    
                                                                                                                                                                                                                        def train(self, corpus: List[str]):
                                                                                                                                                                                                                                """BPE训练主流程"""
                                                                                                                                                                                                                                        # 字符级初始化:每个字符后添加</w>标记词边界
                                                                                                                                                                                                                                                word_counts = Counter([' '.join(list(text)) + ' </w>' for text in corpus])
                                                                                                                                                                                                                                                        self.vocab = {chr(i): i for i in range(256)}
                                                                                                                                                                                                                                                                self.vocab['</w>'] = 256
                                                                                                                                                                                                                                                                        
                                                                                                                                                                                                                                                                                for i in range(self.vocab_size - 257):
                                                                                                                                                                                                                                                                                            pairs = self.get_stats(word_counts)
                                                                                                                                                                                                                                                                                                        if not pairs:
                                                                                                                                                                                                                                                                                                                        break
                                                                                                                                                                                                                                                                                                                                    best_pair = pairs.most_common(1)[0][0]
                                                                                                                                                                                                                                                                                                                                                self.merges.append(best_pair)
                                                                                                                                                                                                                                                                                                                                                            
                                                                                                                                                                                                                                                                                                                                                                        # 更新词汇表
                                                                                                                                                                                                                                                                                                                                                                                    new_token = ''.join(best_pair)
                                                                                                                                                                                                                                                                                                                                                                                                self.vocab[new_token] = 256 + i + 1
                                                                                                                                                                                                                                                                                                                                                                                                            
                                                                                                                                                                                                                                                                                                                                                                                                                        # 合并语料中的token对
                                                                                                                                                                                                                                                                                                                                                                                                                                    word_counts = self.merge_vocab(best_pair, word_counts)
                                                                                                                                                                                                                                                                                                                                                                                                                                            
                                                                                                                                                                                                                                                                                                                                                                                                                                                    self.reverse_vocab = {v: k for k, v in self.vocab.items()}
                                                                                                                                                                                                                                                                                                                                                                                                                                                            print(f"训练完成,词汇表大小: {len(self.vocab)}")
                                                                                                                                                                                                                                                                                                                                                                                                                                                                
                                                                                                                                                                                                                                                                                                                                                                                                                                                                    def tokenize(self, text: str) -> List[str]:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                            """对输入文本进行分词"""
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    tokens = list(text) + ['</w>']
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    for merge_rule in self.merges:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                # 贪心匹配所有可合并的位置
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            i = 0
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        while i < len(tokens) - 1:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        if tokens[i] == merge_rule[0] and tokens[i+1] == merge_rule[1]:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            tokens = tokens[:i] + [''.join(merge_rule)] + tokens[i+2:]
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            else:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                i += 1
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                return [t for t in tokens if t != '</w>']
# 训练示例
corpus = [
    "deep learning neural network",
        "deep neural network",
            "learning deep learning",
                "neural network architecture",
                    "deep learning models"
                    ]
tokenizer = BPETokenizer(vocab_size=100)
tokenizer.train(corpus)

test_text = "deeplearning"
print(f"分词结果: {tokenizer.tokenize(test_text)}")

WordPiece:基于语言模型的分词策略

WordPiece算法由Google在2012年提出,最初应用于语音搜索系统,后被BERT采用并广为人知。与BPE的频次驱动不同,WordPiece采用语言模型评估来确定最优的合并操作。

核心评估函数

WordPiece的核心在于评估每个候选合并的收益。给定相邻token A和B,算法计算合并前后的语言模型似然增益:

复制代码
score(A, B) = log(P(AB)) - log(P(A)) - log(P(B))
           = log(P(AB) / (P(A) * P(B)))
           ```
这个公式的直观含义是:只有当A和B合并后的联合概率显著高于独立概率时,这次合并才是有价值的。如果AB的组合在实际语料中经常出现(高P(AB)),而A和B单独出现的情况较少,这说明AB作为一个整体单元是合理的。

实际实现中,WordPiece通常基于字符级别的bigram语言模型:

```python
import math
from collections import Counter, defaultdict

class WordPieceTokenizer:
    def __init__(self, vocab_size: int = 8000):
            self.vocab_size = vocab_size
                    self.vocab: set = set()
                            self.special_tokens = {'[PAD]': 0, '[UNK]': 1, '[CLS]': 2, '[SEP]': 3, '[MASK]': 4}
                                    
                                        def get_subword_freq(self, corpus: List[str]) -> Counter:
                                                """统计所有子词单元的频次"""
                                                        freq = Counter()
                                                                for word in corpus:
                                                                            # 添加词边界标记
                                                                                        word = ' '.join(list(word)) + ' @@'
                                                                                                    freq[word] += 1
                                                                                                            return freq
                                                                                                                
                                                                                                                    def compute_score(self, pair: str, vocab: set, word_freq: Counter) -> float:
                                                                                                                            """计算WordPiece合并分数(简化版)"""
                                                                                                                                    freq_ab = 0
                                                                                                                                            freq_a = 0
                                                                                                                                                    freq_b = 0
                                                                                                                                                            
                                                                                                                                                                    for word, count in word_freq.items():
                                                                                                                                                                                symbols = word.split()
                                                                                                                                                                                            # 统计pair出现的频次
                                                                                                                                                                                                        for i in range(len(symbols) - 1):
                                                                                                                                                                                                                        if symbols[i] == pair[0] and symbols[i+1] == pair[1]:
                                                                                                                                                                                                                                            freq_ab += count
                                                                                                                                                                                                                                                            if symbols[i] == pair[0]:
                                                                                                                                                                                                                                                                                freq_a += count
                                                                                                                                                                                                                                                                                                if symbols[i] == pair[1]:
                                                                                                                                                                                                                                                                                                                    freq_b += count
                                                                                                                                                                                                                                                                                                                            
                                                                                                                                                                                                                                                                                                                                    if freq_a == 0 or freq_b == 0:
                                                                                                                                                                                                                                                                                                                                                return -float('inf')
                                                                                                                                                                                                                                                                                                                                                        
                                                                                                                                                                                                                                                                                                                                                                # 计算互信息作为分数
                                                                                                                                                                                                                                                                                                                                                                        return math.log((freq_ab + 1) / (freq_a * freq_b + 1))
                                                                                                                                                                                                                                                                                                                                                                            
                                                                                                                                                                                                                                                                                                                                                                                def train(self, corpus: List[str]):
                                                                                                                                                                                                                                                                                                                                                                                        """WordPiece训练流程"""
                                                                                                                                                                                                                                                                                                                                                                                                # 初始化词汇表:所有字符 + 特殊token
                                                                                                                                                                                                                                                                                                                                                                                                        vocab = set()
                                                                                                                                                                                                                                                                                                                                                                                                                for word in corpus:
                                                                                                                                                                                                                                                                                                                                                                                                                            vocab.update(list(word))
                                                                                                                                                                                                                                                                                                                                                                                                                                    vocab.update(self.special_tokens.keys())
                                                                                                                                                                                                                                                                                                                                                                                                                                            
                                                                                                                                                                                                                                                                                                                                                                                                                                                    word_freq = self.get_subword_freq(corpus)
                                                                                                                                                                                                                                                                                                                                                                                                                                                            
                                                                                                                                                                                                                                                                                                                                                                                                                                                                    while len(vocab) < self.vocab_size:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                # 计算所有可能合并的分数
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            best_score = -float('inf')
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        best_pair = None
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                for word in list(word_freq.keys()):
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                symbols = word.split()
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                if len(symbols) < 2:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    continue
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    for i in range(len(symbols) - 1):
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        pair = (symbols[i], symbols[i+1])
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            if pair[0] in vocab and pair[1] in vocab:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    score = self.compute_score(pair, vocab, word_freq)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            if score > best_score:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        best_score = score
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    best_pair = pair
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            if best_pair is None or best_score == -float('inf'):
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            break
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    # 合并最高分的pair
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                vocab.add(''.join(best_pair))
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            # 更新语料频率
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        word_freq = self.merge_pair(best_pair, word_freq)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        self.vocab = vocab
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                print(f"WordPiece训练完成,词汇表大小: {len(self.vocab)}")
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        def merge_pair(self, pair: Tuple[str, str], word_freq: Counter) -> Counter:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                """合并语料中的token对"""
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        new_freq = Counter()
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                pair_str = ' '.join(pair)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        merged = ''.join(pair)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        for word, freq in word_freq.items():
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    new_word = word.replace(pair_str, merged)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                new_freq[new_word] += freq
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                return new_freq
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        def tokenize(self, text: str) -> List[str]:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                """WordPiece分词:最长匹配优先"""
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        tokens = []
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                i = 0
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                while i < len(text):
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            longest_match = None
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        for j in range(len(text) - i, 0, -1):
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        subword = text[i:i+j]
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        if subword in self.vocab:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            longest_match = subword
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                break
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        if longest_match:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        tokens.append(longest_match)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        i += len(longest_match)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    else:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    tokens.append('[UNK]')
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    i += 1
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    return tokens
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    ```
### Unigram Language Model:SentencePiece的基石

Unigram Language Model(ULM)是Google开发SentencePiece工具时采用的算法,它采用概率论方法,从另一个角度解决子词分词问题。

#### EM算法的视角

与BPE和WordPiece从空词汇表逐步构建不同,ULM假设已知一个大的候选词汇表,然后通过期望最大化(EM)算法找出能够最佳解释训练语料的子集。

ULM的核心假设是:每个词可以由多个子词序列生成,序列的概率是各子词概率的乘积。训练目标是最大化训练语料的整体似然:

L = Σ log(P(x_i))

复制代码
其中P(x) = Σ P(s|x)·P(s),s为词x的某种子词分词方式。

实际实现中,ULM采用一种贪心的剪枝策略:从大词汇表开始,迭代移除对整体似然贡献最小的token,直到达到目标规模。

### 分词器的实际工程考量

在LLM生产环境中,Tokenizer的选择需要权衡多个维度。

**词汇表规模与OOV处理**是首要考量。GPT系列采用BPE变体,词汇表通常在5万到10万之间,配合字节级回退机制处理未知字符。BERT的WordPiece词汇表约为3万。这意味着当遇到未登录词时,模型能够通过子词重组保持一定的理解能力。

**序列长度与计算效率**直接关联推理成本。假设平均每词对应1.3个token,对于1000词的输入,序列长度约为1300。使用BPE和WordPiece的模型通常能在这一数量级保持良好的计算效率。

**多语言支持**是另一个关键维度。中文的Tokenization存在特殊挑战------汉字不像英文那样天然由空格分隔。SentencePiece等工具通过无监督学习能够自动发现适合特定语言的子词边界,为多语言模型提供统一处理框架。

### 分词器与模型性能的关联

Tokenizer的设计决策深刻影响模型能力。词汇表过小会导致高频子词重复计算,浪费模型容量;词汇表过大则增加嵌入层参数,影响训练效率。

更关键的是,Tokenizer决定了模型处理特定文本的能力边界。在代码任务中,数字和特殊符号的处理尤为重要------BPE倾向于将连续数字作为整体保留,而某些变体会将每个数字单独token化,这对数学计算任务的性能有显著影响。

中文处理方面存在特殊考量:常用汉字约3500个,但专业领域可能需要覆盖更多生僻字。实验表明,中文BERT使用约21000个WordPiece单元,其中单字符占大多数,这保证了基本的中文理解能力。

### 代码实战:BPE训练完整流程

以下代码展示从语料预处理到分词器使用的完整流程:

```python
import json
import os
from pathlib import Path

class ProductionBPETokenizer:
    def __init__(self, vocab_size: int = 50000, min_frequency: int = 2):
            self.vocab_size = vocab_size
                    self.min_frequency = min_frequency
                            self.merges = {}
                                    self.vocab = {}
                                            self.inverse_vocab = {}
                                                    
                                                        def pretokenize(self, text: str) -> List[str]:
                                                                """预分词:处理标点、空白等"""
                                                                        # 保持标点符号独立
                                                                                text = re.sub(r'([.,!?;:()\[\]{}""\'\'《》【】])', r' \1 ', text)
                                                                                        # 规范化空白
                                                                                                text = re.sub(r'\s+', ' ', text).strip()
                                                                                                        return text.split()
                                                                                                            
                                                                                                                def build_vocab_from_corpus(self, corpus_path: Path) -> Counter:
                                                                                                                        """从语料库构建基础词汇频次"""
                                                                                                                                vocab = Counter()
                                                                                                                                        for file_path in corpus_path.rglob('*.txt'):
                                                                                                                                                    with open(file_path, 'r', encoding='utf-8') as f:
                                                                                                                                                                    for line in f:
                                                                                                                                                                                        tokens = self.pretokenize(line)
                                                                                                                                                                                                            # 字符级拆分
                                                                                                                                                                                                                                word_chars = []
                                                                                                                                                                                                                                                    for token in tokens:
                                                                                                                                                                                                                                                                            word_chars.extend(list(token))
                                                                                                                                                                                                                                                                                                    word_chars.append('<w>')
                                                                                                                                                                                                                                                                                                                        vocab[' '.join(word_chars)] += 1
                                                                                                                                                                                                                                                                                                                                return vocab
                                                                                                                                                                                                                                                                                                                                    
                                                                                                                                                                                                                                                                                                                                        def train(self, corpus_path: str):
                                                                                                                                                                                                                                                                                                                                                """完整训练流程"""
                                                                                                                                                                                                                                                                                                                                                        vocab = self.build_vocab_from_corpus(Path(corpus_path))
                                                                                                                                                                                                                                                                                                                                                                
                                                                                                                                                                                                                                                                                                                                                                        # 初始化词汇表:所有单字符 + 边界标记
                                                                                                                                                                                                                                                                                                                                                                                self.vocab = {chr(i): i for i in range(256)}
                                                                                                                                                                                                                                                                                                                                                                                        self.vocab['<w>'] = 256
                                                                                                                                                                                                                                                                                                                                                                                                self.vocab['<unk>'] = 257
                                                                                                                                                                                                                                                                                                                                                                                                        self.vocab['<s>'] = 258
                                                                                                                                                                                                                                                                                                                                                                                                                
                                                                                                                                                                                                                                                                                                                                                                                                                        # 迭代合并
                                                                                                                                                                                                                                                                                                                                                                                                                                for i in range(self.vocab_size - 259):
                                                                                                                                                                                                                                                                                                                                                                                                                                            pairs = self._get_pair_counts(vocab)
                                                                                                                                                                                                                                                                                                                                                                                                                                                        if not pairs:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                        break
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                best_pair = max(pairs.items(), key=lambda x: x[1])[0]
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        # 检查最小频次要求
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    if pairs[best_pair] < self.min_frequency:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    break
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            # 执行合并
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        vocab = self._merge_pair(best_pair, vocab)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    new_token = ''.join(best_pair)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                self.merges[best_pair] = new_token
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            self.vocab[new_token] = len(self.vocab)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            self.inverse_vocab = {v: k for k, v in self.vocab.items()}
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    print(f"训练完成: 词汇表大小={len(self.vocab)}, 合并规则数={len(self.merges)}")
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            def _get_pair_counts(self, vocab: Counter) -> Dict[Tuple[str, str], int]:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    """统计token对频次"""
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            pairs = defaultdict(int)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    for word, freq in vocab.items():
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                symbols = word.split()
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            for i in range(len(symbols) - 1):
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            pairs[(symbols[i], symbols[i+1])] += freq
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    return pairs
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            def _merge_pair(self, pair: Tuple[str, str], vocab: Counter) -> Counter:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    """合并token对并更新词汇"""
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            merged = Counter()
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    pattern = re.compile(r'(?<!\S)' + re.escape(' '.join(pair)) + r'(?!\S)')
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    for word, freq in vocab.items():
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                new_word = pattern.sub(''.join(pair), word)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            merged[new_word] += freq
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            return merged
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    def encode(self, text: str) -> List[int]:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            """编码文本为token IDs"""
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    tokens = []
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            for char in text:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        tokens.append(char)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                tokens.append('<w>')
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                for merge_rule, new_token in self.merges.items():
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            i = 0
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        while i < len(tokens) - 1:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        if tokens[i] == merge_rule[0] and tokens[i+1] == merge_rule[1]:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            tokens = tokens[:i] + [new_token] + tokens[i+2
相关推荐
阿Y加油吧1 小时前
二刷 LeetCode:300. 最长递增子序列 & 152. 乘积最大子数组 复盘笔记
笔记·算法·leetcode
自我意识的多元宇宙1 小时前
数据结构----希尔排序
数据结构·算法·排序算法
hhhhhh_we1 小时前
再定义“皮肤人格”:从Baumann 16型分型到预颜美历的AI时序人格
前端·图像处理·人工智能·python·aigc
石榴树下的七彩鱼1 小时前
OCR API价格对比2026:身份证/发票/医疗票据识别哪家性价比最高?含Python对接+成本公式
开发语言·人工智能·python·ocr·图像识别·文字识别·api接口
AI自动化工坊1 小时前
Claude Mythos技术解析:AI自主发现零日漏洞的安全实践
人工智能·安全·ai agent
威尔逊·柏斯科·希伯理1 小时前
机器学习-特征工程
人工智能·机器学习
eastyuxiao1 小时前
OpenClaw的PDF处理Skill收费吗?
人工智能·pdf
惊鸿一博1 小时前
深度学习特征匹配算法 LoFTR、DKM、RoMa 介绍
人工智能·深度学习·算法
eBest数字化转型方案1 小时前
基于AI的食品行业零售执行系统架构设计与实践 eBest
人工智能·系统架构·零售