nn.Embedding 理解及其参数 padding_idx含义

看到一些文章对Embedding层理解上存在误区,故贡献一点自己的想法

误区文章:https://blog.csdn.net/weixin_38257276/article/details/114195454

python 复制代码
import torch
import torch.nn as nn

# 10 x 3的向量矩阵
embed = nn.Embedding(10,3)

# Embedding输入必须是tensor
input1 = torch.tensor(1)
print(input1)  # tensor(1)

input2 = torch.tensor([1, 1])
print(input2)  # tensor([1, 1])

input3 = torch.tensor([1, 2])  # tensor([1, 2])
print(input3)

input4 = torch.tensor([1, 10])  # tensor([ 1, 10])
print(input4)

out1 = embed(input1)
print(out1)
# tensor([ 0.1294, -0.1507, -0.0476], grad_fn=<EmbeddingBackward0>)

out2 = embed(input2)
print(out2)
# tensor([[-0.4178,  0.8059,  0.0863],
#         [-0.4178,  0.8059,  0.0863]], grad_fn=<EmbeddingBackward0>)

out3 = embed(input3)
print(out3)
# tensor([[-0.4178,  0.8059,  0.0863],
#         [ 0.9092, -0.8834, -0.5366]], grad_fn=<EmbeddingBackward0>)

out4 = embed(input4)
print(out4)
# IndexError: index out of range in self

# 综上,nn.Embedding(10, 3),10表示num_embeddings, 3表示embedding_dim
# 也就是10个嵌入向量,每个向量是3维(向量长度是3)
# nn.Embedding层的过程可以理解成根据索引查询Embedding向量矩阵的过程,当输入的索引值是0,即返回Embedding矩阵的第一行
# 当输入的索引值是10,由于定义的Embedding矩阵大小是10x3,最多只支持0-9索引,所以会报错(见out4)
# 另一种理解:输入的每个数字都可以表示成one-hot向量,这个向量维度就是10,比如输入的数字是2(索引为2),则对应向量[0 0 1 0 0...0]
# 这个one-hot向量和Embedding向量矩阵相乘,依然是得到Embedding矩阵的第三行。

# 故:10限定了输入的数字大小,正常情况是词表大小作为Embedding的num_embeddings,这样就可以根据各个词的索引查询到对应的向量;
# 3 是输出的向量维度

import torch.nn as nn

embed1 = nn.Embedding(3, 3)
print(embed1.weight)
# tensor([[ 1.0503,  1.2954,  0.0826],
#         [ 1.3010, -0.1322,  2.4299],
#         [ 0.2982, -0.0534, -0.0754]], requires_grad=True)


embed2 = nn.Embedding(3, 3, padding_idx=0)
print(embed2.weight)
# tensor([[ 0.0000,  0.0000,  0.0000],
#         [ 1.1654,  1.5345,  0.9253],
#         [ 1.0780, -1.8185, -1.4120]], requires_grad=True)


embed3 = nn.Embedding(3, 3, padding_idx=1)
print(embed3.weight)
# tensor([[-0.4296,  0.3443, -0.3189],
#         [ 0.0000,  0.0000,  0.0000],
#         [-0.8069,  0.9383,  0.9449]], requires_grad=True)


embed4 = nn.Embedding(3, 3, padding_idx=2)
print(embed4.weight)
# tensor([[-0.8485,  1.5352,  1.1185],
#         [-0.6012, -1.5501, -0.2466],
#         [ 0.0000,  0.0000,  0.0000]], requires_grad=True)


# 综上,padding_idx就是把Embeddings矩阵某一行置为0

input1 = torch.tensor([0, 1, 2, 2, 1, 0])
print(embed4(input1))


tensor([[ 0.4167, -0.5717, -0.9844],
        [ 1.1028, -0.3473,  0.5762],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 1.1028, -0.3473,  0.5762],
        [ 0.4167, -0.5717, -0.9844]], grad_fn=<EmbeddingBackward0>)
相关推荐
曲幽40 分钟前
你的FastAPI又在服务器上“跑不起来”了?来,今天咱把打包这件事彻底聊透
linux·windows·python·docker·fastapi·web·pyinstaller·nssm·services
AI玫瑰助手42 分钟前
Python函数:局部变量与全局变量的作用域
开发语言·python·信息可视化
imDwAaY43 分钟前
机器学习入门:从感知机到逻辑回归,理解线性分类器与Softmax CS188 Note20 学习笔记
人工智能·笔记·python·学习·机器学习·逻辑回归
2601_9611940244 分钟前
2026初级会计实务教材电子版|章节讲义+习题PDF
python·考研·django·pdf·virtualenv·pygame
极客笔记Jack1 小时前
Scanpy 富集分析实战:gseapy 从基因列表到通路解读
python
岁月宁静1 小时前
Hermes Agent:让你的AI智能体越用越聪明
python·agent
财经资讯数据_灵砚智能1 小时前
基于全球经济类多源新闻的NLP情感分析与数据可视化(夜间-次晨)2026年5月29日
人工智能·python·信息可视化·自然语言处理·ai编程
触底反弹1 小时前
从数据结构到 Prompt 设计:前端工程师的 AI 时代进阶指南
javascript·人工智能·python
好好风格1 小时前
这个开源项目,把本地大模型做成会说话的 Live2D 桌宠
人工智能·python·开源
Ada's2 小时前
【计算机基础系列】python语言:环境搭建
开发语言·python