Embedding查表操作

1. 🔍 代码解析

python 复制代码
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size

self.embed_tokens = nn.Embedding(
    config.vocab_size,
    config.hidden_size,
    self.padding_idx
)

1. self.padding_idx = config.pad_token_id

  • 含义:指定 padding token 的 id。
  • 用途 :embedding 层遇到 padding token(一般用 <pad>)时,会自动把该位的 embedding 输出设为 全零 (因为 PyTorch 的 nn.Embedding 会自动将 padding_idx 对应向量固定为 0 且不会更新)。

2. self.vocab_size = config.vocab_size

  • 保存词表大小,比如 32000、50257 等。
  • embedding 的 lookup 范围就是 [0, vocab_size - 1]

3. nn.Embedding(config.vocab_size, config.hidden_size, padding_idx)

python 复制代码
nn.Embedding(
    num_embeddings=vocab_size,      # 字典大小,每个 token 对应一行 embedding
    embedding_dim=hidden_size,      # 每个 token 的向量维度
    padding_idx=pad_token_id        # padding 的 embedding 会被固定为 0
)

为了简单,我把参数设很小:

  • batch_size = 2
  • seq_len = 3
  • vocab_size = 5
  • hidden_size = 4

也就是:

Embedding 的权重矩阵 weight 是一个形状为:

复制代码
(vocab_size=5, hidden_size=4)

的查表:

复制代码
token_id   embedding vector(4 dim)
-------------------------------------
0        [0.1, 0.2, 0.3, 0.4]
1        [0.5, 0.6, 0.7, 0.8]
2        [0.9, 1.0, 1.1, 1.2]
3        [1.3, 1.4, 1.5, 1.6]
4        [1.7, 1.8, 1.9, 2.0]

✔️ 输入 Tensor

假设输入 token id 是:

复制代码
input_ids = [
  [2, 1, 0],   # 第一个 batch 的序列
  [3, 3, 4]    # 第二个 batch 的序列
]

形状为 (2, 3)(batch_size, seq_len)


✔️ Embedding 的查表过程(逐元素)

对第一个样本:

输入 [2, 1, 0]

token id embedding(从 weight 查表)
2 [0.9, 1.0, 1.1, 1.2]
1 [0.5, 0.6, 0.7, 0.8]
0 [0.1, 0.2, 0.3, 0.4]

So 这个序列变成:

复制代码
[
  [0.9, 1.0, 1.1, 1.2],
  [0.5, 0.6, 0.7, 0.8],
  [0.1, 0.2, 0.3, 0.4]
]

形状:(3, 4)(seq_len, hidden_size)


对第二个样本:

输入 [3, 3, 4]

token id embedding
3 [1.3, 1.4, 1.5, 1.6]
3 [1.3, 1.4, 1.5, 1.6]
4 [1.7, 1.8, 1.9, 2.0]

序列变成:

复制代码
[
  [1.3, 1.4, 1.5, 1.6],
  [1.3, 1.4, 1.5, 1.6],
  [1.7, 1.8, 1.9, 2.0]
]

形状也是 (3, 4)


🔚 最终输出

把两个样本堆叠:

复制代码
[
  [  # batch 0
    [0.9, 1.0, 1.1, 1.2],
    [0.5, 0.6, 0.7, 0.8],
    [0.1, 0.2, 0.3, 0.4]
  ],

  [  # batch 1
    [1.3, 1.4, 1.5, 1.6],
    [1.3, 1.4, 1.5, 1.6],
    [1.7, 1.8, 1.9, 2.0]
  ]
]

输出形状:

复制代码
(2, 3, 4)
 = (batch_size, seq_len, hidden_size)

✔️ PyTorch 代码版(完全一致)

python 复制代码
import torch
import torch.nn as nn

embed = nn.Embedding(5, 4)  # vocab=5, hidden=4

# 手动设置 weight,使其跟上面的示例一样好理解
embed.weight = nn.Parameter(torch.tensor([
    [0.1, 0.2, 0.3, 0.4],  # token 0
    [0.5, 0.6, 0.7, 0.8],  # token 1
    [0.9, 1.0, 1.1, 1.2],  # token 2
    [1.3, 1.4, 1.5, 1.6],  # token 3
    [1.7, 1.8, 1.9, 2.0],  # token 4
]))

input_ids = torch.tensor([
    [2, 1, 0],
    [3, 3, 4]
])

out = embed(input_ids)
print(out.shape)
print(out)

输出:

复制代码
torch.Size([2, 3, 4])

相关推荐
知乎的哥廷根数学学派42 分钟前
面向可信机械故障诊断的自适应置信度惩罚深度校准算法(Pytorch)
人工智能·pytorch·python·深度学习·算法·机器学习·矩阵
且去填词1 小时前
DeepSeek :基于 Schema 推理与自愈机制的智能 ETL
数据仓库·人工智能·python·语言模型·etl·schema·deepseek
数字化转型20251 小时前
企业数字化架构集成能力建设
大数据·程序人生·机器学习
人工干智能1 小时前
OpenAI Assistants API 中 client.beta.threads.messages.create方法,兼谈一星*和两星**解包
python·llm
databook1 小时前
当条形图遇上极坐标:径向与圆形条形图的视觉革命
python·数据分析·数据可视化
阿部多瑞 ABU2 小时前
`chenmo` —— 可编程元叙事引擎 V2.3+
linux·人工智能·python·ai写作
acanab2 小时前
VScode python插件
ide·vscode·python
知乎的哥廷根数学学派3 小时前
基于生成对抗U-Net混合架构的隧道衬砌缺陷地质雷达数据智能反演与成像方法(以模拟信号为例,Pytorch)
开发语言·人工智能·pytorch·python·深度学习·机器学习
WangYaolove13143 小时前
Python基于大数据的电影市场预测分析(源码+文档)
python·django·毕业设计·源码
知乎的哥廷根数学学派3 小时前
基于自适应多尺度小波核编码与注意力增强的脉冲神经网络机械故障诊断(Pytorch)
人工智能·pytorch·python·深度学习·神经网络·机器学习