【CS336】Transformer|2-BPE算法 -> Tokenizer封装

2 Byte-Pair Encoding (BPE)算法 -> Tokenizer封装

😄学习目标

  • 为什么需要Tokenizer(为什么需要分词)?Tokenizer与BPE的关系?
  • 为什么需要预分词(Pre-tokenization)?
  • BPE算法的原理(理解迭代过程)?
  • 封装好的简易 BPE Tokenizer 类,包含训练、编码和解码功能
  • BPE 和 Transformer 的关系?

2.1 Tokenizer与BPE

为什么不直接把UTF-8字节(0-255)丢给模型?

  • 计算成本爆炸:transformer的注意力机制计算复杂度是序列长度的NNN的平方,即O(N2)O(N^2)O(N2)
  • 例子:单词"transformer"占用11个字节
    • 如果以字节为单位,序列长度就是11
    • 如果分词后将其作为一个整体(Token),序列长度就是1

(1)Tokenizer的本质作用:实现文本压缩

在"词表大小"与"序列长度"之间寻找平衡点,通过将常见的字节组合合并打包成一个数字ID(Token),实现对文本的压缩。

(2)Tokenizer与BPE

BPE算法是Tokenizer的这一种主流实现方式,它通过统计学手段,利用频次自动寻找值得合并的字节组合。

2.2 预分词(Pre-tokenization)

在真正的工业级实现(如 HuggingFace tokenizers 库或 GPT-2/4)中,BPE 并不是直接把整篇文章当成一个超长字符串来切分,而是首先进行预分词(Pre-tokenization)

为什么必须要有预分词这一步?

  • 防止BPE学习到跨越语义边界的无意义组合,保证token的完整性
  • 案例:"dog."
    • 如果不切分,BPE 可能会把 g. 合并成 g.,这样组合毫无意义
    • 先把 "dog""." 强行分开,BPE 只能在 "dog" 内部合并,不能跨过去吃掉标点

2.2 BPE算法原理

2.2.1 核心概念

BPE (Byte-Pair Encoding) 是现代大模型(如 GPT-4, LLaMA)通用的分词算法。

  • 核心思想: Frequency is all you need(频率即正义)。统计语料中相邻出现频率最高的"字符对",将它们合并成一个新的 Token

  • 算法流程图: 原始文本Text (String) 通过正则表达式 (Regex) 将长文本强行切断得到预分词(Pre-tokenization) ,再通过统计每个独立单词的频率得到统计字典 (Stats Construction),最后在每个独立单词中分别使用BPE迭代合并,寻找最高频的字节对 (Pair),将其替换为新 ID,最后得到Token IDs (Int 0-N)

Regex Split
Count Freqs
BPE Merging
Text (String)
Word List (Pre-tokens)
Stats Dictionary (Counts)
Token IDs (Int 0-N)

2.2.2 BPE训练迭代流程

实验背景:

  • 语料: "Hello World! Hello Python! Hello CS336!"

  • 初始词表: 256 (ASCII/UTF-8 Bytes) -> 基础字节 (0-255)

  • 目标: 执行 4 次合并 (Merges) -> 新学习的合并词 (256-259)

⚪️1️⃣:字节化与初始化

将语料库通过UTF-8 编码转为字节。初始词表(Vocab)为0-255,每个字节都是一个独立的Token

  • H=72, e=101, l=108, o=111
  • 初始序列(字节值):[72, 101, 108, 108, 111, ...]
  • 初始序列(文本) : ['H', 'e', 'l','l','o', ...]

⚫️2️⃣:预分词 (Pre-tokenization)

  • 工具:正则表达式 (Regex)
  • 输入"Hello World! Hello Python! Hello CS336!"
  • 输出['Hello', ' World', '!', ' Hello', ' Python', '!', ' Hello', ' CS336', '!']

⚫️3️⃣:构建统计字典(Stats Construction)

BPE 算法不再扫描全文本,而是扫描这个加权字典。

  • 注意:"Hello" (句首) 和 " Hello" (句中,带空格) 被视为两个不同的基础词。
Token Bytes (IDs) Frequency 说明
b'Hello' 1 句首,无空格
b' Hello' 2 句中,带前导空格
b'!' 3 标点符号独立
b' World' 1
b' Python' 1
b' CS336' 1

🔴4️⃣:BPE 迭代合并 (Merging Process)

将词表扩大,覆盖高频出现的序列。合并操作在统计字典(Stats Construction)内部同时进行。

复制代码
Merge 1/4: (72, 101) -> 256 (Count: 3)   # 'H'+'e'
Merge 2/4: (256, 108) -> 257 (Count: 3)  # 'He'+'l'
Merge 3/4: (257, 108) -> 258 (Count: 3)  # 'Hel'+'l'
Merge 4/4: (258, 111) -> 259 (Count: 3)  # 'Hell'+'o'

💻Round 1: 寻找最佳 Pair

  • 统计: 扫描字典所有单词。
    • ('H', 'e')b'Hello' 中出现 1 次。
    • ('H', 'e')b' Hello' 中出现 2 次。
    • Total Count: 3 次。
  • 操作: 注册新 ID 256 (H+e)。
  • 效果: 所有的 H, e 变成 256

💻Round 2: 继续合并

  • 统计: (256, 108)('He', 'l') 出现 3 次。
  • 操作: 注册新 ID 257

💻Round 3 & 4: 完成单词构建

  • ... (中间步骤省略) ...
  • 最终合并: ('Hell', 'o') -> 259
  • 结果:
    • b'Hello' 变成了 [259]
    • b' Hello' 变成了 [32, 259] (空格 + Hello)。

最终生成的 **ID 字典 (Vocabulary)**包含两部分:基础字节 (0-255)新学习的合并词 (256-259)

新增的合并词,这是 BPE 算法根据高频共现关系"学会"的 Token,如下表:

ID Token (Bytes) Token (String) 来源说明
256 b'He' "He" Merge 1: H + e
257 b'Hel' "Hel" Merge 2: He (256) + l
258 b'Hell' "Hell" Merge 3: Hel (257) + l
259 b'Hello' "Hello" Merge 4: Hell (258) + o

2.3 Tokenizer封装

简易 BPE Tokenizer 类,包含训练、编码和解码功能

python 复制代码
import re
from collections import Counter

class BPETokenizer:
    def __init__(self):
        self.merges = {}  # 记录合并规则: (pair) -> new_id
        self.vocab = {}   # 记录解码映射: id -> bytes
        # 初始化基础词表 (0-255)
        for idx in range(256):
            self.vocab[idx] = bytes([idx])

    def pre_tokenize(self, text):
        """
        Step 1: 预分词 (Pre-tokenization)
        使用 GPT-4 风格正则 (简化版),将文本切分为单词列表。
        """
        # 匹配逻辑: 缩写 OR 单词(带空格) OR 数字 OR 标点 OR 纯空格
        pat = re.compile(r"""'(?:[sdmt]|ll|ve|re)| ?\w+| ?\d+| ?[^\s\w]+|\s+(?!\S)|\s+""")
        return re.findall(pat, text)

    def get_stats(self, vocab_dict):
        """
        Step 2: 基于字典统计频率 (Stats Construction)
        """
        stats = {}
        for ids, count in vocab_dict.items():
            # 遍历单个单词内的所有相邻对,并乘以该单词的频率
            for pair in zip(ids, ids[1:]):
                stats[pair] = stats.get(pair, 0) + count
        return stats

    def merge_vocab(self, vocab_dict, pair, new_id):
        """
        Step 3: 在字典中执行合并 (Merging)
        """
        new_vocab = {}
        for ids, count in vocab_dict.items():
            new_ids = []
            i = 0
            while i < len(ids):
                # 匹配 pair 则合并
                if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
                    new_ids.append(new_id)
                    i += 2
                else:
                    new_ids.append(ids[i])
                    i += 1
            new_vocab[tuple(new_ids)] = count
        return new_vocab

    def train(self, text, vocab_size, verbose=False):
        print(f"--- Training BPE (Target Vocab: {vocab_size}) ---")
        
        # 1. 预分词 & 构建初始统计字典
        words = self.pre_tokenize(text)
        vocab_dict = Counter()
        for word in words:
            vocab_dict[tuple(word.encode("utf-8"))] += 1
            
        print(f"Stats: {len(words)} total words, {len(vocab_dict)} unique words.")
        
        num_merges = vocab_size - 256
        
        # 2. 循环合并
        for i in range(num_merges):
            stats = self.get_stats(vocab_dict)
            if not stats: break
            
            pair = max(stats, key=stats.get)
            idx = 256 + i
            
            self.merges[pair] = idx
            self.vocab[idx] = self.vocab[pair[0]] + self.vocab[pair[1]]
            
            if verbose:
                print(f"Merge {i+1}/{num_merges}: {pair} -> {idx} (Count: {stats[pair]})")
            
            vocab_dict = self.merge_vocab(vocab_dict, pair, idx)
            
        print("--- Training Complete ---")

    def encode(self, text):
        """推理: 同样需要先预分词,再对每个词单独编码"""
        words = self.pre_tokenize(text)
        out_ids = []
        for word in words:
            word_ids = list(word.encode("utf-8"))
            while len(word_ids) >= 2:
                stats = self.get_stats({tuple(word_ids): 1})
                pair_to_merge = None
                min_merge_idx = float("inf")
                # 寻找最早训练出的 pair (贪心策略)
                for pair in stats:
                    if pair in self.merges and self.merges[pair] < min_merge_idx:
                        min_merge_idx = self.merges[pair]
                        pair_to_merge = pair
                if pair_to_merge is None: break
                
                # 执行一次合并
                new_ids = []
                i = 0
                while i < len(word_ids):
                    if i < len(word_ids)-1 and word_ids[i] == pair_to_merge[0] and word_ids[i+1] == pair_to_merge[1]:
                        new_ids.append(min_merge_idx)
                        i += 2
                    else:
                        new_ids.append(word_ids[i])
                        i += 1
                word_ids = new_ids
            out_ids.extend(word_ids)
        return out_ids

    def decode(self, ids):
        tokens = b"".join([self.vocab[idx] for idx in ids])
        return tokens.decode("utf-8", errors="replace")

# --- 验证测试 ---
if __name__ == "__main__":
    text = "Hello World! Hello Python! Hello CS336!"
    tokenizer = AdvancedBPETokenizer()
    tokenizer.train(text, vocab_size=260, verbose=True)
    
    print(f"\n[Test] Encode 'Hello World': {tokenizer.encode('Hello World')}")

训练日志

复制代码
Stats: 9 total words, 6 unique words.
Merge 1/4: (72, 101) -> 256 (Count: 3)   # 'H'+'e'
Merge 2/4: (256, 108) -> 257 (Count: 3)  # 'He'+'l'
Merge 3/4: (257, 108) -> 258 (Count: 3)  # 'Hel'+'l'
Merge 4/4: (258, 111) -> 259 (Count: 3)  # 'Hell'+'o'

最终编码测试

  • Input: "Hello World"
  • Encoded IDs: [259, 32, 87, 111, 114, 108, 100]
  • 对应文本["Hello", " ", "w", "o", "r", "l", "d"]

2.4 BPE与transformer

BPE 和 Transformer 的关系:供需关系

BPE 负责把人类语言变成数字(ID),Transformer 负责计算这些数字。

相关推荐
Yeats_Liao1 小时前
显存瓶颈分析:大模型推理过程中的内存管理机制
python·深度学习·神经网络·架构·开源
_OP_CHEN1 小时前
【算法基础篇】(四十七)乘法逆元终极宝典:从模除困境到三种解法全解析
c++·算法·蓝桥杯·数论·算法竞赛·乘法逆元·acm/icpc
杭州杭州杭州1 小时前
pta考试
数据结构·c++·算法
YuTaoShao1 小时前
【LeetCode 每日一题】2975. 移除栅栏得到的正方形田地的最大面积
算法·leetcode·职场和发展
来两个炸鸡腿1 小时前
【Datawhale组队学习202601】Base-NLP task02 预训练语言模型
学习·语言模型·自然语言处理
junziruruo1 小时前
损失函数(以FMTrack频率感知交互与多专家模型的损失为例)
图像处理·深度学习·学习·计算机视觉
少许极端2 小时前
算法奇妙屋(二十五)-递归问题
算法·递归·汉诺塔
li星野2 小时前
OpenCV4X学习-图像边缘检测、图像分割
深度学习·学习·计算机视觉
Remember_9932 小时前
【数据结构】初识 Java 集合框架:概念、价值与底层原理
java·c语言·开发语言·数据结构·c++·算法·游戏