1. 🔍 代码解析
python
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(
config.vocab_size,
config.hidden_size,
self.padding_idx
)
1. self.padding_idx = config.pad_token_id
- 含义:指定 padding token 的 id。
- 用途 :embedding 层遇到 padding token(一般用
<pad>)时,会自动把该位的 embedding 输出设为 全零 (因为 PyTorch 的nn.Embedding会自动将padding_idx对应向量固定为 0 且不会更新)。
2. self.vocab_size = config.vocab_size
- 保存词表大小,比如 32000、50257 等。
- embedding 的 lookup 范围就是
[0, vocab_size - 1]。
3. nn.Embedding(config.vocab_size, config.hidden_size, padding_idx)
python
nn.Embedding(
num_embeddings=vocab_size, # 字典大小,每个 token 对应一行 embedding
embedding_dim=hidden_size, # 每个 token 的向量维度
padding_idx=pad_token_id # padding 的 embedding 会被固定为 0
)
为了简单,我把参数设很小:
batch_size = 2seq_len = 3vocab_size = 5hidden_size = 4
也就是:
Embedding 的权重矩阵 weight 是一个形状为:
(vocab_size=5, hidden_size=4)
的查表:
token_id embedding vector(4 dim)
-------------------------------------
0 [0.1, 0.2, 0.3, 0.4]
1 [0.5, 0.6, 0.7, 0.8]
2 [0.9, 1.0, 1.1, 1.2]
3 [1.3, 1.4, 1.5, 1.6]
4 [1.7, 1.8, 1.9, 2.0]
✔️ 输入 Tensor
假设输入 token id 是:
input_ids = [
[2, 1, 0], # 第一个 batch 的序列
[3, 3, 4] # 第二个 batch 的序列
]
形状为 (2, 3) → (batch_size, seq_len)
✔️ Embedding 的查表过程(逐元素)
对第一个样本:
输入 [2, 1, 0]
| token id | embedding(从 weight 查表) |
|---|---|
| 2 | [0.9, 1.0, 1.1, 1.2] |
| 1 | [0.5, 0.6, 0.7, 0.8] |
| 0 | [0.1, 0.2, 0.3, 0.4] |
So 这个序列变成:
[
[0.9, 1.0, 1.1, 1.2],
[0.5, 0.6, 0.7, 0.8],
[0.1, 0.2, 0.3, 0.4]
]
形状:(3, 4) → (seq_len, hidden_size)
对第二个样本:
输入 [3, 3, 4]:
| token id | embedding |
|---|---|
| 3 | [1.3, 1.4, 1.5, 1.6] |
| 3 | [1.3, 1.4, 1.5, 1.6] |
| 4 | [1.7, 1.8, 1.9, 2.0] |
序列变成:
[
[1.3, 1.4, 1.5, 1.6],
[1.3, 1.4, 1.5, 1.6],
[1.7, 1.8, 1.9, 2.0]
]
形状也是 (3, 4)。
🔚 最终输出
把两个样本堆叠:
[
[ # batch 0
[0.9, 1.0, 1.1, 1.2],
[0.5, 0.6, 0.7, 0.8],
[0.1, 0.2, 0.3, 0.4]
],
[ # batch 1
[1.3, 1.4, 1.5, 1.6],
[1.3, 1.4, 1.5, 1.6],
[1.7, 1.8, 1.9, 2.0]
]
]
输出形状:
(2, 3, 4)
= (batch_size, seq_len, hidden_size)
✔️ PyTorch 代码版(完全一致)
python
import torch
import torch.nn as nn
embed = nn.Embedding(5, 4) # vocab=5, hidden=4
# 手动设置 weight,使其跟上面的示例一样好理解
embed.weight = nn.Parameter(torch.tensor([
[0.1, 0.2, 0.3, 0.4], # token 0
[0.5, 0.6, 0.7, 0.8], # token 1
[0.9, 1.0, 1.1, 1.2], # token 2
[1.3, 1.4, 1.5, 1.6], # token 3
[1.7, 1.8, 1.9, 2.0], # token 4
]))
input_ids = torch.tensor([
[2, 1, 0],
[3, 3, 4]
])
out = embed(input_ids)
print(out.shape)
print(out)
输出:
torch.Size([2, 3, 4])