【Pytorch】nn.Embedding函数详解

文章目录

  • [一. nn.Embedding 官方注解](#一. nn.Embedding 官方注解)

一. nn.Embedding 官方注解

Pytorch官方注解: https://docs.pytorch.org/docs/stable/generated/torch.nn.modules.sparse.Embedding.html

定义:

bash 复制代码
class torch.nn.modules.sparse.Embedding(
	num_embeddings, 
	embedding_dim, 
	padding_idx=None, 
	max_norm=None, 
	norm_type=2.0, 
	scale_grad_by_freq=False, 
	sparse=False, 
	_weight=None, 
	_freeze=False, 
	device=None, 
	dtype=None
)

一个简单的查找表,用于存储固定词典和大小的嵌入向量。

该模块通常用于存储词嵌入,并通过索引检索它们。该模块的输入是索引列表,输出是相应的词嵌入。

Parameters 参数 :

  • num_embeddings(int)------嵌入字典的大小, 所有输出的索引必须在 [0, num_embeddings-1]
  • embedding_dim(int)------每个嵌入向量的大小, 即词向量的维度(如 128,256,512)
  • padding_idx(int,可选)------ 若指定,则位于padding_idx的条目不会对梯度产生影响;因此,padding_idx处的嵌入向量在训练期间不会更新,即它会保持为固定的"填充"。对于新构建的嵌入层,padding_idx处的嵌入向量默认全为零,但可以更新为其他值以用作填充向量。该位置对应的嵌入向量不会产生梯度,训练过程中不会被更新,用于序列补齐,避免填充符号影响训练。
  • max_norm(float,可选)------若提供此参数,则每个范数大于max_norm的嵌入向量会被重新归一化,以使其范数等于max_norm。仅限制向量长度,不改变语义方向,用于防止梯度爆炸、训练不稳定。
  • norm_type (float, 可选)------用于计算max_norm选项的p范数中的p值。默认值为2
  • scale_grad_by_freq(bool,可选)------ 若提供该参数,将按小批量中单词频率的倒数来缩放梯度。默认值为False。词出现的越多梯度越小,词出现的越少,梯度越大, 避免高词频梯度过大,让低词频也能正常学习。
  • sparse(bool,可选)------如果为True,则相对于weight矩阵的梯度将是稀疏张量。有关稀疏梯度的更多详细信息,请参见注释。即只计算本次batch出现词的梯度,未出现词不占用显存、不存储梯度,大幅节省内存,提升速度。

Variables 变量 :

  • weight (张量) : 模块的可学习权重,形状为(num_embeddings,embedding_dim),初始值来自 N(0,1)

Shape形状:

  • Input (*):任意形状的IntTensor或LongTensor,包含要提取的索引
  • Output(*,H):其中 * 是输入形状,且 H = embedding_dim

注意事项:

请记住,只有有限数量的优化器支持稀疏梯度:目前是optim.SGD(CUDA和CPU)、optim.SparseAdam(CUDA和CPU)以及optim.Adagrad(CPU)

注意事项:

当 max_norm 不为 None 时,Embedding 的forward 方法会就地修改weight张量。由于梯度计算所需的张量不能就地修改,因此在调用Embedding的forward方法之前对Embedding.weight执行可微分操作时,若max_norm不为None,则需要克隆Embedding.weight,例如:

python 复制代码
n,d,m = 3,5,7
embedding = nn.Embedding(n, d, max_norm=1.0)
w = torch.tensor([1,2])
a = (
	embedding.weight.clone() @ W.t()
)
b = embedding(idx) @ W.t()
out = a.unsqueeze(0) + b.unsqueeze(1)
loss = out.sigmoid().prod()
loss.backward

Example 示例:

python 复制代码
>>> # an Embedding module containing 10 tensors of size 3
>>> embedding = nn.Embedding(10, 3)
>>> # a batch of 2 samples of 4 indices each
>>> input = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])
>>> embedding(input)
tensor([[[-0.0251, -1.6902,  0.7172],
         [-0.6431,  0.0748,  0.6969],
         [ 1.4970,  1.3448, -0.9685],
         [-0.3677, -2.7265, -0.1685]],

        [[ 1.4970,  1.3448, -0.9685],
         [ 0.4362, -0.4004,  0.9400],
         [-0.6431,  0.0748,  0.6969],
         [ 0.9124, -2.3616,  1.1151]]])

>>> # example with padding_idx
>>> embedding = nn.Embedding(10, 3, padding_idx=0)
>>> input = torch.LongTensor([[0, 2, 0, 5]])
>>> embedding(input)
tensor([[[ 0.0000,  0.0000,  0.0000],
         [ 0.1535, -2.0309,  0.9315],
         [ 0.0000,  0.0000,  0.0000],
         [-0.1655,  0.9897,  0.0635]]])

>>> # example of changing `pad` vector
>>> padding_idx = 0
>>> embedding = nn.Embedding(3, 3, padding_idx=padding_idx)
>>> embedding.weight

Parameter containing:
tensor([[ 0.0000,  0.0000,  0.0000],
        [-0.7895, -0.7089, -0.0364],
        [ 0.6778,  0.5803,  0.2678]], requires_grad=True)

>>> with torch.no_grad():
>>>    embedding.weight[padding_idx] = torch.ones(3)
>>> embedding.weight
Parameter containing:
tensor([[ 1.0000,  1.0000,  1.0000],
        [-0.7895, -0.7089, -0.0364],
        [ 0.6778,  0.5803,  0.2678]], requires_grad=True)
python 复制代码
classmethod from_pretrained(
	embeddings, 
	freeze=True, 
	padding_idx=None, 
	max_norm=None, 
	norm_type=2.0, 
	scale_grad_by_freq=False, 
	sparse=False
)

从给定的二维 FloatTensor 创建嵌入实例

Parameters 参数 :

  • embeddings (Tensor) ------包含嵌入层权重的浮点张量。第一维度作为num_embeddings传递给嵌入层,第二维度作为embedding_dim传递
  • freeze (bool, optional) ------ 如果为True,则张量在学习过程中不会更新。相当于embedding.weight.requires_grad = False。默认值:True
  • padding_idx (int, optional) ------如果指定,padding_idx处的条目不会对梯度产生影响;因此,padding_idx处的嵌入向量在训练期间不会更新,即它会保持为固定的"填充"
  • max_norm (float, optional) -- 参见模块初始化文档
  • norm_type(浮点数,可选参数)------ 详见模块初始化文档。默认值为 2
  • scale_grad_by_freq(布尔型,可选参数)------ 详见模块初始化文档。默认值为 False
  • sparse(布尔型,可选参数)------ 详见模块初始化文档

Example 示例:

python 复制代码
>>> # FloatTensor containing pretrained weights
>>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
>>> embedding = nn.Embedding.from_pretrained(weight)
>>> # Get embeddings for index 1
>>> input = torch.LongTensor([1])
>>> embedding(input)
tensor([[ 4.0000,  5.1000,  6.3000]])
相关推荐
老成说AI2 小时前
SOUNDVIEW视频翻译:SHARK吸尘器如何靠TIKTOK打破高客单魔咒?
人工智能·跨境电商·tiktok·soundview
ByteX2 小时前
AI Coding
人工智能
jiajia_lisa2 小时前
科技暖民心,通行更便捷——车牌识别赋能民生出行
大数据·人工智能
非科班Java出身GISer2 小时前
国产 AI IDE(Agent) 颠覆传统开发方式:codebuddy 介绍,以及简单对比 trae、lingma、Comate
人工智能·ai编程·ai agent·ai ide·ai 开发工具·ai 开发软件
qyr67892 小时前
全球蜂窝分布式天线系统市场报告2026-2032
大数据·人工智能·数据分析·市场报告·蜂窝分布式天线系统
junior_Xin2 小时前
机器学习深度学习beginning5
人工智能·深度学习
电子科技圈2 小时前
SmartDV展示AI & HPC连接与存储IP解决方案,以解锁下一代算力芯片和节点的“速度密码”
网络·数据库·人工智能·嵌入式硬件·aigc·边缘计算
Daydream.V2 小时前
计算机视觉——疲劳检测、基于DNN的年龄性别预测
人工智能·计算机视觉·dnn·疲劳检测·年龄性别预测
龙文浩_2 小时前
AI的jieba分词原理与多模式应用解析
人工智能·pytorch·深度学习·神经网络