nn.Embedding 理解及其参数 padding_idx含义

看到一些文章对Embedding层理解上存在误区,故贡献一点自己的想法

误区文章:https://blog.csdn.net/weixin_38257276/article/details/114195454

python 复制代码
import torch
import torch.nn as nn

# 10 x 3的向量矩阵
embed = nn.Embedding(10,3)

# Embedding输入必须是tensor
input1 = torch.tensor(1)
print(input1)  # tensor(1)

input2 = torch.tensor([1, 1])
print(input2)  # tensor([1, 1])

input3 = torch.tensor([1, 2])  # tensor([1, 2])
print(input3)

input4 = torch.tensor([1, 10])  # tensor([ 1, 10])
print(input4)

out1 = embed(input1)
print(out1)
# tensor([ 0.1294, -0.1507, -0.0476], grad_fn=<EmbeddingBackward0>)

out2 = embed(input2)
print(out2)
# tensor([[-0.4178,  0.8059,  0.0863],
#         [-0.4178,  0.8059,  0.0863]], grad_fn=<EmbeddingBackward0>)

out3 = embed(input3)
print(out3)
# tensor([[-0.4178,  0.8059,  0.0863],
#         [ 0.9092, -0.8834, -0.5366]], grad_fn=<EmbeddingBackward0>)

out4 = embed(input4)
print(out4)
# IndexError: index out of range in self

# 综上,nn.Embedding(10, 3),10表示num_embeddings, 3表示embedding_dim
# 也就是10个嵌入向量,每个向量是3维(向量长度是3)
# nn.Embedding层的过程可以理解成根据索引查询Embedding向量矩阵的过程,当输入的索引值是0,即返回Embedding矩阵的第一行
# 当输入的索引值是10,由于定义的Embedding矩阵大小是10x3,最多只支持0-9索引,所以会报错(见out4)
# 另一种理解:输入的每个数字都可以表示成one-hot向量,这个向量维度就是10,比如输入的数字是2(索引为2),则对应向量[0 0 1 0 0...0]
# 这个one-hot向量和Embedding向量矩阵相乘,依然是得到Embedding矩阵的第三行。

# 故:10限定了输入的数字大小,正常情况是词表大小作为Embedding的num_embeddings,这样就可以根据各个词的索引查询到对应的向量;
# 3 是输出的向量维度

import torch.nn as nn

embed1 = nn.Embedding(3, 3)
print(embed1.weight)
# tensor([[ 1.0503,  1.2954,  0.0826],
#         [ 1.3010, -0.1322,  2.4299],
#         [ 0.2982, -0.0534, -0.0754]], requires_grad=True)


embed2 = nn.Embedding(3, 3, padding_idx=0)
print(embed2.weight)
# tensor([[ 0.0000,  0.0000,  0.0000],
#         [ 1.1654,  1.5345,  0.9253],
#         [ 1.0780, -1.8185, -1.4120]], requires_grad=True)


embed3 = nn.Embedding(3, 3, padding_idx=1)
print(embed3.weight)
# tensor([[-0.4296,  0.3443, -0.3189],
#         [ 0.0000,  0.0000,  0.0000],
#         [-0.8069,  0.9383,  0.9449]], requires_grad=True)


embed4 = nn.Embedding(3, 3, padding_idx=2)
print(embed4.weight)
# tensor([[-0.8485,  1.5352,  1.1185],
#         [-0.6012, -1.5501, -0.2466],
#         [ 0.0000,  0.0000,  0.0000]], requires_grad=True)


# 综上,padding_idx就是把Embeddings矩阵某一行置为0

input1 = torch.tensor([0, 1, 2, 2, 1, 0])
print(embed4(input1))


tensor([[ 0.4167, -0.5717, -0.9844],
        [ 1.1028, -0.3473,  0.5762],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 1.1028, -0.3473,  0.5762],
        [ 0.4167, -0.5717, -0.9844]], grad_fn=<EmbeddingBackward0>)
相关推荐
CoovallyAIHub2 小时前
开源的消逝与新生:从 TensorFlow 的落幕到开源生态的蜕变
pytorch·深度学习·llm
数据智能老司机7 小时前
精通 Python 设计模式——分布式系统模式
python·设计模式·架构
数据智能老司机8 小时前
精通 Python 设计模式——并发与异步模式
python·设计模式·编程语言
数据智能老司机8 小时前
精通 Python 设计模式——测试模式
python·设计模式·架构
数据智能老司机8 小时前
精通 Python 设计模式——性能模式
python·设计模式·架构
c8i8 小时前
drf初步梳理
python·django
每日AI新事件8 小时前
python的异步函数
python
这里有鱼汤9 小时前
miniQMT下载历史行情数据太慢怎么办?一招提速10倍!
前端·python
databook18 小时前
Manim实现脉冲闪烁特效
后端·python·动效
程序设计实验室19 小时前
2025年了,在 Django 之外,Python Web 框架还能怎么选?
python