Embedding查表操作

1. 🔍 代码解析

python 复制代码
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size

self.embed_tokens = nn.Embedding(
    config.vocab_size,
    config.hidden_size,
    self.padding_idx
)

1. self.padding_idx = config.pad_token_id

  • 含义:指定 padding token 的 id。
  • 用途 :embedding 层遇到 padding token(一般用 <pad>)时,会自动把该位的 embedding 输出设为 全零 (因为 PyTorch 的 nn.Embedding 会自动将 padding_idx 对应向量固定为 0 且不会更新)。

2. self.vocab_size = config.vocab_size

  • 保存词表大小,比如 32000、50257 等。
  • embedding 的 lookup 范围就是 [0, vocab_size - 1]

3. nn.Embedding(config.vocab_size, config.hidden_size, padding_idx)

python 复制代码
nn.Embedding(
    num_embeddings=vocab_size,      # 字典大小,每个 token 对应一行 embedding
    embedding_dim=hidden_size,      # 每个 token 的向量维度
    padding_idx=pad_token_id        # padding 的 embedding 会被固定为 0
)

为了简单,我把参数设很小:

  • batch_size = 2
  • seq_len = 3
  • vocab_size = 5
  • hidden_size = 4

也就是:

Embedding 的权重矩阵 weight 是一个形状为:

复制代码
(vocab_size=5, hidden_size=4)

的查表:

复制代码
token_id   embedding vector(4 dim)
-------------------------------------
0        [0.1, 0.2, 0.3, 0.4]
1        [0.5, 0.6, 0.7, 0.8]
2        [0.9, 1.0, 1.1, 1.2]
3        [1.3, 1.4, 1.5, 1.6]
4        [1.7, 1.8, 1.9, 2.0]

✔️ 输入 Tensor

假设输入 token id 是:

复制代码
input_ids = [
  [2, 1, 0],   # 第一个 batch 的序列
  [3, 3, 4]    # 第二个 batch 的序列
]

形状为 (2, 3)(batch_size, seq_len)


✔️ Embedding 的查表过程(逐元素)

对第一个样本:

输入 [2, 1, 0]

token id embedding(从 weight 查表)
2 [0.9, 1.0, 1.1, 1.2]
1 [0.5, 0.6, 0.7, 0.8]
0 [0.1, 0.2, 0.3, 0.4]

So 这个序列变成:

复制代码
[
  [0.9, 1.0, 1.1, 1.2],
  [0.5, 0.6, 0.7, 0.8],
  [0.1, 0.2, 0.3, 0.4]
]

形状:(3, 4)(seq_len, hidden_size)


对第二个样本:

输入 [3, 3, 4]

token id embedding
3 [1.3, 1.4, 1.5, 1.6]
3 [1.3, 1.4, 1.5, 1.6]
4 [1.7, 1.8, 1.9, 2.0]

序列变成:

复制代码
[
  [1.3, 1.4, 1.5, 1.6],
  [1.3, 1.4, 1.5, 1.6],
  [1.7, 1.8, 1.9, 2.0]
]

形状也是 (3, 4)


🔚 最终输出

把两个样本堆叠:

复制代码
[
  [  # batch 0
    [0.9, 1.0, 1.1, 1.2],
    [0.5, 0.6, 0.7, 0.8],
    [0.1, 0.2, 0.3, 0.4]
  ],

  [  # batch 1
    [1.3, 1.4, 1.5, 1.6],
    [1.3, 1.4, 1.5, 1.6],
    [1.7, 1.8, 1.9, 2.0]
  ]
]

输出形状:

复制代码
(2, 3, 4)
 = (batch_size, seq_len, hidden_size)

✔️ PyTorch 代码版(完全一致)

python 复制代码
import torch
import torch.nn as nn

embed = nn.Embedding(5, 4)  # vocab=5, hidden=4

# 手动设置 weight,使其跟上面的示例一样好理解
embed.weight = nn.Parameter(torch.tensor([
    [0.1, 0.2, 0.3, 0.4],  # token 0
    [0.5, 0.6, 0.7, 0.8],  # token 1
    [0.9, 1.0, 1.1, 1.2],  # token 2
    [1.3, 1.4, 1.5, 1.6],  # token 3
    [1.7, 1.8, 1.9, 2.0],  # token 4
]))

input_ids = torch.tensor([
    [2, 1, 0],
    [3, 3, 4]
])

out = embed(input_ids)
print(out.shape)
print(out)

输出:

复制代码
torch.Size([2, 3, 4])

相关推荐
iAm_Ike1 小时前
Go 中自定义类型与基础类型间的显式类型转换详解
jvm·数据库·python
iuvtsrt1 小时前
Golang怎么实现方法集与接口的匹配_Golang如何理解值类型和指针类型实现接口的区别【详解】
jvm·数据库·python
旦莫2 小时前
AI驱动的纯视觉自动化测试:知识库里应该积累什么知识内容
人工智能·python·测试开发·pytest·ai测试
知识领航员3 小时前
蘑兔AI音乐深度实测:功能拆解、实测表现与适用场景
java·c语言·c++·人工智能·python·算法·github
如何原谅奋力过但无声4 小时前
【灵神高频面试题合集06-08】反转链表、快慢指针(环形链表/重排链表)、前后指针(删除链表/链表去重)
数据结构·python·算法·leetcode·链表
deephub5 小时前
2026 RAG 选型指南:Vector、Graph、Vectorless 该怎么挑
人工智能·python·大语言模型·rag
狐狐生风7 小时前
使用 UV 创建并运行 Python 项目(完整步骤)
python·uv
噜噜噜阿鲁~7 小时前
python学习笔记 | 9.2、模块-安装第三方模块
笔记·python·学习
现代野蛮人7 小时前
【深度学习】 —— VGG-16 网络实现猫狗识别
网络·人工智能·python·深度学习·tensorflow