【PyTorch单点知识】torch.nn.Embedding模块介绍:理解词向量与实现

文章目录

      • [0. 前言](#0. 前言)
      • [1. 基础介绍](#1. 基础介绍)
        • [1.1 基本参数](#1.1 基本参数)
        • [1.2 可选参数](#1.2 可选参数)
        • [1.3 属性](#1.3 属性)
        • [1.4 PyTorch源码注释](#1.4 PyTorch源码注释)
      • [2. 实例演示](#2. 实例演示)
      • [3. `embedding_dim`的合理设定](#3. embedding_dim的合理设定)
      • [4. 结论](#4. 结论)

0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

在自然语言处理(NLP)中,torch.nn.Embedding是PyTorch框架中一个至关重要的模块,用于将离散的词汇转换成连续的向量空间表示。这种转换允许模型捕捉词汇之间的语义关系,并在诸如情感分析、文本分类和机器翻译等任务中发挥关键作用。

本文将深入探讨torch.nn.Embedding的工作原理,并通过示例代码演示其在PyTorch中的使用。

1. 基础介绍

torch.nn.Embedding的本质是一个映射表(Lookup table) ,它用于储存自然语言词典嵌入向量的映射关系。

1.1 基本参数

torch.nn.Embedding的初始化接受两个基本参数:num_embeddingsembedding_dim

  • num_embeddings:这个参数直观理解为"要嵌入的自然语言的词汇数量",表示上面所述的自然语言词典 的大小,即可能的唯一词汇数量。比如英语中的常用单词,从abandon开始一共有3000个,那num_embeddings就可以设定为3000;
  • embedding_dim:表示每个词汇映射的嵌入向量的维度。
1.2 可选参数
  • padding_idx:用于指定词汇表中的填充词汇索引,该位置的向量将被初始化为零。
  • max_norm:用于限制嵌入向量的L2范数。
  • norm_type:用于指定范数类型。
  • scale_grad_by_freq:如果设置为True,则将梯度按词汇频率缩放。
  • sparse:如果设置为True,则将嵌入梯度标记为稀疏。
1.3 属性

torch.nn.Embedding 模块只有一个属性 weight。这个属性代表了嵌入层要学习的权重,即存储所有嵌入向量的矩阵。这是嵌入层的学习权重,形状为 (num_embeddings, embedding_dim),也就是上文所说的lookup table映射表。这些权重代表实际的嵌入向量,它们是可学习的参数,并且在训练过程中会被优化算法更新。默认情况下,weight 是从标准正态分布 N(0, 1) 随机初始化的。这意味着每个元素都独立地从均值为 0、标准差为 1 的正态分布中采样。

1.4 PyTorch源码注释

以下是nn.Embedding的源码注释,用于上面说明的参考:

python 复制代码
Args:
        num_embeddings (int): size of the dictionary of embeddings
        embedding_dim (int): the size of each embedding vector
        padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient;
                                     therefore, the embedding vector at :attr:`padding_idx` is not updated during training,
                                     i.e. it remains as a fixed "pad". For a newly constructed Embedding,
                                     the embedding vector at :attr:`padding_idx` will default to all zeros,
                                     but can be updated to another value to be used as the padding vector.
        max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
                                    is renormalized to have norm :attr:`max_norm`.
        norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
        scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse of frequency of
                                                the words in the mini-batch. Default ``False``.
        sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.
                                 See Notes for more details regarding sparse gradients.

    Attributes:
        weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
                         initialized from :math:`\mathcal{N}(0, 1)`

    Shape:
        - Input: :math:`(*)`, IntTensor or LongTensor of arbitrary shape containing the indices to extract
        - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}`

2. 实例演示

这里我将给出一个简单的例子来说明如何使用 PyTorch 的 torch.nn.Embedding 模块创建一个嵌入层,并获取一些单词的嵌入向量。

假设我们有一个小型的词汇表,包含以下单词:

  • "the"
  • "cat"
  • "dog"
  • "sat"
  • "on"
  • "mat"

我们将这些单词映射到索引上,例如:

  • "the" -> 0
  • "cat" -> 1
  • "dog" -> 2
  • "sat" -> 3
  • "on" -> 4
  • "mat" -> 5

现在我们可以创建一个 torch.nn.Embedding 层,将这些单词映射到嵌入向量中。我们将使用一个 3 维的嵌入向量来表示每个单词。

下面是具体的代码示例:

python 复制代码
import torch
import torch.nn as nn

# 创建一个 Embedding 层
# num_embeddings: 词汇表的大小,这里是 6
# embedding_dim: 嵌入向量的维度,这里是 3
embedding = nn.Embedding(num_embeddings=6, embedding_dim=3)

# 定义一些单词的索引
word_indices = torch.LongTensor([0, 1, 2, 3, 4, 5])  # "the", "cat", "dog", "sat", "on", "mat"

# 通过索引获取嵌入向量
word_embeddings = embedding(word_indices)

# 输出嵌入向量
print(word_embeddings)

运行上述代码后,word_embeddings 将是一个形状为 (6, 3) 的张量,其中每一行代表一个单词的嵌入向量。

python 复制代码
tensor([[ 0.0439,  0.7314, -0.3546],
        [ 0.6975,  1.2725,  1.4042],
        [-1.7532, -2.0642, -0.1434],
        [ 0.2538,  1.1123, -0.8636],
        [-0.7238, -0.0585,  0.5242],
        [ 0.6485,  0.6885, -1.2045]], grad_fn=<EmbeddingBackward0>)

例如,word_embeddings[0] 对应于单词 "the" 的嵌入向量,word_embeddings[1] 对应于单词 "cat" 的嵌入向量,以此类推。

这就是一个简单的英语单词嵌入向量的例子。在实际应用中,词汇表会更大,嵌入向量的维度也会更高,而且通常会使用预训练的嵌入向量来初始化这些权重。

3. embedding_dim的合理设定

通过上文说明,我们可以轻松地掌握nn.Embedding模块的使用,但是这里有个问题:embedding_dim设定为多少比较合适呢?

这里首先要说明下嵌入向量:它应该是代表单词"语义"的向量,而不是像one-hot那样是简单的字母映射。

举个例子:meetmeat两个词,拼写十分接近,即它们的one-hot编码十分接近,但是它们的语义完全不同,也就是说嵌入向量应该相差很远。而hugeenormous情况刚好相反,它们的one-hot编码完全不同,而嵌入向量应该比较接近。

那回到embedding_dim的设定选择上来,我觉得可以参考以下3个方面来设定比较合理的embedding_dim

  1. 平衡信息量与过拟合风险
    • 信息量: 较高的 embedding_dim 可以捕获更多的信息和细微差别,从而提高模型的表达能力。然而,这也可能会导致过拟合,因为高维空间容易出现稀疏性问题。
    • 过拟合风险: 较低的 embedding_dim 可以减少参数数量,降低过拟合的风险,但可能会丢失一些信息。
  2. 考虑词汇表的大小
    • 较小的词汇表: 如果词汇表相对较小(例如几千个词),较低的 embedding_dim(如 50 或 100)可能就足够了。
    • 较大的词汇表: 对于较大的词汇表(例如几十万或更多),可以选择较高的 embedding_dim(如 200 至 500)以更好地捕捉语义信息。
  3. 实验验证
    • 交叉验证: 最终的选择通常需要通过实验来确定。使用交叉验证来评估不同 embedding_dim 下的模型性能,可以帮助找到最佳值。
    • 预训练嵌入: 如果有可用的预训练嵌入(如 Word2Vec、GloVe 或 FastText),可以考虑使用它们的维度作为参考。

一点点思考:在Embedding方法中,embedding_dim一般是要比num_embeddings小(很多)的,这会导致矩阵的秩不满,最终会导致Embedding方法中的单词可以通过线性变换变成另一个单词。比如把abandon的词向量×2得到get的词向量,而one-hot不会有这个问题,这是Embedding小小的局限性。

4. 结论

torch.nn.Embedding模块在PyTorch中为NLP任务提供了强大的工具,允许模型从词汇索引中学习有意义的向量表示。通过初始化和调用这个模块,我们可以轻松地将文本数据转换为适合深度学习模型的格式,从而挖掘文本数据中的丰富语义信息。

相关推荐
2403_8757368711 分钟前
道品科技智慧农业中的自动气象检测站
网络·人工智能·智慧城市
海阔天空_201324 分钟前
Python pyautogui库:自动化操作的强大工具
运维·开发语言·python·青少年编程·自动化
零意@32 分钟前
ubuntu切换不同版本的python
windows·python·ubuntu
学术头条35 分钟前
AI 的「phone use」竟是这样练成的,清华、智谱团队发布 AutoGLM 技术报告
人工智能·科技·深度学习·语言模型
准橙考典36 分钟前
怎么能更好的通过驾考呢?
人工智能·笔记·自动驾驶·汽车·学习方法
ai_xiaogui39 分钟前
AIStarter教程:快速学会卸载AI项目【AI项目管理平台】
人工智能·ai作画·语音识别·ai写作·ai软件
思忖小下43 分钟前
Python基础学习_01
python
孙同学要努力44 分钟前
《深度学习》——深度学习基础知识(全连接神经网络)
人工智能·深度学习·神经网络
q567315231 小时前
在 Bash 中获取 Python 模块变量列
开发语言·python·bash
是萝卜干呀1 小时前
Backend - Python 爬取网页数据并保存在Excel文件中
python·excel·table·xlwt·爬取网页数据