假设我们需要一个查找表(Lookup Table),我们可以根据索引数字快速定位查找表中某个具体位置并读取出来。最简单的方法,可以通过一个二维数组或者二维list来实现。但如果我希望查找表的值可以通过梯度反向传播来修改,那么就需要用到nn.Embedding
来实现了。
其实,我们需要用反向传播来修正表值的场景还是很多的,比如我们想存储数据的通用特征时,这个通用特征就可以用nn.Embedding来表示,常见于现在的各种codebook的trick。闲话不多说,我们来看栗子:
python
import torch
from torch import nn
table = nn.Embedding(10, 3)
print(table.weight)
idx = torch.LongTensor([[1]])
b = table(idx)
print(b)
'''
output
Parameter containing:
tensor([[-0.2317, -0.9679, -1.9324],
[ 0.2473, 1.1043, -0.7218],
[ 0.5425, -0.3109, -0.1330],
[-1.4006, -0.0675, 0.1376],
[-0.1995, 0.7168, 0.5692],
[-1.3572, -0.6407, -0.0128],
[-0.0773, 1.1928, -1.0836],
[ 0.1721, -0.9232, -0.4059],
[ 1.6108, -0.4640, 0.3535],
[ 0.6975, 1.6554, -0.2217]], requires_grad=True)
tensor([[[ 0.2473, 1.1043, -0.7218]]], grad_fn=<EmbeddingBackward0>)
'''
这段代码实际上就实现了一个查找表的功能,索引值为[[1]](注意有两个中括弧),返回值为对应的表值。我们还可以批量查找表值:
python
import torch
from torch import nn
table = nn.Embedding(10, 3)
print(table)
print(table.weight)
indices = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
print(indices)
out = table(indices)
print(out)
print(out.shape)
通过输入索引张量来获取表值:[2,4] -> [2,4,3],请注意这个shape变化,即对应位置的索引获得对应位置的表值。
参考:https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html
本人亲自整理,有问题可留言交流~