nn.Embedding 根据索引生成的向量有权重吗

import torch

import torch.nn as nn

假设有一个大小为 10x3 的 Embedding 层,其中有 10 个单词,每个单词用一个长度为 3 的向量表示

num_words = 10

embedding_dim = 3

创建 Embedding 层

embedding_layer = nn.Embedding(num_words, embedding_dim)

print(embedding_layer.weight)

embedded_vectors = embedding_layer(torch.LongTensor([4]))

print(embedded_vectors)

embedded_vectors = embedding_layer(torch.LongTensor([5]))

print(embedded_vectors)

embedded_vectors = embedding_layer(torch.LongTensor([6]))

print(embedded_vectors)

nn.Embedding层单词转向量实测

1.nn.Embedding创建对象embedding_layer

2.可以看到embedding_layer创建完成,其属性weight已经有值了

3.embedding_layer方法传入分别torch.LongTensor([4]),torch.LongTensor([5]),(torch.LongTensor([6])生成的结果就是根据索引值去weight里取值。

打破猜测:

1.原以为embedding_layer里进行的一个乘法,传参*随机权重,如embedded_vectors =torch.LongTensor([4])*W,实际不是,没有乘法

2.实际是nn.Embedding(num_words, embedding_dim)根据参数已经随机生成了所有的向量,之后仅需根据索引取值

原始猜测:

1.由于程序每次重启embedding_layer.weight生成的参数随机,为供断点续训和预测用,这些参数不能每次都随机生成,所以这些应该是要保存在模型中。即断点续训或预测时,embedding层的向量不应是随机生成了,而是读取模型文件中存储的模型参数。

2.embedding_layer.weight参与梯度更新,一开始以为此处没有

3.一开始根据各种信息判断nn.Embedding的内部机制是,或许有一个随机参数,乘以输入单词索引,得到嵌入向量,并且这个参数不参加更新,潜意识是参数不保存。

复制代码
1.为什么有或许有一个随机参数,乘以输入单词索引,得到嵌入向量这样的理解?
	因为看着是传入了索引,得到了一个随机向量,合理猜测应该是有个随即参数与传参相乘。
	所以这里第一步的猜测就错了。
	首先这个参数的确是有的,机制的实际是随机生成所有的可能的索引的向量,供直接取用。这里参数即嵌入向量

2.这个参数不参加更新?
	这种不参与更新的参数,模型会保存吗?猜测应该不会,如果不会,那每次重启的得到的嵌入向量都变了,怎么供续训和预测用?

对于transformer/bert,网络上的确对nn.Embedding这一步骤的机制讲解不够清晰,不知道嵌入向量是怎么得出的,不知道其中是否有需要训练的参数。

嵌入参数不参加更新这说法主要是来自李宏毅讲的注意力机制那块的误解,说是除了wq,wk,wv参数参与训练,没有别的参数了。这就和续训预测产生了极强的的矛盾,难以判断。

当你创建 nn.Embedding 层时,PyTorch 会随机初始化权重。这些权重在训练过程中会通过反向传播进行更新,以拟合模型的输入和输出数据,确保模型能够更好地进行预测或分类任务。

相关推荐
陈天伟教授5 小时前
人工智能应用- 语言处理:02.机器翻译:规则方法
人工智能·深度学习·神经网络·语言模型·自然语言处理·机器翻译
却道天凉_好个秋5 小时前
Tensorflow数据增强(三):高级裁剪
人工智能·深度学习·tensorflow
Lun3866buzha6 小时前
【深度学习应用】鸡蛋裂纹检测与分类:基于YOLOv3的智能识别系统,从图像采集到缺陷分类的完整实现
深度学习·yolo·分类
大江东去浪淘尽千古风流人物7 小时前
【VLN】VLN仿真与训练三要素 Dataset,Simulators,Benchmarks(2)
深度学习·算法·机器人·概率论·slam
cyyt7 小时前
深度学习周报(2.2~2.8)
人工智能·深度学习
2401_836235867 小时前
财务报表识别产品:从“数据搬运”到“智能决策”的技术革命
人工智能·科技·深度学习·ocr·生活
啊森要自信8 小时前
CANN runtime 深度解析:异构计算架构下运行时组件的性能保障与功能增强实现逻辑
深度学习·架构·transformer·cann
kyle~8 小时前
深度学习---长短期记忆网络LSTM
人工智能·深度学习·lstm
DatGuy8 小时前
Week 36: 量子深度学习入门:辛量子神经网络与物理守恒
人工智能·深度学习·神经网络
肾透侧视攻城狮8 小时前
《解锁计算机视觉:深度解析 PyTorch torchvision 核心与进阶技巧》
人工智能·深度学习·计算机视觉模快·支持的数据集类型·常用变换方法分类·图像分类流程实战·视觉模快高级功能