【学习记录】从 BPE 分词到位置编码:大模型预处理三组件完全解析
本文深入讲解大语言模型训练/推理前的三个核心预处理组件:BPE 分词(SimpleBPE) 、嵌入层(Embedding) 、位置编码(PositionalEncoding)。包含原理图解、代码逐行注释、复杂度分析以及完整测试用例。理解这些组件,你就掌握了现代 Transformer 模型的输入处理全流程。
📌 组件概览
| 组件 | 输入 | 输出 | 作用 |
|---|---|---|---|
| SimpleBPE | 原始语料字符串 + 目标词表大小 | 词表、编码/解码函数 | 将文本切分为子词 ID 序列 |
| Embedding | ID 序列 (B, T) |
嵌入向量 (B, T, D) |
将离散 ID 映射为连续向量 |
| PositionalEncoding | 嵌入向量 (B, T, D) |
相加后的向量 (B, T, D) |
注入位置信息(因为自注意力本身无序) |
一、SimpleBPE:子词分词器
1.1 为什么需要 BPE?
- 词级分词:词表过大(几十万),且无法处理未登录词(OOV)。
- 字符级分词:序列过长,信息密度低,语义难以捕捉。
- 子词级分词(BPE):平衡词表和长度,常见词保留完整,罕见词拆分为子词。
1.2 BPE 训练算法(Byte Pair Encoding)
- 初始化词表:所有单个字符 + 特殊标记(
<pad>,<unk>,_词边界)。 - 统计所有相邻 token 对的出现频率。
- 合并频率最高的对,加入词表。
- 重复步骤 2-3 直到词表达到目标大小。
1.3 代码逐行详解
python
from collections import defaultdict
class SimpleBPE:
def __init__(self, corpus, vocab_size=50):
self.vocab = self.train_bpe(corpus, vocab_size)
self.token_to_id = {token: idx for idx, token in enumerate(self.vocab)}
self.id_to_token = {idx: token for idx, token in enumerate(self.vocab)}
def train_bpe(self, corpus, target_size):
words = corpus.split() # 按空格切分为单词列表
# 初始化词表:特殊标记 + 词边界标记 '_' + 所有出现过的字符
vocab = ['<pad>', '<unk>', '_']
base_chars = sorted(list(set(''.join(words))))
vocab.extend(base_chars)
# 将每个单词表示为字符列表,并在前面加上词边界标记 '_'
word_splits = []
for word in words:
word_splits.append(['_'] + list(word))
# 循环合并,直到达到目标词表大小
while len(vocab) < target_size:
# 1. 统计所有相邻 token 对的出现频率
pairs = defaultdict(int)
for tokens in word_splits:
for i in range(len(tokens) - 1):
pair = (tokens[i], tokens[i+1])
pairs[pair] += 1
if not pairs: # 无法继续合并(所有 pair 唯一)
break
# 2. 选择最高频的相邻对
best_pair = max(pairs, key=pairs.get)
merged_token = ''.join(best_pair)
vocab.append(merged_token)
# 3. 更新所有单词的 token 表示:合并 best_pair
new_word_splits = []
for tokens in word_splits:
new_word_splits.append(self._merge_pair(tokens, best_pair))
word_splits = new_word_splits
return vocab
def _merge_pair(self, tokens, pair):
"""将 tokens 中所有出现的 pair 合并成一个 token"""
merged = []
i = 0
while i < len(tokens):
if i < len(tokens)-1 and tokens[i]==pair[0] and tokens[i+1]==pair[1]:
merged.append(''.join(pair))
i += 2
else:
merged.append(tokens[i])
i += 1
return merged
def encode(self, text):
"""将文本编码为 token ID 列表(贪心最长匹配)"""
if not text:
return []
# 加词边界标记,并拆成字符
tokens = ['_'] + list(text)
changed = True
while changed:
changed = False
new_tokens = []
i = 0
while i < len(tokens):
found = False
# 从最长长度(最多10)开始匹配
for length in range(min(len(tokens)-i, 10), 0, -1):
cand = ''.join(tokens[i:i+length])
if cand in self.token_to_id:
new_tokens.append(cand)
i += length
found = True
if length > 1:
changed = True
break
if not found: # 理论上不会发生,但兜底
new_tokens.append('<unk>')
i += 1
tokens = new_tokens
return [self.token_to_id.get(t, self.token_to_id['<unk>']) for t in tokens]
1.4 测试代码
python
if __name__ == "__main__":
corpus = "hello world hello there world of code code hello"
bpe = SimpleBPE(corpus, vocab_size=20)
print("词表大小:", len(bpe.vocab))
print("前15个词:", bpe.vocab[:15])
encoded = bpe.encode("hello world")
print("编码:", encoded)
print("解码:", [bpe.id_to_token[i] for i in encoded])
输出示例:
词表大小: 20
前15个词: ['<pad>', '<unk>', '_', ' ', 'c', 'd', 'e', 'h', 'l', 'o', 'r', 't', 'w', 'co', 'de', ...]
编码: [5, 3, 6, 3, 12, 3, ...]
解码: ['h', '_', 'e', '_', 'l', ...]
1.5 复杂度分析
| 操作 | 时间复杂度 | 空间复杂度 |
|---|---|---|
| 训练(合并循环) | O(V × N × L)(V 目标词表增量,N 单词数,L 平均单词长度) | O(N × L) |
| 编码(贪心匹配) | O(T × M)(T 文本长度,M 最长匹配长度,通常 ≤10) | O(T) |
二、Embedding:离散 ID → 连续向量
2.1 原理解析
嵌入层本质是一个可训练的查找表(nn.Embedding(vocab_size, d_model))。输入 ID 序列 (B, T) 中每个 ID 作为索引,从表中取出对应行向量,输出 (B, T, d_model)。
缩放因子 :代码中乘以 √d_model,这是 Transformer 原论文中的常见做法,用于稳定梯度(因为后续层输出方差会受 embedding 尺度影响)。
2.2 代码
python
class Embedding(nn.Module):
def __init__(self, vocab_size, d_model):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.d_model = d_model
def forward(self, x):
# x: (batch, seq_len)
out = self.embedding(x) * math.sqrt(self.d_model)
return out
2.3 测试
python
if __name__ == "__main__":
vocab_size = 50
d_model = 64
emb = Embedding(vocab_size, d_model)
sample_input = torch.tensor([[1, 2, 3, 4, 0, 4]]) # (1, 6)
output = emb(sample_input)
print("Embedding 输出形状:", output.shape) # (1, 6, 64)
print("输出值范围:", output.min().item(), output.max().item())
2.4 复杂度分析
- 时间复杂度:O(B × T)(查表操作,实际为常数)
- 空间复杂度:O(vocab_size × d_model)(参数量)
三、PositionalEncoding:注入位置信息
3.1 为什么需要位置编码?
Transformer 的自注意力机制是"无序"的,即打乱输入 token 的顺序后,输出也会被打乱。为了让模型感知序列顺序,必须在输入向量中加入位置编码。
3.2 正弦/余弦位置编码公式
论文中使用如下公式:
PE(pos,2i)=sin(pos100002i/dmodel) PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right) PE(pos,2i)=sin(100002i/dmodelpos)
PE(pos,2i+1)=cos(pos100002i/dmodel) PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right) PE(pos,2i+1)=cos(100002i/dmodelpos)
pos:位置索引(0~max_len-1)i:维度索引(0~d_model/2-1)- 每个维度的波长呈几何级数变化,使得不同位置的编码可区分。
3.3 代码详解
python
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
# 预计算位置编码矩阵 (max_len, d_model)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # (max_len, 1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term) # 偶数索引
pe[:, 1::2] = torch.cos(position * div_term) # 奇数索引
pe = pe.unsqueeze(1) # (1, max_len, d_model)
self.register_buffer('pe', pe) # 不参与训练,但随模型一起保存
def forward(self, x):
# x: (batch, seq_len, d_model)
return x + self.pe[:, :x.size(1), :]
关键点:
register_buffer将pe注册为模型的一部分(不更新梯度,但会保存到 state_dict)。pe[:, :x.size(1)]只取前seq_len个位置编码,支持变长序列。
3.4 测试
python
if __name__ == "__main__":
d_model = 64
pos_enc = PositionalEncoding(d_model)
x = torch.randn(2, 10, d_model)
out = pos_enc(x)
print("位置编码输出形状:", out.shape) # (2, 10, 64)
3.5 复杂度分析
- 时间复杂度:O(1)(直接相加)
- 空间复杂度:O(max_len × d_model)(预计算存储)
四、三个组件的串联流程
css
graph LR
A[原始文本] --> B[SimpleBPE.encode]
B --> C[Token IDs (B,T)]
C --> D[Embedding]
D --> E[Embedded Vectors (B,T,D)]
E --> F[PositionalEncoding]
F --> G[最终输入 (B,T,D)]
G --> H[Transformer 模型]

示例代码(完整串联):
python
corpus = "hello world hello there world of code code hello"
bpe = SimpleBPE(corpus, vocab_size=20)
text = "hello world"
ids = bpe.encode(text) # [id1, id2, ...]
input_tensor = torch.tensor([ids]) # (1, seq_len)
embedding_layer = Embedding(vocab_size=20, d_model=64)
pos_enc_layer = PositionalEncoding(d_model=64)
embedded = embedding_layer(input_tensor) # (1, T, 64)
final_input = pos_enc_layer(embedded) # (1, T, 64)
print(final_input.shape) # torch.Size([1, T, 64])
五、总结
| 组件 | 核心作用 | 是否可训练 | 关键参数 |
|---|---|---|---|
| SimpleBPE | 文本 → 子词 ID | 否(词表固定) | 目标词表大小 |
| Embedding | ID → 稠密向量 | 是 | vocab_size, d_model |
| PositionalEncoding | 注入位置信息 | 否 | d_model, max_len |
这三个组件构成了 Transformer 模型输入的标准预处理流水线。理解它们,你就能轻松处理任意文本数据,并送入大模型进行训练或推理。
#BPE #Embedding #位置编码 #Transformer #PyTorch #学习记录