人人都能看懂的长短期记忆网络(LSTM)解析

本文从RNN的缺陷入手,通过大量的图片讲解了LSTM的结构,核心思想,计算过程,并讲解了几种LSTM的变体。

传统NN的缺陷

在自然语言处理中,每一个词语的意思主要由词语的上下文决定,因此,在处理文本中的每一个词语时必须要考虑上词语上下文的信息,但是传统的神经网络却不能做到这点,传统的NN其输入都是相互无关的,输出只能由输入本身的信息所决定,而无法让一条输入数据与其之前/之后的数据相关联。

递归神经网络(RNN)

递归神经网络(RNNs)应运而生,解决了这一难题。这类网络中包含了循环结构,使信息能够被保存下来。如下图所示:

在递归神经网络中,网络的某个部分 <math xmlns="http://www.w3.org/1998/Math/MathML"> A A </math>A 会处理输入 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt 并产出结果 <math xmlns="http://www.w3.org/1998/Math/MathML"> h t h_t </math>ht。同时 <math xmlns="http://www.w3.org/1998/Math/MathML"> A A </math>A计算出的信息也会作为输入,这种循环结构确保信息能从网络的一个阶段传递到下一个阶段。

RNNs和普通神经网络其实差别不大。递归神经网络就像是为一个网络复制出多个副本,每个副本都将计算出的信息传递给下一个网络。如果我们展开这个循环,如下图所示:

展开后的递归神经网络呈链式结构,网络间如同有序序列一样,因此,RNN极其适合处理有序序列问题,比如时序问题等。

RNN的缺陷:长序列遗忘

RNNs的优势在于它能将一个序列中过去的信息与当前的任务挂钩,比如基于已有的文本预测下一个词。当序列比较短的时候,RNN可以做到很好的预测,比如如果要预测"太阳从__升起"这句话中缺少的词,RNN能快速的填入东边。在相关信息与当前预测位置之间距离较近的情况下,RNNs能够良好的利用过去的信息。

然而,假设我们要进行长序列预测,比如预测"我在法国长大...(中间省略300字)...我说流利的法语。"这段话的最后一个词。附近的信息让我们倾向于认为下一个词可能是某种语言的名字,但要精确到是哪种语言,就需要从"法国"这一更早的背景信息中寻找线索。即相关信息与其被需要的位置隔得太远时,RNN就难以进行处理了。

为了解决长序列遗忘问题,LSTM产生了。

LSTM

长短时记忆网络(LSTMs)是递归神经网络(RNNs)的一种特殊形式,其擅长处理长期的数据依赖关系。即相对于RNN,其能够处理更长的序列,注意到更远的相关信息。

LSTM和RNN在模型结构上的区别

所有的递归神经网络都由一系列重复的神经网络模块组成。在标准的 RNN 中,这个重复模块结构很简单:只有一个 tanh 激活层。

标准RNN的重复模块。 相比之下,LSTMs 保持了这种链式结构,但其重复的模块结构却大不相同。它们不是只有一个神经网络层,而是包含四个相互配合的层。

下一节将解读LSTM的工作原理。在此之前首先先定义好所用的符号:

在上述结构图中,每一条线表示一个传输中的向量 ,这个向量沿着箭头所指的方向从一个单元的输出传递到另一个单元的输入。粉红色圆圈代表针对向量的逐元素运算操作 ,比如向量加法(两个向量中每个元素对应相加)。黄色的方块代表训练好的神经网络层线条的合并代表了向量的拼接分叉则表示将输入的向量复制并分发到不同的运算

由此,传统的RNN就是将两个输入向量 <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x和 <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h拼接在一起而后输入tanh层中,而LSTM就复杂得多了。

LSTM核心idea

LSTM的关键概念只有两个:细胞状态和门。

首先说明细胞状态(cell state),如下图所示,细胞状态指的是在图中顶部自始至终横贯的线。细胞状态像是一个运送带,其在网络中只会进行简单的线性运算,让历史信息轻松流转而不受干扰。

LSTM 通过被称为"门"(gate)的结构,精准地向细胞状态添加或移除信息。这些门可以根据需要选择性地传递信息。门由一个输出范围在 0 到 1 之间的 sigmoid 神经网络层和逐点乘法操作组成。

这些门的决策过程涉及两个步骤:

  1. 一个sigmoid神经网络层计算出一组在0和1之间的值。这些值确定了输入信息的每个部分应该被允许通过的程度。0表示"不允许任何信息通过",而1则表示"允许所有信息通过"。

  2. sigmoid层的结果随后通过逐点乘法应用于信息 <math xmlns="http://www.w3.org/1998/Math/MathML"> c t − 1 c_{t-1} </math>ct−1。这一步实际上是缩放信息的组成部分,根据门控信号要么抑制它们,要么完全让它们通过。

在一个LSTM网络中,通常有三个这样的门:

  • 遗忘门决定细胞状态的哪些部分不再需要并且可以被抛弃。
  • 输入门控制新信息加入到细胞状态中。
  • 输出门确定当前细胞状态的哪一部分应该被用来计算输出。

而LSTM的计算过程为:输入信息先通过遗忘门和输入门,而后计算出的信息用于状态更新,而后将所有信息合并在一起计算出输出。接下来就对于各个部分进行讲解。

遗忘门

LSTM的首要步骤是确定从细胞状态中要舍弃哪些信息。这个选择是通过一个名为"遗忘门层"的sigmoid层来实施的,它检视 <math xmlns="http://www.w3.org/1998/Math/MathML"> h t − 1 h_{t−1} </math>ht−1 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_{t} </math>xt 的信息,并为细胞状态 <math xmlns="http://www.w3.org/1998/Math/MathML"> C t − 1 C_{t−1} </math>Ct−1 中每个数分配一个0到1之间的值。1代表"完全保留",0则代表"彻底抛弃"。

简单地说,遗忘门可以类比为一个信息的过滤器,它决定哪些旧信息是过时的,而后在更新状态时将其抛弃。

举个例子,假设我们的语言模型要基于之前输入的所有话预测一句话中的下一个词,细胞状态可能包含当前主语的性别信息,以确保使用正确的代词。当出现新主语时,我们需要先忘记曾经记忆的旧主语的性别信息。

如上图所示,遗忘门中的运算就是:将当前的输入 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt 和上一次的输出 <math xmlns="http://www.w3.org/1998/Math/MathML"> h t − 1 h_{t−1} </math>ht−1 这两个向量拼接在一起,经过一个线性层后,再经过一个sigmoid运算,得到一个权重值。这个权重确定旧序列的历史信息中哪些信息需要去除。

输入门

随后的步骤是决定我们将在细胞状态中保存哪些新信息。其步骤如下:

  1. 由"输入门层"的sigmoid层确定哪些值需要更新,计算出 <math xmlns="http://www.w3.org/1998/Math/MathML"> i t i_t </math>it;
  2. 通过tanh层生成新的可能加入状态的候选值向量 <math xmlns="http://www.w3.org/1998/Math/MathML"> C ^ t \hat C_t </math>C^t。
  3. 接下来,我们会合并这两部分信息来更新状态。

输入门与候选值向量的协同工作可以被看作是一个决策过程,确定新的、有价值的信息如何整合到细胞状态中。这个过程不仅仅是简单地添加信息,而是根据当前的输入和过去的状态精细地调整和选择信息。

同样以词语预测模型为例,我们需要在这一步计算出了需要添加到细胞状态 <math xmlns="http://www.w3.org/1998/Math/MathML"> C t − 1 C_{t-1} </math>Ct−1 的新主语的性别信息。

状态更新

细胞状态的更新是LSTM的核心,使其有能力在长时间跨度内保持信息。此过程确保了网络能够记住重要的信息并遗忘不必要的信息,从而解决了传统RNNs中的长序列遗忘问题。

接下来,我们将过时的细胞状态 <math xmlns="http://www.w3.org/1998/Math/MathML"> C t − 1 C_{t−1} </math>Ct−1 更新为新的细胞状态 <math xmlns="http://www.w3.org/1998/Math/MathML"> C t C_t </math>Ct。先前的步骤已经计算出了更新所需要的向量。我们通过 <math xmlns="http://www.w3.org/1998/Math/MathML"> f t f_t </math>ft 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> C t − 1 C_{t-1} </math>Ct−1 进行向量点积来缩减旧状态,放弃之前决定不再保留的信息。然后加入 <math xmlns="http://www.w3.org/1998/Math/MathML"> i t ∗ C ^ t i_t∗ \hat C_t </math>it∗C^t------将新候选值与状态更新向量 <math xmlns="http://www.w3.org/1998/Math/MathML"> i t i_t </math>it 相乘。

在语言模型的情境中,这一步骤对应我们遗忘旧主语性别信息并添加新主语性别信息过程。

输出门

最后,我们需要确定输出内容 <math xmlns="http://www.w3.org/1998/Math/MathML"> h t h_t </math>ht。这个输出所需要的输入包括细胞状态 <math xmlns="http://www.w3.org/1998/Math/MathML"> C t − 1 C_{t-1} </math>Ct−1 和利用输入 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt 和上一轮输出 <math xmlns="http://www.w3.org/1998/Math/MathML"> h t − 1 h_{t-1} </math>ht−1 所计算出的 <math xmlns="http://www.w3.org/1998/Math/MathML"> a t a_t </math>at,历史信息 <math xmlns="http://www.w3.org/1998/Math/MathML"> C t − 1 C_{t-1} </math>Ct−1 经过过滤后和当前信息融合,得到输出,其步骤包括:

  1. 通过一个sigmoid层确定细胞状态中哪些部分将被输出。
  2. 将细胞状态通过tanh处理(使值位于-1至1之间)并乘以sigmoid门的输出。从而得到最终输出。

LSTM变体

迄今为止讲的都是是标准的LSTM模型。然而,在LSTM模型的基础上衍生了各种变体。

  • Gers和Schmidhuber在2000年提出了一个流行的LSTM变种,即增加了"窥视孔连接"。这样做允许门层能够直接观察细胞状态。如下图所示:其将历史信息/之前的细胞状态也加入了运算之中。
  • 有一个变体耦合遗忘门和输入门。该篇论文认为与其分开决定哪些信息该遗忘,哪些新信息该加入,不如将这两个决策过程合二为一。我们仅在有新信息加入时才遗忘旧信息,仅在丢弃旧信息时才加入新信息。
  • Cho等人在2014年提出GRU(门控循环单元)。它将遗忘门和输入门合二为一,称为"更新门"。同时,它还将细胞状态和隐藏状态合并,并作出了其他一些更改。最终得到比标准的LSTM模型简单的模型。

不过,Greff等人在2015年的研究中对流行的几种LSTM的变体进行了比较,结果表明它们基本上效果相同。

相关推荐
AI量化投资实验室25 分钟前
deap系统重构,再新增一个新的因子,年化39.1%,卡玛提升至2.76(附python代码)
大数据·人工智能·重构
张登杰踩33 分钟前
如何快速下载Huggingface上的超大模型,不用梯子,以Deepseek-R1为例子
人工智能
AIGC大时代34 分钟前
分享14分数据分析相关ChatGPT提示词
人工智能·chatgpt·数据分析
TMT星球1 小时前
生数科技携手央视新闻《文博日历》,推动AI视频技术的创新应用
大数据·人工智能·科技
AI视觉网奇1 小时前
图生3d算法学习笔记
人工智能
天乐敲代码1 小时前
JAVASE入门九脚-集合框架ArrayList,LinkedList,HashSet,TreeSet,迭代
java·开发语言·算法
十年一梦实验室1 小时前
【Eigen教程】矩阵、数组和向量类(二)
线性代数·算法·矩阵
Kent_J_Truman2 小时前
【子矩阵——优先队列】
算法
小锋学长生活大爆炸2 小时前
【DGL系列】dgl中为graph指定CSR/COO/CSC矩阵格式
人工智能·pytorch·深度学习·图神经网络·gnn·dgl
佛州小李哥2 小时前
在亚马逊云科技上用AI提示词优化功能写出漂亮提示词(上)
人工智能·科技·ai·语言模型·云计算·aws·亚马逊云科技