pytorch中的可学习查找表实现之nn.Embedding

假设我们需要一个查找表(Lookup Table),我们可以根据索引数字快速定位查找表中某个具体位置并读取出来。最简单的方法,可以通过一个二维数组或者二维list来实现。但如果我希望查找表的值可以通过梯度反向传播来修改,那么就需要用到nn.Embedding来实现了。

其实,我们需要用反向传播来修正表值的场景还是很多的,比如我们想存储数据的通用特征时,这个通用特征就可以用nn.Embedding来表示,常见于现在的各种codebook的trick。闲话不多说,我们来看栗子:

python 复制代码
import torch
from torch import nn

table = nn.Embedding(10, 3)
print(table.weight)
idx = torch.LongTensor([[1]])
b = table(idx)
print(b)

'''
output
Parameter containing:
tensor([[-0.2317, -0.9679, -1.9324],
        [ 0.2473,  1.1043, -0.7218],
        [ 0.5425, -0.3109, -0.1330],
        [-1.4006, -0.0675,  0.1376],
        [-0.1995,  0.7168,  0.5692],
        [-1.3572, -0.6407, -0.0128],
        [-0.0773,  1.1928, -1.0836],
        [ 0.1721, -0.9232, -0.4059],
        [ 1.6108, -0.4640,  0.3535],
        [ 0.6975,  1.6554, -0.2217]], requires_grad=True)
tensor([[[ 0.2473,  1.1043, -0.7218]]], grad_fn=<EmbeddingBackward0>)
'''

这段代码实际上就实现了一个查找表的功能,索引值为[[1]](注意有两个中括弧),返回值为对应的表值。我们还可以批量查找表值:

python 复制代码
import torch
from torch import nn

table = nn.Embedding(10, 3)
print(table)
print(table.weight)

indices = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
print(indices)

out = table(indices)
print(out)
print(out.shape)

通过输入索引张量来获取表值:[2,4] -> [2,4,3],请注意这个shape变化,即对应位置的索引获得对应位置的表值。

参考:https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html

本人亲自整理,有问题可留言交流~

相关推荐
骇城迷影2 小时前
从零复现GPT-2 124M
人工智能·pytorch·python·gpt·深度学习
zhangfeng11335 小时前
GitHub 知名博主 hiyouga 及其明星项目 LlamaFactory项目介绍 详细介绍
人工智能·pytorch·语言模型·github
钱彬 (Qian Bin)1 天前
基于Qwen3-VL-Embedding-2B与vLLM构建高精度多模态图像检索系统
embedding·vllm·多模态检索·qwen3-vl
AI资源库1 天前
解构嵌入模型之王:All-MiniLM-L6-v2 的文件树解密、蒸馏机制与工业级应用生态
langchain·nlp·bert·embedding·hugging face·fine-tuning·ai agent
查无此人byebye1 天前
从DDPM到DiT:扩散模型3大核心架构演进|CNN到Transformer的AIGC生成革命(附实操要点)
人工智能·pytorch·深度学习·架构·cnn·音视频·transformer
love530love1 天前
突破 Windows 编译禁区:BitNet 1-bit LLM 推理框架 GPU 加速部署编译 BitNet CUDA 算子全记录
c++·人工智能·pytorch·windows·python·cuda·bitnet
CCPC不拿奖不改名1 天前
Langflow源代码解析01:源代码拉取、安装依赖项,并运行langflow
人工智能·python·深度学习·langchain·embedding·rag·langflow
盼小辉丶1 天前
PyTorch实战(28)——PyTorch深度学习模型部署
人工智能·pytorch·深度学习·模型部署
呆萌小新@渊洁2 天前
LoRA 与参数高效微调:低秩适配实战指南
人工智能·pytorch·python·ai·语音识别
码行拾光2 天前
踩坑90分钟血泪复盘:Windows装PyTorch报DLL错误?根本原因是Python 3.12不兼容!
pytorch·windows·python