nn.Embedding
Embedding其实是构造了一个巨大的张量表,对于输入tensor某个位置的标量,在Embedding表中查表进行赋值:
python
# 伪代码演示
# 输入size: (1, 3133)
# Embedding size: (15536, 2048)
# output = torch.zeros(1,3133,2048)
for batch_idx in range(1):
for seq_idx in range(3133):
# 取出当前位置的Token ID,比如token_id=151656
token_id = input_tensor[batch_idx, seq_idx]
# 在Embedding权重字典里把151656那个长度为2048的向量取出来,直接赋值到输出的对应位置
output[batch_idx, seq_idx, :] = embedding_weight[token_id, :]
从原理上可以看到,input_tensor的每一个值,一定是在[0, Embedding.shape(0)],且是整数