Embedding查表操作

1. 🔍 代码解析

python 复制代码
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size

self.embed_tokens = nn.Embedding(
    config.vocab_size,
    config.hidden_size,
    self.padding_idx
)

1. self.padding_idx = config.pad_token_id

  • 含义:指定 padding token 的 id。
  • 用途 :embedding 层遇到 padding token(一般用 <pad>)时,会自动把该位的 embedding 输出设为 全零 (因为 PyTorch 的 nn.Embedding 会自动将 padding_idx 对应向量固定为 0 且不会更新)。

2. self.vocab_size = config.vocab_size

  • 保存词表大小,比如 32000、50257 等。
  • embedding 的 lookup 范围就是 [0, vocab_size - 1]

3. nn.Embedding(config.vocab_size, config.hidden_size, padding_idx)

python 复制代码
nn.Embedding(
    num_embeddings=vocab_size,      # 字典大小,每个 token 对应一行 embedding
    embedding_dim=hidden_size,      # 每个 token 的向量维度
    padding_idx=pad_token_id        # padding 的 embedding 会被固定为 0
)

为了简单,我把参数设很小:

  • batch_size = 2
  • seq_len = 3
  • vocab_size = 5
  • hidden_size = 4

也就是:

Embedding 的权重矩阵 weight 是一个形状为:

复制代码
(vocab_size=5, hidden_size=4)

的查表:

复制代码
token_id   embedding vector(4 dim)
-------------------------------------
0        [0.1, 0.2, 0.3, 0.4]
1        [0.5, 0.6, 0.7, 0.8]
2        [0.9, 1.0, 1.1, 1.2]
3        [1.3, 1.4, 1.5, 1.6]
4        [1.7, 1.8, 1.9, 2.0]

✔️ 输入 Tensor

假设输入 token id 是:

复制代码
input_ids = [
  [2, 1, 0],   # 第一个 batch 的序列
  [3, 3, 4]    # 第二个 batch 的序列
]

形状为 (2, 3)(batch_size, seq_len)


✔️ Embedding 的查表过程(逐元素)

对第一个样本:

输入 [2, 1, 0]

token id embedding(从 weight 查表)
2 [0.9, 1.0, 1.1, 1.2]
1 [0.5, 0.6, 0.7, 0.8]
0 [0.1, 0.2, 0.3, 0.4]

So 这个序列变成:

复制代码
[
  [0.9, 1.0, 1.1, 1.2],
  [0.5, 0.6, 0.7, 0.8],
  [0.1, 0.2, 0.3, 0.4]
]

形状:(3, 4)(seq_len, hidden_size)


对第二个样本:

输入 [3, 3, 4]

token id embedding
3 [1.3, 1.4, 1.5, 1.6]
3 [1.3, 1.4, 1.5, 1.6]
4 [1.7, 1.8, 1.9, 2.0]

序列变成:

复制代码
[
  [1.3, 1.4, 1.5, 1.6],
  [1.3, 1.4, 1.5, 1.6],
  [1.7, 1.8, 1.9, 2.0]
]

形状也是 (3, 4)


🔚 最终输出

把两个样本堆叠:

复制代码
[
  [  # batch 0
    [0.9, 1.0, 1.1, 1.2],
    [0.5, 0.6, 0.7, 0.8],
    [0.1, 0.2, 0.3, 0.4]
  ],

  [  # batch 1
    [1.3, 1.4, 1.5, 1.6],
    [1.3, 1.4, 1.5, 1.6],
    [1.7, 1.8, 1.9, 2.0]
  ]
]

输出形状:

复制代码
(2, 3, 4)
 = (batch_size, seq_len, hidden_size)

✔️ PyTorch 代码版(完全一致)

python 复制代码
import torch
import torch.nn as nn

embed = nn.Embedding(5, 4)  # vocab=5, hidden=4

# 手动设置 weight,使其跟上面的示例一样好理解
embed.weight = nn.Parameter(torch.tensor([
    [0.1, 0.2, 0.3, 0.4],  # token 0
    [0.5, 0.6, 0.7, 0.8],  # token 1
    [0.9, 1.0, 1.1, 1.2],  # token 2
    [1.3, 1.4, 1.5, 1.6],  # token 3
    [1.7, 1.8, 1.9, 2.0],  # token 4
]))

input_ids = torch.tensor([
    [2, 1, 0],
    [3, 3, 4]
])

out = embed(input_ids)
print(out.shape)
print(out)

输出:

复制代码
torch.Size([2, 3, 4])

相关推荐
小黎14757789853643 小时前
OpenClaw 连接飞书完整指南:插件安装、配置与踩坑记录
机器学习
IVEN_5 小时前
只会Python皮毛?深入理解这几点,轻松进阶全栈开发
python·全栈
哥布林学者6 小时前
高光谱成像(二)光谱角映射 SAM
机器学习·高光谱成像
Ray Liang6 小时前
用六边形架构与整洁架构对比是伪命题?
java·python·c#·架构设计
AI攻城狮6 小时前
如何给 AI Agent 做"断舍离":OpenClaw Session 自动清理实践
python
千寻girling6 小时前
一份不可多得的 《 Python 》语言教程
人工智能·后端·python
AI攻城狮9 小时前
用 Playwright 实现博客一键发布到稀土掘金
python·自动化运维
曲幽10 小时前
FastAPI分布式系统实战:拆解分布式系统中常见问题及解决方案
redis·python·fastapi·web·httpx·lock·asyncio
哥布林学者1 天前
高光谱成像(一)高光谱图像
机器学习·高光谱成像