PyTorch nn.Embedding() 嵌入详解

在对文本序列进行分词(tokenize)并映射后,字符串序列就转变为了数字(token id)序列,这些 token id 可以直接输入到模型中,但需要明白的是,模型并不能直接从一个纯粹的数字中获取丰富的信息。类比到人类的认知,我们理解一个字或词并不是仅靠符号,而是其背后的含义。

nn.Embedding 嵌入层

torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None, _freeze=False, device=None, dtype=None)

A simple lookup table that stores embeddings of a fixed dictionary and size.

一个简单的查找表,用于存储固定大小的字典中每个词的嵌入向量。

参数

  • num_embeddings (int): 嵌入字典的大小,即词汇表的大小 (vocab size)。
  • embedding_dim (int): 每个嵌入向量的维度大小。
  • padding_idx (int, 可选): 指定填充对应的索引值。该索引对应的嵌入向量在训练过程中不会更新,即梯度不参与反向传播,通常作为"填充"标记使用。对于新构建的 Embedding 模块,此索引的嵌入向量默认值为全零,但可以更改为其他值。
  • max_norm (float, 可选): 如果设置,超过此值的嵌入向量范数将被重新归一化,使其最大范数等于 max_norm
  • norm_type (float, 可选): 用于计算 max_norm 的 p-范数,默认为 2,即计算 2 范数。
  • scale_grad_by_freq (bool, 可选): 如果为 True,梯度将根据单词在 mini-batch 中的频率的倒数进行缩放,适用于高频词的梯度调整。默认为 False
  • sparse (bool, 可选): 如果设置为 True,则权重矩阵的梯度为稀疏张量,适合大规模词汇表的内存优化。

变量

  • weight (Tensor): 模块的可学习权重,形状为 (num_embeddings, embedding_dim),初始值从正态分布 N(0, 1) 中采样。

方法

from_pretrained(embeddings, freeze=True, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False)

Create Embedding instance from given 2-dimensional FloatTensor.

用于从给定的 2 维浮点张量(FloatTensor)创建一个 Embedding 实例。

参数

  • embeddings (Tensor): 一个包含嵌入权重的 FloatTensor。第一个维度代表 num_embeddings(词汇表大小),第二个维度代表 embedding_dim(嵌入向量维度)。
  • freeze (bool, 可选): 如果为 True,则嵌入张量在训练过程中保持不变,相当于设置 embedding.weight.requires_grad = False。默认值为 True
  • 其余参数参考之前定义。

要点示例未完待续...(预计 11.6 前上传)

QA

Q1:对于神经网络来说,什么是"符号"及其"背后的含义"?

答案是:Token IDEmbedding

那么,什么是 Embedding?

我们可以通过 PyTorch 中的 nn.Embedding 类来理解它,先跳过繁琐的介绍,运行代码来直观感受:

python 复制代码
import torch
import torch.nn as nn

# 设置随机种子以确保结果可复现
torch.manual_seed(42)

# 定义嵌入层参数
num_embeddings = 5  # 假设词汇表中有 5 个 token
embedding_dim = 3   # 每个 token 对应 3 维嵌入向量

# 初始化嵌入层
embedding = nn.Embedding(num_embeddings, embedding_dim)

# 定义整数索引
input_indices = torch.tensor([0, 2, 4])

# 查找嵌入向量
output = embedding(input_indices)

# 打印结果
print("权重矩阵:")
print(embedding.weight.data)
print("\nEmbedding 输出:")
print(output)

输出:

复制代码
权重矩阵:
tensor([[ 0.3367,  0.1288,  0.2345],
        [ 0.2303, -1.1229, -0.1863],
        [ 2.2082, -0.6380,  0.4617],
        [ 0.2674,  0.5349,  0.8094],
        [ 1.1103, -1.6898, -0.9890]])

Embedding 输出:
tensor([[ 0.3367,  0.1288,  0.2345],
        [ 2.2082, -0.6380,  0.4617],
        [ 1.1103, -1.6898, -0.9890]], grad_fn=<EmbeddingBackward0>)

在这里,input_indices = [0, 2, 4] 从权重矩阵中选择第 0、2 和 4 行作为对应的嵌入表示。是的没错,Embedding 的获取就是这么简单。

接下来,构建一个 Embedding 类进行理解:

python 复制代码
class Embedding():
    def __init__(self, num_embeddings, embedding_dim):
        self.weight = torch.nn.Parameter(torch.randn(num_embeddings, embedding_dim))
        
    def forward(self, indices):
        return self.weight[indices]  # 没错,就是返回对应的行

可以看出,Embedding 类的本质是一个查找表(lookup table)。在上面的示例中,embedding.weight 中存储了 5 个(num_embeddings)嵌入向量,每个向量有 3 个维度(embedding_dim)。当提供 input_indices 时,查找表返回对应的嵌入向量(权重矩阵的行)。

Q2: 最初的权重矩阵是什么?最终的嵌入向量由什么决定?

最初的权重矩阵是一般随机初始化的,在训练过程中会更新权重,使其能有效地表达背后的含义。

Q3: 什么是语义?

举个简单的例子来理解"语义"关系:像"猫"和"狗"在向量空间中的表示应该非常接近,因为它们都是宠物;"男人"和"女人"之间的向量差异可能代表性别的区别。此外,不同语言的词汇,如"男人"(中文)和"man"(英文),如果在相同的嵌入空间中,它们的向量也会非常接近,反映出跨语言的语义相似性。同时,【"女人"和"woman"(中文-英文)】与【"男人"和"man"(中文-英文)】之间的差异也可能非常相似。

本文"狭义"地解读了与 Token id 一起出现的 Embedding,这个概念在自然语言处理(NLP)中有着更具体的称呼:Word Embedding。

相关推荐
盼小辉丶25 分钟前
PyTorch强化学习实战(11)——N步DQN(N-step DQN)
pytorch·python·深度学习·强化学习
星越华夏2 小时前
深度学习项目实战:基于PyTorch的图像分类与目标检测(YOLOv8)
pytorch·深度学习·yolo·分类
imDwAaY2 小时前
从感知机到 Attention:我用 PyTorch 打穿 CS188 机器学习终章 CS188 Proj5 学习笔记
人工智能·pytorch·笔记·python·学习·机器学习
zlkingdom20 小时前
Jetson Orin开发板,在conda环境中直接实现Pytorch的GPU加速
人工智能·pytorch·conda·随笔·jetson orin
月疯1 天前
PyTorch 中定义了一个 LeakyReLU 激活函数层
人工智能·pytorch·python
Y学院1 天前
PyTorch深度学习框架核心概念精讲
人工智能·pytorch·深度学习
zhangfeng11331 天前
联邦学习 合并权重 合并权重。导致内存溢出解决办法和类库 mergekit 包依赖版本
人工智能·pytorch·机器学习
nashane1 天前
HarmonyOS 6学习:应用无响应(AppFreeze)故障排查与性能优化指南
人工智能·pytorch·python
zhangfeng11332 天前
超算/曙光DCU集群 昆山站 根目录文件夹逐项释义(HTC调度集群环境、国产DCU算力节点)
人工智能·pytorch·机器学习
zhangfeng11332 天前
国家超算中心 htc 如果只有gpu资源 没有cpu资源 操作文件的时候会不会很卡呢
人工智能·pytorch·python·机器学习