PyTorch 中nn.Embedding

核心参数与用法

nn.Embedding的核心参数:

num_embeddings:嵌入表的大小(即离散特征的总类别数,如词汇表大小)。

embedding_dim:每个嵌入向量的维度(输出向量的长度)。

padding_idx(可选):指定一个索引,其对应的嵌入向量将始终为 0(用于处理填充符号)。

css 复制代码
import torch
import torch.nn as nn

# 定义嵌入层:词汇表大小为10(索引0-9),嵌入维度为3
embedding = nn.Embedding(num_embeddings=10, embedding_dim=3)

# 输入:形状为(batch_size, seq_len)的整数张量(索引必须在[0, num_embeddings-1]范围内)
input_indices = torch.tensor([[1, 3, 5], [2, 4, 6]])  # 批量大小为2,序列长度为3

# 前向传播:获取嵌入向量
output_embeddings = embedding(input_indices)

print("输入形状:", input_indices.shape)  # 输出:torch.Size([2, 3])
print("输出形状:", output_embeddings.shape)  # 输出:torch.Size([2, 3, 3])(每个索引被映射为3维向量)
print("输出内容:\n", output_embeddings)
css 复制代码
输入形状: torch.Size([2, 3])
输出形状: torch.Size([2, 3, 3])
输出内容:
 tensor([[[ 0.5095,  0.3979, -1.7759],
         [-0.1456,  1.6262,  0.3929],
         [ 0.8530, -0.6685,  1.6823]],

        [[ 1.0323, -0.0969, -0.6512],
         [ 0.2309, -1.5649,  0.7431],
         [-0.3285, -0.2512, -0.1028]]], grad_fn=<EmbeddingBackward0>)
Parameter containing:
tensor([[-1.8749,  0.2108,  0.4401],
        [ 0.5095,  0.3979, -1.7759],
        [ 1.0323, -0.0969, -0.6512],
        [-0.1456,  1.6262,  0.3929],
        [ 0.2309, -1.5649,  0.7431],
        [ 0.8530, -0.6685,  1.6823],
        [-0.3285, -0.2512, -0.1028],
        [-0.1919,  0.2022, -0.2425],
        [-0.7266,  1.3337, -0.7980],
        [ 0.0791, -0.7093,  0.2264]], requires_grad=True)
相关推荐
华农DrLai4 小时前
什么是LLM做推荐的三种范式?Prompt-based、Embedding-based、Fine-tuning深度解析
人工智能·深度学习·prompt·transformer·知识图谱·embedding
高洁015 小时前
多模态AI模型融合难?核心问题与解决思路
人工智能·深度学习·机器学习·数据挖掘·transformer
renhongxia16 小时前
ORACLE-SWE:量化Oracle 信息信号对SWE代理的贡献
人工智能·深度学习·学习·语言模型·分类
weixin_156241575769 小时前
基于YOLOv8深度学习花卉识别系统摄像头实时图片文件夹多图片等另有其他的识别系统可二开
大数据·人工智能·python·深度学习·yolo
QQ676580089 小时前
AI赋能轨道交通智能巡检 轨道交通故障检测 轨道缺陷断裂检测 轨道裂纹识别 鱼尾板故障识别 轨道巡检缺陷数据集深度学习yolo第10303期
人工智能·深度学习·yolo·智能巡检·轨道交通故障检测·鱼尾板故障识别·轨道缺陷断裂检测
云程笔记9 小时前
002.计算机视觉与目标检测发展简史:从传统方法到深度学习
深度学习·yolo·目标检测·计算机视觉
weixin_156241575769 小时前
基于YOLO深度学习的动物检测与识别系统
人工智能·深度学习·yolo
叶舟10 小时前
LYT-NET:一个超级轻量的低光照图像增强Transformer网络
人工智能·深度学习·transformer·llie·低光照图像增强
管二狗赶快去工作!10 小时前
体系结构论文(九十八):NPUEval: Optimizing NPU Kernels with LLMs and Open Source Compilers
人工智能·深度学习·自然语言处理·体系结构
LaughingZhu10 小时前
Product Hunt 每日热榜 | 2026-04-10
人工智能·经验分享·深度学习·神经网络·产品运营