【Pytorch】nn.Embedding函数详解

文章目录

  • [一. nn.Embedding 官方注解](#一. nn.Embedding 官方注解)

一. nn.Embedding 官方注解

Pytorch官方注解: https://docs.pytorch.org/docs/stable/generated/torch.nn.modules.sparse.Embedding.html

定义:

bash 复制代码
class torch.nn.modules.sparse.Embedding(
	num_embeddings, 
	embedding_dim, 
	padding_idx=None, 
	max_norm=None, 
	norm_type=2.0, 
	scale_grad_by_freq=False, 
	sparse=False, 
	_weight=None, 
	_freeze=False, 
	device=None, 
	dtype=None
)

一个简单的查找表,用于存储固定词典和大小的嵌入向量。

该模块通常用于存储词嵌入,并通过索引检索它们。该模块的输入是索引列表,输出是相应的词嵌入。

Parameters 参数 :

  • num_embeddings(int)------嵌入字典的大小, 所有输出的索引必须在 [0, num_embeddings-1]
  • embedding_dim(int)------每个嵌入向量的大小, 即词向量的维度(如 128,256,512)
  • padding_idx(int,可选)------ 若指定,则位于padding_idx的条目不会对梯度产生影响;因此,padding_idx处的嵌入向量在训练期间不会更新,即它会保持为固定的"填充"。对于新构建的嵌入层,padding_idx处的嵌入向量默认全为零,但可以更新为其他值以用作填充向量。该位置对应的嵌入向量不会产生梯度,训练过程中不会被更新,用于序列补齐,避免填充符号影响训练。
  • max_norm(float,可选)------若提供此参数,则每个范数大于max_norm的嵌入向量会被重新归一化,以使其范数等于max_norm。仅限制向量长度,不改变语义方向,用于防止梯度爆炸、训练不稳定。
  • norm_type (float, 可选)------用于计算max_norm选项的p范数中的p值。默认值为2
  • scale_grad_by_freq(bool,可选)------ 若提供该参数,将按小批量中单词频率的倒数来缩放梯度。默认值为False。词出现的越多梯度越小,词出现的越少,梯度越大, 避免高词频梯度过大,让低词频也能正常学习。
  • sparse(bool,可选)------如果为True,则相对于weight矩阵的梯度将是稀疏张量。有关稀疏梯度的更多详细信息,请参见注释。即只计算本次batch出现词的梯度,未出现词不占用显存、不存储梯度,大幅节省内存,提升速度。

Variables 变量 :

  • weight (张量) : 模块的可学习权重,形状为(num_embeddings,embedding_dim),初始值来自 N(0,1)

Shape形状:

  • Input (*):任意形状的IntTensor或LongTensor,包含要提取的索引
  • Output(*,H):其中 * 是输入形状,且 H = embedding_dim

注意事项:

请记住,只有有限数量的优化器支持稀疏梯度:目前是optim.SGD(CUDA和CPU)、optim.SparseAdam(CUDA和CPU)以及optim.Adagrad(CPU)

注意事项:

当 max_norm 不为 None 时,Embedding 的forward 方法会就地修改weight张量。由于梯度计算所需的张量不能就地修改,因此在调用Embedding的forward方法之前对Embedding.weight执行可微分操作时,若max_norm不为None,则需要克隆Embedding.weight,例如:

python 复制代码
n,d,m = 3,5,7
embedding = nn.Embedding(n, d, max_norm=1.0)
w = torch.tensor([1,2])
a = (
	embedding.weight.clone() @ W.t()
)
b = embedding(idx) @ W.t()
out = a.unsqueeze(0) + b.unsqueeze(1)
loss = out.sigmoid().prod()
loss.backward

Example 示例:

python 复制代码
>>> # an Embedding module containing 10 tensors of size 3
>>> embedding = nn.Embedding(10, 3)
>>> # a batch of 2 samples of 4 indices each
>>> input = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])
>>> embedding(input)
tensor([[[-0.0251, -1.6902,  0.7172],
         [-0.6431,  0.0748,  0.6969],
         [ 1.4970,  1.3448, -0.9685],
         [-0.3677, -2.7265, -0.1685]],

        [[ 1.4970,  1.3448, -0.9685],
         [ 0.4362, -0.4004,  0.9400],
         [-0.6431,  0.0748,  0.6969],
         [ 0.9124, -2.3616,  1.1151]]])

>>> # example with padding_idx
>>> embedding = nn.Embedding(10, 3, padding_idx=0)
>>> input = torch.LongTensor([[0, 2, 0, 5]])
>>> embedding(input)
tensor([[[ 0.0000,  0.0000,  0.0000],
         [ 0.1535, -2.0309,  0.9315],
         [ 0.0000,  0.0000,  0.0000],
         [-0.1655,  0.9897,  0.0635]]])

>>> # example of changing `pad` vector
>>> padding_idx = 0
>>> embedding = nn.Embedding(3, 3, padding_idx=padding_idx)
>>> embedding.weight

Parameter containing:
tensor([[ 0.0000,  0.0000,  0.0000],
        [-0.7895, -0.7089, -0.0364],
        [ 0.6778,  0.5803,  0.2678]], requires_grad=True)

>>> with torch.no_grad():
>>>    embedding.weight[padding_idx] = torch.ones(3)
>>> embedding.weight
Parameter containing:
tensor([[ 1.0000,  1.0000,  1.0000],
        [-0.7895, -0.7089, -0.0364],
        [ 0.6778,  0.5803,  0.2678]], requires_grad=True)
python 复制代码
classmethod from_pretrained(
	embeddings, 
	freeze=True, 
	padding_idx=None, 
	max_norm=None, 
	norm_type=2.0, 
	scale_grad_by_freq=False, 
	sparse=False
)

从给定的二维 FloatTensor 创建嵌入实例

Parameters 参数 :

  • embeddings (Tensor) ------包含嵌入层权重的浮点张量。第一维度作为num_embeddings传递给嵌入层,第二维度作为embedding_dim传递
  • freeze (bool, optional) ------ 如果为True,则张量在学习过程中不会更新。相当于embedding.weight.requires_grad = False。默认值:True
  • padding_idx (int, optional) ------如果指定,padding_idx处的条目不会对梯度产生影响;因此,padding_idx处的嵌入向量在训练期间不会更新,即它会保持为固定的"填充"
  • max_norm (float, optional) -- 参见模块初始化文档
  • norm_type(浮点数,可选参数)------ 详见模块初始化文档。默认值为 2
  • scale_grad_by_freq(布尔型,可选参数)------ 详见模块初始化文档。默认值为 False
  • sparse(布尔型,可选参数)------ 详见模块初始化文档

Example 示例:

python 复制代码
>>> # FloatTensor containing pretrained weights
>>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
>>> embedding = nn.Embedding.from_pretrained(weight)
>>> # Get embeddings for index 1
>>> input = torch.LongTensor([1])
>>> embedding(input)
tensor([[ 4.0000,  5.1000,  6.3000]])
相关推荐
AI医影跨模态组学1 天前
Ann Oncol(IF=65.4)广东省人民医院放射科刘再毅&阿里巴巴达摩院等团队:基于非增强CT与深度学习的结直肠癌检测
人工智能·深度学习·论文·医学影像
学习论之费曼学习法1 天前
AI 入门 30 天挑战 - Day 19 费曼学习法版 - GAN 生成对抗网络
人工智能·学习·生成对抗网络
guslegend1 天前
第17节:模型忽略关键实体怎么办?注意力权重分配机制引导生成拒绝重点
人工智能·大模型·rag
Deepoch1 天前
Deepoc 具身模型开发板赋能智能轮椅自主随行与安全控制技术研究
人工智能·科技·安全·开发板·deepoc·智能轮椅
Magic-Yuan1 天前
算力的迷雾
人工智能·算法·机器学习
财迅通Ai1 天前
德福科技2025年净利增长145.91% 高端突破引领成长新篇
大数据·人工智能·科技·德福科技
AI医影跨模态组学1 天前
Nature Reviews Cancer(IF=66.8)澳门科技大学张康教授等团队:人工智能推动多组学与临床数据整合在基础和转化癌症研究中的进展
人工智能·科技·深度学习·论文·医学影像
天使的翅膀20251 天前
BM25为何精准匹配专有名词?
人工智能
weixin_669545201 天前
支持 18W 快充的 2 节/3 节串联锂电池高效同步升压充电芯片 SW7306
人工智能·单片机·嵌入式硬件·硬件工程
wayz111 天前
Day 16:PCA主成分分析与降维
人工智能·算法·机器学习