示例回顾
python
import torch
import torch.nn as nn
# 定义嵌入字典的大小和嵌入维度
num_embeddings = 10
embedding_dim = 3
# 创建一个 nn.EmbeddingBag 实例
embedding_bag = nn.EmbeddingBag(num_embeddings, embedding_dim, mode='mean')
# 定义输入索引和偏移量
input_indices = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9])
offsets = torch.tensor([0, 4])
# 计算嵌入并进行归约
output = embedding_bag(input_indices, offsets)
print("EmbeddingBag output:")
print(output)
解释
-
嵌入字典:
num_embeddings = 10
表示嵌入字典的大小,即词汇表的大小。embedding_dim = 3
表示每个嵌入向量的维度。
-
输入索引和偏移量:
input_indices = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9])
是输入的索引张量,表示需要嵌入的词汇索引。offsets = torch.tensor([0, 4])
是偏移量张量,表示每个序列的起始位置。
-
嵌入向量:
nn.EmbeddingBag
会根据input_indices
从嵌入字典中查找对应的嵌入向量。
计算过程
假设嵌入字典中的嵌入向量如下(随机初始化):
embedding_matrix = [
[0.1, 0.2, 0.3], # index 0
[0.4, 0.5, 0.6], # index 1
[0.7, 0.8, 0.9], # index 2
[1.0, 1.1, 1.2], # index 3
[1.3, 1.4, 1.5], # index 4
[1.6, 1.7, 1.8], # index 5
[1.9, 2.0, 2.1], # index 6
[2.2, 2.3, 2.4], # index 7
[2.5, 2.6, 2.7], # index 8
[2.8, 2.9, 3.0] # index 9
]
计算步骤
-
查找嵌入向量 - 对于
input_indices = [1, 2, 4, 5, 4, 3, 2, 9]
,查找对应的嵌入向量:[ [0.4, 0.5, 0.6], # index 1 [0.7, 0.8, 0.9], # index 2 [1.3, 1.4, 1.5], # index 4 [1.6, 1.7, 1.8], # index 5 [1.3, 1.4, 1.5], # index 4 [1.0, 1.1, 1.2], # index 3 [0.7, 0.8, 0.9], # index 2 [2.8, 2.9, 3.0] # index 9 ]
-
应用偏移量:
offsets = [0, 4]
表示两个序列的起始位置:-
第一个序列:
input_indices[0:4]
对应的嵌入向量:[ [0.4, 0.5, 0.6], # index 1 [0.7, 0.8, 0.9], # index 2 [1.3, 1.4, 1.5], # index 4 [1.6, 1.7, 1.8] # index 5 ]
-
第二个序列:
input_indices[4:8]
对应的嵌入向量:[ [1.3, 1.4, 1.5], # index 4 [1.0, 1.1, 1.2], # index 3 [0.7, 0.8, 0.9], # index 2 [2.8, 2.9, 3.0] # index 9 ]
-
-
计算平均值:
- 对每个序列的嵌入向量进行平均计算:
-
第一个序列的平均值:
mean([ [0.4, 0.5, 0.6], [0.7, 0.8, 0.9], [1.3, 1.4, 1.5], [1.6, 1.7, 1.8] ]) = [1.0, 1.1, 1.2]
-
第二个序列的平均值:
mean([ [1.3, 1.4, 1.5], [1.0, 1.1, 1.2], [0.7, 0.8, 0.9], [2.8, 2.9, 3.0] ]) = [1.45, 1.55, 1.65]
-
- 对每个序列的嵌入向量进行平均计算:
-
输出结果:
-
最终输出的嵌入向量为:
[ [1.0, 1.1, 1.2], [1.45, 1.55, 1.65] ]
-
总结
在 nn.EmbeddingBag
中,mean
模式会对输入索引对应的嵌入向量进行平均计算。具体步骤如下:
- 根据输入索引查找对应的嵌入向量。
- 根据偏移量将输入索引分成多个序列。
- 对每个序列的嵌入向量进行平均计算。
- 输出归约后的嵌入向量。
通过这种方式,nn.EmbeddingBag
可以高效地处理变长序列的嵌入操作,并进行归约计算。