LSTM长短时记忆网络【数学+图解】

文章目录

🍃作者介绍:双非本科大三网络工程专业在读,阿里云专家博主,专注于Java领域学习,擅长web应用开发、数据结构和算法,初步涉猎人工智能和前端开发。

🦅个人主页:@逐梦苍穹

📕所属专栏:人工智能

🌻gitee地址:xzl的人工智能代码仓库

✈ 您的一键三连,是我创作的最大动力🌹

1、简介

长短时记忆网络(LSTM)和门控循环单元(GRU)是循环神经网络(RNN)的两种改进变体。

它们通过引入 门控机制 解决了RNN在处理长序列时的梯度消失和梯度爆炸问题。

先复习一下RNN:https://xzl-tech.blog.csdn.net/article/details/140940642

有兴趣可以继续学习GRU:https://xzl-tech.blog.csdn.net/article/details/140940794

2、门控机制

  1. 门控机制的基本思想是使用"门"来控制信息在网络中的流动。
  2. 每个门都是通过 神经网络层计算出来的 权重向量 ,其值通常在 0到1之间
  3. 不同的门在不同 时间步 上控制信息的选择、遗忘和更新。
  4. 这些门是通过可学习的参数在训练过程中自动调整的。

3、LSTM

LSTM:Long Short-Term Memory

3.1、概念

LSTM是一种 特殊的RNN结构,它通过引入 记忆单元 门控机制来控制信息的流动,以此解决长时依赖问题。
LSTM网络包含一个称为
记忆单元
(cell state)的特殊单元,用于维护跨越时间步的长期信息

记忆单元通过三种 (门控机制)来控制信息的更新:

  1. 输入门(Input Gate):决定哪些新信息需要被写入记忆单元。
  2. 遗忘门(Forget Gate):决定哪些旧信息需要被从记忆单元中移除。
  3. 输出门(Output Gate):决定从记忆单元中输出哪些信息。

3.2、公式⭐

下文有图解,此处看不懂可以先跳过)

LSTM在每个时间步的更新过程可以用以下公式描述:

  1. 遗忘门 : f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) ft=σ(Wf⋅[ht−1,xt]+bf)
  • f t f_t ft 表示遗忘门的输出。
  • σ \sigma σ 是sigmoid激活函数,用于将输出值限制在0到1之间。
  1. 输入门 : i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) it=σ(Wi⋅[ht−1,xt]+bi)
  • i t i_t it 表示输入门的输出。
  1. 候选记忆单元更新 : C ~ t = tanh ⁡ ( W C ⋅ [ h t − 1 , x t ] + b C ) \tilde{C}t = \tanh(W_C \cdot [h{t-1}, x_t] + b_C) C~t=tanh(WC⋅[ht−1,xt]+bC)
  • C ~ t \tilde{C}_t C~t 表示候选的记忆单元状态。
  1. 记忆单元更新 : C t = f t ∗ C t − 1 + i t ∗ C ~ t C_t = f_t \ast C_{t-1} + i_t \ast \tilde{C}_t Ct=ft∗Ct−1+it∗C~t
  • C t C_t Ct 表示当前时间步的记忆单元状态。
  1. 输出门 : o t = σ ( W o ⋅ [ h t − 1 , x t ] + b o ) o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) ot=σ(Wo⋅[ht−1,xt]+bo)
  • o t o_t ot 表示输出门的输出。
  1. 隐藏状态更新 : h t = o t ∗ tanh ⁡ ( C t ) h_t = o_t \ast \tanh(C_t) ht=ot∗tanh(Ct)
  • h t h_t ht 是当前时间步的隐藏状态。

3.3、特点

  • 有效捕获长时依赖:LSTM通过门控机制,有效地捕获序列数据中的长时依赖关系。
  • 复杂性:相对于标准RNN,LSTM的结构更为复杂,计算量也更大。

4、图解LSTM⭐

4.1、RNN

多维的角度:

二维的角度:

其实就是在原本的前馈神经网络中加入了时间的维度

4.2、时间链条

在原来的RNN的基础上,LSTM就是增加了一条时间链条 C t C_t Ct

连起来:

这个时间链条并不是跟 S t S_t St隐藏层同平面的,旋转一下即为:

4.3、记忆单元🔺

下面关于 S t S_t St和 C t C_t Ct的关系进行展开:

S t S_t St和 C t C_t Ct这条线展开平面为:

S t S_t St和 C t C_t Ct一条线拆成了三条线:

那么,关于 f 1 f_1 f1和 f 2 f_2 f2两个函数关系,
f 1 = σ ( W 1 ⋅ [ h t − 1 , x t ] + b 1 ) f_1 = \sigma(W_1 \cdot [h_{t-1}, x_t] + b_1) f1=σ(W1⋅[ht−1,xt]+b1)
i t = σ ( W 2 ⋅ [ h t − 1 , x t ] + b 2 ) i_t = \sigma(W_2 \cdot [h_{t-1}, x_t] + b_2) it=σ(W2⋅[ht−1,xt]+b2)
C ~ t = tanh ⁡ ( W ~ 2 ⋅ [ h t − 1 , x t ] + b ~ 2 ) \tilde{C}_t = \tanh(\tilde{W}2 \cdot [h{t-1}, x_t] + \tilde{b}_2) C~t=tanh(W~2⋅[ht−1,xt]+b~2)
f 2 = i t ∗ C ~ t f_2=i_t*\tilde{C}t f2=it∗C~t
C t = f t ∗ C t − 1 + i t ∗ C ~ t C_t = f_t \ast C
{t-1} + i_t \ast \tilde{C}_t Ct=ft∗Ct−1+it∗C~t

所以根据这张图,以及上面的公式,不难看出:

图中的"删除"就是遗忘门 f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) ft=σ(Wf⋅[ht−1,xt]+bf);

图中的"增加"就是输入门 i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) it=σ(Wi⋅[ht−1,xt]+bi)和候选记忆单元更新 C ~ t = tanh ⁡ ( W C ⋅ [ h t − 1 , x t ] + b C ) \tilde{C}t = \tanh(W_C \cdot [h{t-1}, x_t] + b_C) C~t=tanh(WC⋅[ht−1,xt]+bC)的乘积

4.4、LSTM

关于LSTM,有这么一张经典图:

这张图展示了LSTM单元的详细结构,包含了三个主要的门:遗忘门、输入门和输出门,以及记忆单元的更新过程。

从输入到输出,LSTM单元通过门控机制控制信息的流动,允许网络在长时间跨度内捕获依赖关系。

5、LSTM与GRU的对比

  1. 复杂性
    • LSTM更复杂,参数更多。
    • GRU较为简洁,参数更少,训练速度更快。
  2. 性能
    • 两者在处理长时依赖性任务时表现都很优异,具体选择往往取决于数据集和计算资源。
    • 在一些特定任务和数据集上,GRU可能比LSTM表现更好,尤其是在计算资源有限的情况下。
  3. 使用场景
    • 对于需要更强的长期记忆和复杂信息流动的任务,LSTM可能更合适。
    • 对于实时性要求较高或者模型简单性要求较高的任务,GRU可能更具优势。

LSTM和GRU是两种非常成功的RNN变体,通过改进信息传递机制,有效解决了传统RNN在处理长序列数据时的局限性。

它们在自然语言处理、语音识别和时间序列预测等领域得到广泛应用。

6、应用

RNN及其变体广泛应用于以下领域:

  • 自然语言处理:如语言模型、机器翻译和文本生成。
  • 语音识别:将音频序列转换为文本。
  • 时间序列预测:如股票价格预测和天气预报。
  • 视频分析:从视频帧中提取时间信息。

7、训练技巧

  • 梯度裁剪:限制梯度的大小以防止梯度爆炸。
  • 正则化:使用Dropout等技术防止过拟合。
  • 预训练和转移学习:利用大规模预训练模型微调特定任务。

RNN模型在序列数据处理中具有强大的表现力和适应能力,但也面临一些挑战。通过使用LSTM、GRU等改进模型,结合适当的训练技巧,能够有效地应用于各种实际问题。

相关推荐
抓个马尾女孩4 分钟前
MoCo对比损失
人工智能·深度学习
阿里云大数据AI技术26 分钟前
TAG:BladeLLM 的纯异步推理架构
人工智能·tag·llm推理
m0_603888711 小时前
什么是上采样什么是下采样
人工智能·深度学习·计算机视觉
TSINGSEE1 小时前
人员抽烟AI检测算法在智慧安防领域的创新应用,助力监控智能化
人工智能·算法·视频编解码·安防视频监控·视频监控管理平台
一枚游戏干饭人1 小时前
【运营攻略】怎样进行游戏产品的定位
人工智能·游戏·语音识别
Python极客之家1 小时前
基于机器学习的乳腺癌肿瘤智能分析预测系统
人工智能·python·机器学习·毕业设计·xgboost·可视化分析
嵌入式杂谈1 小时前
深入理解AI大模型:参数、Token、上下文窗口、上下文长度和温度
人工智能
范范08251 小时前
自然语言处理入门:从基础概念到实战项目
人工智能·自然语言处理
_feivirus_1 小时前
神经网络_使用TensorFlow预测气温
人工智能·神经网络·算法·tensorflow·预测气温
deflag2 小时前
第T1周:Tensorflow实现mnist手写数字识别
人工智能·python·机器学习·分类·tensorflow