机器学习详解(19):长短期记忆网络LSTM原理详解

在处理序列数据时,传统的神经网络往往难以捕捉长期依赖关系,这时循环神经网络(RNN)应运而生。然而,标准RNN在面对长序列时容易出现梯度消失或爆炸的问题,难以有效学习长期信息。为了解决这一难题,长短期记忆网络(Long Short-Term Memory,LSTM)被提出。LSTM 是一种特殊的 RNN 结构,它通过引入门控机制,有效保留关键信息并抑制无关内容,在自然语言处理、语音识别、时间序列预测等领域表现出色。

文章目录

  • [1 引入](#1 引入)
    • [1.1 RNN](#1.1 RNN)
    • [1.2 长期依赖问题](#1.2 长期依赖问题)
  • [2 LSTM网络](#2 LSTM网络)
  • [3 LSTM数学原理](#3 LSTM数学原理)
    • [3.1 核心思想](#3.1 核心思想)
    • [3.2 逐步解析](#3.2 逐步解析)
    • [3.3 LSTM的变体](#3.3 LSTM的变体)
  • [4 总结](#4 总结)

1 引入

1.1 RNN

在读文章时,我们会基于前面的内容来理解当前的词句,这种思维的延续性是人类理解的基础。而传统的神经网络并不具备这种能力,这限制了它们在处理序列数据(比如电影情节、语音、文本)时的表现。

为了解决这一问题,循环神经网络(Recurrent Neural Networks, RNN)被提出。RNN通过引入"循环结构",使得信息能够在时间步骤之间传递,从而保留上下文信息。虽然看起来神秘,但RNN其实可以理解为多个共享参数的神经网络副本,依次传递消息,适合处理序列和列表类型的数据。

如下图所示,一个神经网络单元 A A A接收某个输入 x t x_t xt并输出一个值 h t h_t ht。这个循环结构允许信息从网络的一个时间步传递到下一个时间步。

这其实与普通的神经网络并没有太大不同。一个循环神经网络可以被看作是同一个网络的多个副本,每个副本向下一个传递信息。思考一下当我们将这个循环展开时会发生什么:

  • x t x_t xt 表示在时间点 t t t 时刻输入给神经网络的一个数据(比如一句话里的一个词,一段语音的一帧等)。
  • A A A 是一个神经网络模块(可以理解为一层神经网络,它有自己的参数和计算方式)。
  • h t h_t ht 表示在时间点 t t t 时刻这个模块的输出,也可以理解为"这个时间点的理解/记忆"。
  • 黑色的弯曲箭头表示 输出 h t − 1 h_{t-1} ht−1 被反馈到下一次作为输入的一部分。

注意,在 RNN 中,前面的状态不是只影响下一步,而是会"间接影响所有后续步骤"。信息是逐步传递的,前面的状态对后面有长期影响。

正因为 RNN 是一步一步传信息,当序列很长时(比如 100 个词),最前面的 h 0 h_0 h0 影响 h 100 h_{100} h100 时可能就 "记不清了",我们把这个现象叫"长期依赖问题"(Long-Term Dependencies)。

1.2 长期依赖问题

RNN 的特性是它可能能够将之前的信息连接到当前任务中,例如使用前面的视频帧来帮助理解当前帧。但它真的能做到吗?这取决于情况。

有时候,我们只需要参考最近的信息就能完成当前任务。例如,考虑一个语言模型在生成句子时预测下一个词的情况。如果我们已经知道前面的部分是:"天空中飘着白云,太阳照耀着",那么即使只依靠这段近期上下文,我们也能较容易地判断下一个词可能是"很暖和"或"天气很好"等。这是因为上下文已经很清晰地暗示了接下来的内容。在这种情况下,相关信息与它被使用的位置之间的距离很短,RNN 通常能够学习到如何利用这些近期的信息。

但也有一些情况下,我们需要更多的上下文信息。比如一句话是:"我从小在法国长大......我说一口流利的______。" 仅仅根据最近的内容"我说一口流利的",我们可能猜测下一个词是一种语言,但如果想准确判断是哪种语言,就需要回到更早出现的"法国"这个信息。也就是说,真正关键的信息离它被使用的地方相隔较远。像这种相关信息与使用位置之间间隔较大的情况,就会让 RNN 很难学会正确地建立联系。

不幸的是,随着这个间隔变大,RNN 往往就无法学会如何连接这些信息。

从理论上讲,RNN 是具备处理"长期依赖"问题能力的。我们甚至可以通过人为地设置合适的参数,让它们在特定任务中取得不错的效果。但在实际应用中,RNN 往往难以自动学会这种长距离的信息关联。这种能力的缺失,被认为是 RNN 在处理复杂序列任务时的一个重要瓶颈。

好在,LSTM 网络很好地解决了这个问题!

2 LSTM网络

长短期记忆网络(Long Short Term Memory networks,通常简称为 "LSTM")是一种特殊类型的 RNN,能够学习长期依赖关系。LSTM 是为了解决长期依赖问题而专门设计的。对于它们来说,长期记忆是默认行为,而不是一种需要刻意学习的特性。

所有的循环神经网络本质上都是由一系列重复的神经网络模块组成的链式结构。在标准的 RNN 中,这种重复模块的结构非常简单,通常只有一层,例如一个 tanh ⁡ \tanh tanh 层。

  • tanh 层是 RNN 中常用的激活函数,用于生成隐藏状态,使模型具有非线性表达能力,并保持数值稳定。

LSTM 也具有这种链式结构,但其重复模块的结构更加复杂。与其只有一层神经网络不同,LSTM 模块中包含了四个子结构,它们之间以一种非常特别的方式进行交互。

在讲解 LSTM 的结构之前,我们先熟悉一下这里面的符号。

在上图中,每一条线代表一个完整的向量,从一个节点的输出传递到其他节点的输入。粉红色圆圈表示逐点操作(pointwise operation),例如向量加法;黄色方块表示神经网络中的学习层。直线表示数据流传递,分叉线表示内容被复制后分别传递到多个位置。

3 LSTM数学原理

3.1 核心思想

LSTM 的关键在于"单元状态"(cell state),也就是图中贯穿顶部的那条横线。

这个单元状态有点像一条传送带。它沿着整个链条一直向前传递,只在某些位置进行一些轻微的线性交互。信息可以非常容易地沿着这条路径几乎不变地流动。

LSTM 具有在单元状态中添加或移除信息 的能力,这一过程由一些被称为"门"(gates)的结构精细地调控。

门的作用是有选择地允许信息通过。它们由一个 sigmoid 神经网络层和一个逐点(pointwise)乘法操作组成。

sigmoid 层的输出值在 0 到 1 之间,表示每个成分通过的"程度"。输出为 0 意味着"完全阻止信息通过",输出为 1 则表示"让所有信息通过"。一个 LSTM 有三个这样的门,用来控制、保护和更新单元状态。

3.2 逐步解析

我们在 LSTM 中的第一步,是决定从单元状态中要丢弃哪些信息。这个决策是由一个被称为"遗忘门层"(forget gate layer)的 sigmoid 层完成的。它会查看 h t − 1 h_{t-1} ht−1(h为hidden state, 这里表示历史记忆) 和 x t x_t xt(新观察到的信息),并为单元状态 C t − 1 C_{t-1} Ct−1 中的每一个数输出一个 0 到 1 之间的值。值为 1 表示"完全保留这个信息",而值为 0 表示"完全丢弃这个信息"。

回到前面语言模型的例子,它需要根据已有的上下文来预测下一个词。在中文语境下,假设模型正在处理句子"这个男孩跑得很快,他......",那么单元状态中可能存储了"男孩"这个主语的信息,以便后续正确使用"他"这个代词。当接下来出现一个新主语,比如"女孩",我们就希望能够忘记"男孩"相关的信息,转而关注新的上下文内容。

接下来我们要决定将哪些新信息存入单元状态。这个过程包含两个部分:首先,一个称为"输入门层"的 sigmoid 层决定我们将更新哪些值;然后,一个 tanh 层会生成一组新的候选值 C ~ t \tilde{C}_t C~t,这些值可以被加入到状态中。在下一步中,我们会将这两个部分组合起来,对单元状态进行更新。

在语言模型的例子中,我们可能想把新主语的性别信息加入到单元状态中,以替换之前被遗忘的旧信息。

现在我们该把旧的单元状态 C t − 1 C_{t-1} Ct−1 更新为新的状态 C t C_t Ct 了。前面的步骤已经决定好了该做什么,我们现在只是实际执行它。

我们用 f t f_t ft 去乘以旧的状态,表示我们要忘记的那部分信息。接着我们加上 i t ⋅ C ~ t i_t \cdot \tilde{C}_t it⋅C~t,这部分是我们要加入的新候选值,它的每个分量都经过了输入门的调节。

在语言模型的例子中,这一步就是真正抛弃旧主语性别信息、并引入新主语性别信息的时刻。

最后一步是决定我们要输出什么。这个输出基于我们更新后的单元状态,但会经过一个"过滤"过程。首先,我们通过一个 sigmoid 层决定单元状态的哪些部分将被输出。然后,我们将新的单元状态通过 tanh,使其值压缩到 [ − 1 , 1 ] [-1, 1] [−1,1] 区间,并将其与 sigmoid 门的输出相乘,从而仅输出我们希望保留的部分。

在语言模型中,比如它刚刚看到一个主语,它可能希望输出一些与动词搭配相关的信息,例如主语是"他说"还是"他们说",这样模型就能正确生成后续的动词或语气助词,确保句子语法通顺。

个人理解

你可以把 LSTM 想象成你在对话时的大脑:

  • C t − 1 C_{t-1} Ct−1 就像是你上一时刻脑海中积累的长期记忆(比如你记得朋友昨天说过的话)。
  • h t − 1 h_{t-1} ht−1 是你上一次说出来的内容,是你短期表达的结果。
  • x t x_t xt 是你现在看到或听到的新信息,比如你此刻听到朋友问了你一个问题。

现在,你需要结合你听到的问题( x t x_t xt)、你刚刚说过什么( h t − 1 h_{t-1} ht−1)、还有你脑海里已有的知识( C t − 1 C_{t-1} Ct−1)来判断:

  1. 你要不要遗忘一些旧信息 ?(由遗忘门 f t f_t ft 决定,可能觉得朋友昨天说的那句现在不重要了)
  2. 你要不要学习一些新信息 ?(由输入门 i t i_t it 决定,比如朋友现在说的这个新问题)
  3. 最终形成了新的"知识状态" C t C_t Ct,即你现在的大脑记忆更新版。
  4. 但你不会把脑海里所有的想法都说出来,这时候输出门 o t o_t ot 决定你要表达哪些内容
  5. 最终,你说出了 h t h_t ht,这是你对当前局面基于记忆和输入的回应。

所以:

  • C t C_t Ct 是你内心全部思考的结果(包括过去记忆和当前理解)
  • h t h_t ht 是你对外说出来的一句话
  • 而这一句话( h t h_t ht)是根据你听到的信息( x t x_t xt)、你记得的内容( C t − 1 C_{t-1} Ct−1)、你之前说过的话( h t − 1 h_{t-1} ht−1),经过大脑处理之后,由输出门 o t o_t ot 精选出来的内容。

LSTM 就像你的大脑一样,一边保留长期记忆,一边随时根据上下文决定该记住什么、忘掉什么、说出什么。

3.3 LSTM的变体

前面我们讲述的是一种比较标准的 LSTM,事实上,几乎每篇涉及 LSTM 的论文都使用了稍有不同的版本。虽然这些差异通常很小,但仍然值得提一提其中的一些变体。

其中一个常见的变体是窥视孔连接(peephole connections)。意思是,我们允许门控层查看单元状态 C t C_t Ct。

上图将窥视孔连接添加到了所有的门上,但实际上有些论文只会为部分门添加窥视孔。

另一种变体是将遗忘门和输入门结合起来使用。我们不再分别决定"遗忘什么"和"加入什么新信息",而是把两个决策放在一起进行:我们只在决定忘记某些内容时,才用新内容去替代它。

还有一种结构上的更大变化是门控循环单元(Gated Recurrent Unit,GRU)。GRU 将遗忘门和输入门合并为一个"更新门",并合并了单元状态和隐藏状态,同时做了一些其他调整。这个模型比标准 LSTM 更简单,并且越来越受欢迎。

上述只是最常见的一些 LSTM 变体。实际上还有很多其他变体,比如 Depth-Gated RNN、Clockwork RNN 等等。有些方法甚至完全改变了处理长期依赖的问题方式,比如 Clockwork RNN 就是通过不同时间尺度的机制来处理长依赖。

那么,这些变体中哪一个最好?它们之间的差异真的重要吗?有人对各种变体进行了全面比较,发现它们其实差别并不大。甚至在某些任务中,RNN一些结构甚至比 LSTM 更好用。

4 总结

虽然用公式描述时,LSTM 的结构可能看起来复杂甚至有些吓人,但只要我们一步步梳理其内部逻辑,就会发现它其实并不难理解。它通过门控机制实现了对记忆的"读、写、忘"操作,构建了一种既能保留上下文又能灵活应变的序列建模能力。

但LSTM 并不是终点,注意力机制(Attention)不再依赖固定的上下文状态,而是**"主动选择"**需要关注的信息来源。例如,在图像字幕生成任务中,模型在生成每一个词时,都能自动聚焦图像中不同的区域。这种自由灵活的对齐方式,在自然语言处理和多模态任务中都展现出巨大潜力。

随着 Transformer 的提出和普及,Attention 的重要性日益凸显,或许,理解 LSTM 正是你探索深度学习更远世界的第一步!

相关推荐
DUTBenjamin1 分钟前
计算机视觉6——相机基础
人工智能·计算机视觉
睡觉zzz14 分钟前
React写ai聊天对话,如何实现聊天makedown输出转化
前端·人工智能·react.js
deephub16 分钟前
PyTorch CUDA内存管理优化:深度理解GPU资源分配与缓存机制
人工智能·pytorch·python·深度学习·英伟达
闰土_RUNTU17 分钟前
机器学习中的数学(PartⅡ)——线性代数:2.2矩阵
线性代数·机器学习·矩阵
菜小包17 分钟前
通义万相 vs 豆包:AI领域文生图/文生视频全面对比
人工智能
Lx35218 分钟前
📌《AI生成代码的边界测试:哪些场景人类仍需主导》
人工智能
Qiu的博客27 分钟前
一文读懂 AI
人工智能·算法·开源
孔令飞28 分钟前
Go 1.24 新方法:编写性能测试用例方法 testing.B.Loop 介绍
人工智能·云原生·go
pen-ai31 分钟前
【NLP】18. Encoder 和 Decoder
人工智能·自然语言处理
Elastic 中国社区官方博客40 分钟前
Elasticsearch:使用稀疏向量提升相关性
大数据·数据库·人工智能·elasticsearch·搜索引擎·ai·全文检索