Embedding查表操作

1. 🔍 代码解析

python 复制代码
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size

self.embed_tokens = nn.Embedding(
    config.vocab_size,
    config.hidden_size,
    self.padding_idx
)

1. self.padding_idx = config.pad_token_id

  • 含义:指定 padding token 的 id。
  • 用途 :embedding 层遇到 padding token(一般用 <pad>)时,会自动把该位的 embedding 输出设为 全零 (因为 PyTorch 的 nn.Embedding 会自动将 padding_idx 对应向量固定为 0 且不会更新)。

2. self.vocab_size = config.vocab_size

  • 保存词表大小,比如 32000、50257 等。
  • embedding 的 lookup 范围就是 [0, vocab_size - 1]

3. nn.Embedding(config.vocab_size, config.hidden_size, padding_idx)

python 复制代码
nn.Embedding(
    num_embeddings=vocab_size,      # 字典大小,每个 token 对应一行 embedding
    embedding_dim=hidden_size,      # 每个 token 的向量维度
    padding_idx=pad_token_id        # padding 的 embedding 会被固定为 0
)

为了简单,我把参数设很小:

  • batch_size = 2
  • seq_len = 3
  • vocab_size = 5
  • hidden_size = 4

也就是:

Embedding 的权重矩阵 weight 是一个形状为:

复制代码
(vocab_size=5, hidden_size=4)

的查表:

复制代码
token_id   embedding vector(4 dim)
-------------------------------------
0        [0.1, 0.2, 0.3, 0.4]
1        [0.5, 0.6, 0.7, 0.8]
2        [0.9, 1.0, 1.1, 1.2]
3        [1.3, 1.4, 1.5, 1.6]
4        [1.7, 1.8, 1.9, 2.0]

✔️ 输入 Tensor

假设输入 token id 是:

复制代码
input_ids = [
  [2, 1, 0],   # 第一个 batch 的序列
  [3, 3, 4]    # 第二个 batch 的序列
]

形状为 (2, 3)(batch_size, seq_len)


✔️ Embedding 的查表过程(逐元素)

对第一个样本:

输入 [2, 1, 0]

token id embedding(从 weight 查表)
2 [0.9, 1.0, 1.1, 1.2]
1 [0.5, 0.6, 0.7, 0.8]
0 [0.1, 0.2, 0.3, 0.4]

So 这个序列变成:

复制代码
[
  [0.9, 1.0, 1.1, 1.2],
  [0.5, 0.6, 0.7, 0.8],
  [0.1, 0.2, 0.3, 0.4]
]

形状:(3, 4)(seq_len, hidden_size)


对第二个样本:

输入 [3, 3, 4]

token id embedding
3 [1.3, 1.4, 1.5, 1.6]
3 [1.3, 1.4, 1.5, 1.6]
4 [1.7, 1.8, 1.9, 2.0]

序列变成:

复制代码
[
  [1.3, 1.4, 1.5, 1.6],
  [1.3, 1.4, 1.5, 1.6],
  [1.7, 1.8, 1.9, 2.0]
]

形状也是 (3, 4)


🔚 最终输出

把两个样本堆叠:

复制代码
[
  [  # batch 0
    [0.9, 1.0, 1.1, 1.2],
    [0.5, 0.6, 0.7, 0.8],
    [0.1, 0.2, 0.3, 0.4]
  ],

  [  # batch 1
    [1.3, 1.4, 1.5, 1.6],
    [1.3, 1.4, 1.5, 1.6],
    [1.7, 1.8, 1.9, 2.0]
  ]
]

输出形状:

复制代码
(2, 3, 4)
 = (batch_size, seq_len, hidden_size)

✔️ PyTorch 代码版(完全一致)

python 复制代码
import torch
import torch.nn as nn

embed = nn.Embedding(5, 4)  # vocab=5, hidden=4

# 手动设置 weight,使其跟上面的示例一样好理解
embed.weight = nn.Parameter(torch.tensor([
    [0.1, 0.2, 0.3, 0.4],  # token 0
    [0.5, 0.6, 0.7, 0.8],  # token 1
    [0.9, 1.0, 1.1, 1.2],  # token 2
    [1.3, 1.4, 1.5, 1.6],  # token 3
    [1.7, 1.8, 1.9, 2.0],  # token 4
]))

input_ids = torch.tensor([
    [2, 1, 0],
    [3, 3, 4]
])

out = embed(input_ids)
print(out.shape)
print(out)

输出:

复制代码
torch.Size([2, 3, 4])

相关推荐
小狗照亮每一天1 小时前
【菜狗学深度学习】注意力机制手撕——20251201
人工智能·深度学习·机器学习
AI视觉网奇1 小时前
数字人 语音驱动
人工智能·python
宁大小白1 小时前
pythonstudy Day24
人工智能·机器学习
伯远医学1 小时前
CUT&RUN
java·服务器·网络·人工智能·python·算法·eclipse
一晌小贪欢1 小时前
Python-11 Python作用域与闭包:LEGB规则深度解析
开发语言·python·python基础·python小白·python作用域·python小庄
战南诚1 小时前
如何查看正在执行的事务
python·flask·sqlalchemy
丸码1 小时前
JDK1.8新特性全解析
linux·windows·python
@游子1 小时前
Python学习笔记-Day4
笔记·python·学习
艾莉丝努力练剑1 小时前
【Python基础:语法第二课】Python 流程控制详解:条件语句 + 循环语句 + 人生重开模拟器实战
人工智能·爬虫·python·pycharm