【Pytorch:nn.Embedding】简介以及使用方法:用于生成固定数量的具有指定维度的嵌入向量embedding vector

文章目录

1、nn.Embedding

  • 首先我们讲解一下关于嵌入向量embedding vector的概念

1)在自然语言处理NLP领域,是将单词、短语或其他文本单位映射到一个固定长度的实数向量空间中 。嵌入向量具有较低的维度,通常在几十到几百维之间,且每个维度都包含一定程度上的语义信息。这意味着在嵌入向量空间中,语义上相似的单词在向量空间中也更加接近。

2)在计算机视觉领域,是将图像或图像中的区域映射到一个固定长度的实数向量空间中。嵌入向量在计算机视觉任务中起到了表示和提取特征的作用。通过将图像映射到嵌入向量空间,可以捕捉到图像的语义信息、视觉特征以及图像之间的相似性。

  • 总之,嵌入向量是具有固定维度的,而不论是在NLP领域还是CV领域,都需要生成多个嵌入向量,因此也有固定数量。
  • 于是,我们就可以简单理解该类为:
python 复制代码
CLASS torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None,
norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None, _freeze=False, device=None, dtype=None)
''
一个简单的查找表,用于存储固定词典和尺寸的embeddings:其实就是存储了固定数量的具有固定维度的嵌入向量
该模块需要使用索引检索嵌入向量:也就是说模块的输入是索引列表,输出是相应存储的嵌入向量。
1) num_embeddings: 嵌入向量的数量
2) embedding_dim: 嵌入向量的维度
注意:
1)它的成员变量weight:具有shape为 (num_embeddings, embedding_dim) 的可学习的参数
2)输入为:任意形状[*]的IntTensor或LongTensor,内部元素为索引值,即0到num_embeddings-1之间的值
   输出为:[*, H]的嵌入向量,H为embedding_dim
''
  • 例如:
python 复制代码
from torch import nn
import torch


# an Embedding module containing 10 tensors of size 3
embedding = nn.Embedding(10, 3)
# a batch of 2 samples of 4 indices each
input = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])
print(embedding(input))
print(embedding.weight)
''
输出为:
tensor([[[ 0.4125,  0.1478,  0.3764],
         [ 0.5272, -0.4960,  1.5926],
         [ 0.2231, -0.7653, -0.5333],
         [ 2.8278,  1.5299,  1.4080]],

        [[ 0.2231, -0.7653, -0.5333],
         [-0.3996,  0.3626, -0.3369],
         [ 0.5272, -0.4960,  1.5926],
         [ 0.6222,  1.3385,  0.6861]]], grad_fn=<EmbeddingBackward>)
Parameter containing:
tensor([[-0.1316, -0.2370, -0.8308],
        [ 0.4125,  0.1478,  0.3764],
        [ 0.5272, -0.4960,  1.5926],
        [-0.3996,  0.3626, -0.3369],
        [ 0.2231, -0.7653, -0.5333],
        [ 2.8278,  1.5299,  1.4080],
        [-0.4182,  0.4665,  1.5345],
        [-1.2107,  0.3569,  0.9719],
        [-0.6439, -0.4095,  0.6130],
        [ 0.6222,  1.3385,  0.6861]], requires_grad=True)
''

2、使用场景

  • transformer decoder输入的嵌入向量Output Embedding
  • DETR中的decoder的object queries
相关推荐
陈文锦丫14 小时前
MixFormer: A Mixed CNN–Transformer Backbone
人工智能·cnn·transformer
Coding茶水间16 小时前
基于深度学习的安全帽检测系统演示与介绍(YOLOv12/v11/v8/v5模型+Pyqt5界面+训练代码+数据集)
图像处理·人工智能·深度学习·yolo·目标检测·计算机视觉
AI-智能17 小时前
别啃文档了!3 分钟带小白跑完 Dify 全链路:从 0 到第一个 AI 工作流
人工智能·python·自然语言处理·llm·embedding·agent·rag
adjusttraining21 小时前
毁掉孩子视力不是电视和手机,两个隐藏很深因素,很多家长并不知
深度学习·其他
操练起来1 天前
【昇腾CANN训练营·第八期】Ascend C生态兼容:基于PyTorch Adapter的自定义算子注册与自动微分实现
人工智能·pytorch·acl·昇腾·cann
ziwu1 天前
【宠物识别系统】Python+TensorFlow+Vue3+Django+人工智能+深度学习+卷积神经网络算法
人工智能·深度学习·图像识别
ziwu1 天前
海洋生物识别系统【最新版】Python+TensorFlow+Vue3+Django+人工智能+深度学习+卷积神经网络算法
人工智能·深度学习·图像识别
WWZZ20251 天前
快速上手大模型:深度学习12(目标检测、语义分割、序列模型)
深度学习·算法·目标检测·计算机视觉·机器人·大模型·具身智能
Ai173163915791 天前
2025.11.28国产AI计算卡参数信息汇总
服务器·图像处理·人工智能·神经网络·机器学习·视觉检测·transformer
浩浩的代码花园1 天前
自研端侧推理模型实测效果展示
android·深度学习·计算机视觉·端智能