Embedding查表操作

1. 🔍 代码解析

python 复制代码
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size

self.embed_tokens = nn.Embedding(
    config.vocab_size,
    config.hidden_size,
    self.padding_idx
)

1. self.padding_idx = config.pad_token_id

  • 含义:指定 padding token 的 id。
  • 用途 :embedding 层遇到 padding token(一般用 <pad>)时,会自动把该位的 embedding 输出设为 全零 (因为 PyTorch 的 nn.Embedding 会自动将 padding_idx 对应向量固定为 0 且不会更新)。

2. self.vocab_size = config.vocab_size

  • 保存词表大小,比如 32000、50257 等。
  • embedding 的 lookup 范围就是 [0, vocab_size - 1]

3. nn.Embedding(config.vocab_size, config.hidden_size, padding_idx)

python 复制代码
nn.Embedding(
    num_embeddings=vocab_size,      # 字典大小,每个 token 对应一行 embedding
    embedding_dim=hidden_size,      # 每个 token 的向量维度
    padding_idx=pad_token_id        # padding 的 embedding 会被固定为 0
)

为了简单,我把参数设很小:

  • batch_size = 2
  • seq_len = 3
  • vocab_size = 5
  • hidden_size = 4

也就是:

Embedding 的权重矩阵 weight 是一个形状为:

复制代码
(vocab_size=5, hidden_size=4)

的查表:

复制代码
token_id   embedding vector(4 dim)
-------------------------------------
0        [0.1, 0.2, 0.3, 0.4]
1        [0.5, 0.6, 0.7, 0.8]
2        [0.9, 1.0, 1.1, 1.2]
3        [1.3, 1.4, 1.5, 1.6]
4        [1.7, 1.8, 1.9, 2.0]

✔️ 输入 Tensor

假设输入 token id 是:

复制代码
input_ids = [
  [2, 1, 0],   # 第一个 batch 的序列
  [3, 3, 4]    # 第二个 batch 的序列
]

形状为 (2, 3)(batch_size, seq_len)


✔️ Embedding 的查表过程(逐元素)

对第一个样本:

输入 [2, 1, 0]

token id embedding(从 weight 查表)
2 [0.9, 1.0, 1.1, 1.2]
1 [0.5, 0.6, 0.7, 0.8]
0 [0.1, 0.2, 0.3, 0.4]

So 这个序列变成:

复制代码
[
  [0.9, 1.0, 1.1, 1.2],
  [0.5, 0.6, 0.7, 0.8],
  [0.1, 0.2, 0.3, 0.4]
]

形状:(3, 4)(seq_len, hidden_size)


对第二个样本:

输入 [3, 3, 4]

token id embedding
3 [1.3, 1.4, 1.5, 1.6]
3 [1.3, 1.4, 1.5, 1.6]
4 [1.7, 1.8, 1.9, 2.0]

序列变成:

复制代码
[
  [1.3, 1.4, 1.5, 1.6],
  [1.3, 1.4, 1.5, 1.6],
  [1.7, 1.8, 1.9, 2.0]
]

形状也是 (3, 4)


🔚 最终输出

把两个样本堆叠:

复制代码
[
  [  # batch 0
    [0.9, 1.0, 1.1, 1.2],
    [0.5, 0.6, 0.7, 0.8],
    [0.1, 0.2, 0.3, 0.4]
  ],

  [  # batch 1
    [1.3, 1.4, 1.5, 1.6],
    [1.3, 1.4, 1.5, 1.6],
    [1.7, 1.8, 1.9, 2.0]
  ]
]

输出形状:

复制代码
(2, 3, 4)
 = (batch_size, seq_len, hidden_size)

✔️ PyTorch 代码版(完全一致)

python 复制代码
import torch
import torch.nn as nn

embed = nn.Embedding(5, 4)  # vocab=5, hidden=4

# 手动设置 weight,使其跟上面的示例一样好理解
embed.weight = nn.Parameter(torch.tensor([
    [0.1, 0.2, 0.3, 0.4],  # token 0
    [0.5, 0.6, 0.7, 0.8],  # token 1
    [0.9, 1.0, 1.1, 1.2],  # token 2
    [1.3, 1.4, 1.5, 1.6],  # token 3
    [1.7, 1.8, 1.9, 2.0],  # token 4
]))

input_ids = torch.tensor([
    [2, 1, 0],
    [3, 3, 4]
])

out = embed(input_ids)
print(out.shape)
print(out)

输出:

复制代码
torch.Size([2, 3, 4])

相关推荐
audyxiao0012 分钟前
智能交通顶刊TITS论文分享|如何利用驾驶感知世界模型实现无信号灯路口自动驾驶?
人工智能·机器学习·自动驾驶·tits
90的程序爱好者3 分钟前
Flask 用户注册功能实现
python·flask
lisw057 分钟前
氛围炒股概述!
大数据·人工智能·机器学习
张3蜂2 小时前
Gunicorn深度解析:Python WSGI服务器的王者
服务器·python·gunicorn
Godspeed Zhao6 小时前
自动驾驶中的传感器技术24.3——Camera(18)
人工智能·机器学习·自动驾驶
rayufo8 小时前
【工具】列出指定文件夹下所有的目录和文件
开发语言·前端·python
Python 老手9 小时前
Python while 循环 极简核心讲解
java·python·算法
开源技术10 小时前
如何将本地LLM模型与Ollama和Python集成
开发语言·python
weixin_4370446410 小时前
Netbox批量添加设备——堆叠设备
linux·网络·python
我有医保我先冲10 小时前
AI 时代 “任务完成“ 与 “专业能力“ 的区分:理论基础、行业影响与个人发展策略
人工智能·python·机器学习