从 BPE 分词到位置编码:大模型预处理三组件完全解析

【学习记录】从 BPE 分词到位置编码:大模型预处理三组件完全解析

本文深入讲解大语言模型训练/推理前的三个核心预处理组件:BPE 分词(SimpleBPE)嵌入层(Embedding)位置编码(PositionalEncoding)。包含原理图解、代码逐行注释、复杂度分析以及完整测试用例。理解这些组件,你就掌握了现代 Transformer 模型的输入处理全流程。


📌 组件概览

组件 输入 输出 作用
SimpleBPE 原始语料字符串 + 目标词表大小 词表、编码/解码函数 将文本切分为子词 ID 序列
Embedding ID 序列 (B, T) 嵌入向量 (B, T, D) 将离散 ID 映射为连续向量
PositionalEncoding 嵌入向量 (B, T, D) 相加后的向量 (B, T, D) 注入位置信息(因为自注意力本身无序)

一、SimpleBPE:子词分词器

1.1 为什么需要 BPE?

  • 词级分词:词表过大(几十万),且无法处理未登录词(OOV)。
  • 字符级分词:序列过长,信息密度低,语义难以捕捉。
  • 子词级分词(BPE):平衡词表和长度,常见词保留完整,罕见词拆分为子词。

1.2 BPE 训练算法(Byte Pair Encoding)

  1. 初始化词表:所有单个字符 + 特殊标记(<pad>, <unk>, _ 词边界)。
  2. 统计所有相邻 token 对的出现频率。
  3. 合并频率最高的对,加入词表。
  4. 重复步骤 2-3 直到词表达到目标大小。

1.3 代码逐行详解

python 复制代码
from collections import defaultdict

class SimpleBPE:
    def __init__(self, corpus, vocab_size=50):
        self.vocab = self.train_bpe(corpus, vocab_size)
        self.token_to_id = {token: idx for idx, token in enumerate(self.vocab)}
        self.id_to_token = {idx: token for idx, token in enumerate(self.vocab)}

    def train_bpe(self, corpus, target_size):
        words = corpus.split()          # 按空格切分为单词列表
        # 初始化词表:特殊标记 + 词边界标记 '_' + 所有出现过的字符
        vocab = ['<pad>', '<unk>', '_']
        base_chars = sorted(list(set(''.join(words))))
        vocab.extend(base_chars)
        
        # 将每个单词表示为字符列表,并在前面加上词边界标记 '_'
        word_splits = []
        for word in words:
            word_splits.append(['_'] + list(word))

        # 循环合并,直到达到目标词表大小
        while len(vocab) < target_size:
            # 1. 统计所有相邻 token 对的出现频率
            pairs = defaultdict(int)
            for tokens in word_splits:
                for i in range(len(tokens) - 1):
                    pair = (tokens[i], tokens[i+1])
                    pairs[pair] += 1
            if not pairs:               # 无法继续合并(所有 pair 唯一)
                break
            # 2. 选择最高频的相邻对
            best_pair = max(pairs, key=pairs.get)
            merged_token = ''.join(best_pair)
            vocab.append(merged_token)
            # 3. 更新所有单词的 token 表示:合并 best_pair
            new_word_splits = []
            for tokens in word_splits:
                new_word_splits.append(self._merge_pair(tokens, best_pair))
            word_splits = new_word_splits
        return vocab

    def _merge_pair(self, tokens, pair):
        """将 tokens 中所有出现的 pair 合并成一个 token"""
        merged = []
        i = 0
        while i < len(tokens):
            if i < len(tokens)-1 and tokens[i]==pair[0] and tokens[i+1]==pair[1]:
                merged.append(''.join(pair))
                i += 2
            else:
                merged.append(tokens[i])
                i += 1
        return merged

    def encode(self, text):
        """将文本编码为 token ID 列表(贪心最长匹配)"""
        if not text:
            return []
        # 加词边界标记,并拆成字符
        tokens = ['_'] + list(text)
        changed = True
        while changed:
            changed = False
            new_tokens = []
            i = 0
            while i < len(tokens):
                found = False
                # 从最长长度(最多10)开始匹配
                for length in range(min(len(tokens)-i, 10), 0, -1):
                    cand = ''.join(tokens[i:i+length])
                    if cand in self.token_to_id:
                        new_tokens.append(cand)
                        i += length
                        found = True
                        if length > 1:
                            changed = True
                        break
                if not found:   # 理论上不会发生,但兜底
                    new_tokens.append('<unk>')
                    i += 1
            tokens = new_tokens
        return [self.token_to_id.get(t, self.token_to_id['<unk>']) for t in tokens]

1.4 测试代码

python 复制代码
if __name__ == "__main__":
    corpus = "hello world hello there world of code code hello"
    bpe = SimpleBPE(corpus, vocab_size=20)
    print("词表大小:", len(bpe.vocab))
    print("前15个词:", bpe.vocab[:15])
    encoded = bpe.encode("hello world")
    print("编码:", encoded)
    print("解码:", [bpe.id_to_token[i] for i in encoded])

输出示例

复制代码
词表大小: 20
前15个词: ['<pad>', '<unk>', '_', ' ', 'c', 'd', 'e', 'h', 'l', 'o', 'r', 't', 'w', 'co', 'de', ...]
编码: [5, 3, 6, 3, 12, 3, ...]
解码: ['h', '_', 'e', '_', 'l', ...]

1.5 复杂度分析

操作 时间复杂度 空间复杂度
训练(合并循环) O(V × N × L)(V 目标词表增量,N 单词数,L 平均单词长度) O(N × L)
编码(贪心匹配) O(T × M)(T 文本长度,M 最长匹配长度,通常 ≤10) O(T)

二、Embedding:离散 ID → 连续向量

2.1 原理解析

嵌入层本质是一个可训练的查找表(nn.Embedding(vocab_size, d_model))。输入 ID 序列 (B, T) 中每个 ID 作为索引,从表中取出对应行向量,输出 (B, T, d_model)

缩放因子 :代码中乘以 √d_model,这是 Transformer 原论文中的常见做法,用于稳定梯度(因为后续层输出方差会受 embedding 尺度影响)。

2.2 代码

python 复制代码
class Embedding(nn.Module):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.d_model = d_model

    def forward(self, x):
        # x: (batch, seq_len)
        out = self.embedding(x) * math.sqrt(self.d_model)
        return out

2.3 测试

python 复制代码
if __name__ == "__main__":
    vocab_size = 50
    d_model = 64
    emb = Embedding(vocab_size, d_model)
    sample_input = torch.tensor([[1, 2, 3, 4, 0, 4]])   # (1, 6)
    output = emb(sample_input)
    print("Embedding 输出形状:", output.shape)          # (1, 6, 64)
    print("输出值范围:", output.min().item(), output.max().item())

2.4 复杂度分析

  • 时间复杂度:O(B × T)(查表操作,实际为常数)
  • 空间复杂度:O(vocab_size × d_model)(参数量)

三、PositionalEncoding:注入位置信息

3.1 为什么需要位置编码?

Transformer 的自注意力机制是"无序"的,即打乱输入 token 的顺序后,输出也会被打乱。为了让模型感知序列顺序,必须在输入向量中加入位置编码。

3.2 正弦/余弦位置编码公式

论文中使用如下公式:

PE(pos,2i)=sin⁡(pos100002i/dmodel) PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right) PE(pos,2i)=sin(100002i/dmodelpos)
PE(pos,2i+1)=cos⁡(pos100002i/dmodel) PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right) PE(pos,2i+1)=cos(100002i/dmodelpos)

  • pos:位置索引(0~max_len-1)
  • i:维度索引(0~d_model/2-1)
  • 每个维度的波长呈几何级数变化,使得不同位置的编码可区分。

3.3 代码详解

python 复制代码
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        # 预计算位置编码矩阵 (max_len, d_model)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  # (max_len, 1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)   # 偶数索引
        pe[:, 1::2] = torch.cos(position * div_term)   # 奇数索引
        pe = pe.unsqueeze(1)   # (1, max_len, d_model)
        self.register_buffer('pe', pe)   # 不参与训练,但随模型一起保存

    def forward(self, x):
        # x: (batch, seq_len, d_model)
        return x + self.pe[:, :x.size(1), :]

关键点

  • register_bufferpe 注册为模型的一部分(不更新梯度,但会保存到 state_dict)。
  • pe[:, :x.size(1)] 只取前 seq_len 个位置编码,支持变长序列。

3.4 测试

python 复制代码
if __name__ == "__main__":
    d_model = 64
    pos_enc = PositionalEncoding(d_model)
    x = torch.randn(2, 10, d_model)
    out = pos_enc(x)
    print("位置编码输出形状:", out.shape)   # (2, 10, 64)

3.5 复杂度分析

  • 时间复杂度:O(1)(直接相加)
  • 空间复杂度:O(max_len × d_model)(预计算存储)

四、三个组件的串联流程

css 复制代码
graph LR
    A[原始文本] --> B[SimpleBPE.encode]
    B --> C[Token IDs (B,T)]
    C --> D[Embedding]
    D --> E[Embedded Vectors (B,T,D)]
    E --> F[PositionalEncoding]
    F --> G[最终输入 (B,T,D)]
    G --> H[Transformer 模型]
 

示例代码(完整串联):

python 复制代码
corpus = "hello world hello there world of code code hello"
bpe = SimpleBPE(corpus, vocab_size=20)
text = "hello world"
ids = bpe.encode(text)                     # [id1, id2, ...]
input_tensor = torch.tensor([ids])         # (1, seq_len)
embedding_layer = Embedding(vocab_size=20, d_model=64)
pos_enc_layer = PositionalEncoding(d_model=64)

embedded = embedding_layer(input_tensor)   # (1, T, 64)
final_input = pos_enc_layer(embedded)      # (1, T, 64)
print(final_input.shape)                   # torch.Size([1, T, 64])

五、总结

组件 核心作用 是否可训练 关键参数
SimpleBPE 文本 → 子词 ID 否(词表固定) 目标词表大小
Embedding ID → 稠密向量 vocab_size, d_model
PositionalEncoding 注入位置信息 d_model, max_len

这三个组件构成了 Transformer 模型输入的标准预处理流水线。理解它们,你就能轻松处理任意文本数据,并送入大模型进行训练或推理。

#BPE #Embedding #位置编码 #Transformer #PyTorch #学习记录

相关推荐
石榴树下的七彩鱼7 小时前
图片去水印 API 详解:从单图到批量自动化去水印(附 Python/JS/PHP 完整教程)
python·自动化·图片处理·图片去水印·石榴智能·api教程
Li emily13 小时前
解决了加密货币api多币种订阅时的数据乱序问题
人工智能·python·api·fastapi
2301_7815714213 小时前
Golang格式化输出占位符都有什么_Golang fmt占位符教程【通俗】
jvm·数据库·python
asdzx6713 小时前
使用 Python 为 PDF 添加页码 (详细教程)
python·pdf·页码
AI技术控13 小时前
《Transformers are Inherently Succinct》论文解读:从“能表达什么”到“多紧凑地表达”
人工智能·python·深度学习·机器学习·自然语言处理
金融大 k16 小时前
Python 全球指数监控面板:TickDB + REST + WebSocket 完整方案
python·websocket
啊哈哈1213816 小时前
系统设计复盘:为什么 Agent 的 ReAct 循环必须内嵌确定性保护层——以 FitMind 健康助手的路由与步骤控制为例
人工智能·python·react
一颗牙牙17 小时前
安装mmcv
开发语言·python·深度学习
大数据魔法师17 小时前
Streamlit(二)- Streamlit 架构与运行机制
python·web