【课程总结】Day17(中):LSTM及GRU模型简介

前言

在上一章【课程总结】Day17(上):NLP自然语言处理及RNN网络我们初步了解RNN的基本概念和原理。本章内容,我们将继续了解RNN的变种模型,如LSTM和GRU。

RNN发展历史

早期发展

  • 1980年代:RNN 的概念最早由 David Rumelhart 和 Geoffrey Hinton 提出。早期的 RNN 结构相对简单,主要用于字符级别的序列处理。
  • 1986年:Rumelhart 等人提出的反向传播算法(Backpropagation Through Time, BPTT)使得 RNN 的训练变得可行。

问题与挑战

  • 1990年代:随着 RNN 的应用增多,研究人员发现传统 RNN 在处理长序列时面临梯度消失和梯度爆炸的问题,限制了其在实际应用中的有效性。

LSTM 和 GRU 的提出

  • 1997年:Sepp Hochreiter 和 Jürgen Schmidhuber 提出了长短期记忆网络(LSTM),引入了门控机制,有效解决了传统 RNN 的梯度消失问题。
  • 2014年:门控循环单元(GRU)被提出,作为 LSTM 的简化版本,进一步提高了 RNN 的性能和训练效率。

Transformer 的出现

  • 2017年:Vaswani 等人提出了 Transformer 模型,完全基于自注意力机制,摆脱了 RNN 的结构限制。Transformer 能够并行处理序列数据,显著提高了训练速度和效率。

LSTM模型

LSTM模型介绍

长短期记忆网络(LSTM)是一种特殊类型的递归神经网络(RNN),旨在解决传统 RNN 在处理长序列时常见的梯度消失和梯度爆炸问题。LSTM 通过引入门控机制,能够有效地捕捉长期依赖关系。

LSTM模型结构

LSTM 单元主要由以下几个部分组成:

  • 第一部分:遗忘门(Forget Gate)
    • 作用:遗忘不想要的。输出在 0 到 1 之间的值,0 表示完全遗忘,1 表示完全保留。
  • 第二部分:输入门(Input Gate)
    • 作用:输出短期状态,决定当前输入 (x_t) 有多少信息被添加到单元状态中。
  • 第三部分:输出门(Output Gate)
    • 作用:决定当前单元状态 (C_t) 的多少信息将输出到隐状态。

LSTM的API使用

自动循环
python 复制代码
# 创建模型
lstm = nn.LSTM(input_size=128, hidden_size=256)

# 定义输入(待处理信息)
X = torch.randn(13, 2, 128)
# 初始的长期状态
c0 = torch.zeros(1, 2, 256, dtype=torch.float32)
# 初始的短期状态
h0 = torch.zeros(1, 2, 256, dtype=torch.float32)

# 执行循环
out, (hn, cn) = lstm(X, (h0, c0))
out.shape, hn.shape, cn.shape
# 输出结果
# (torch.Size([13, 2, 256]), torch.Size([1, 2, 256]), torch.Size([1, 2, 256]))
手写循环
python 复制代码
# 创建手动循环
lstm_cell = nn.LSTMCell(input_size=128, hidden_size=256)

# 定义输入(待处理信息)
X = torch.randn(13, 2, 128)
c0 = torch.zeros(2, 256, dtype=torch.float32)
h0 = torch.zeros(2, 256, dtype=torch.float32)

out = []
# 循环处理
for x in X:
    h0, c0 = lstm_cell(x, (h0, c0))
    out.append(h0)
# 将结果堆叠
out = torch.stack(out)
len(out), out.shape
# 输出结果
# (13, torch.Size([13, 2, 256]))

GRU模型

GRU模型介绍

门控循环单元(Gated Recurrent Unit, GRU)是一种改进的递归神经网络(RNN)架构,旨在解决传统 RNN 在处理长序列时遇到的梯度消失问题。GRU 于 2014 年由 Kyunghyun Cho 等人提出,作为 LSTM 的一种简化版本,具有更少的参数和更高的计算效率。

GRU模型结构

GRU 主要由以下两个门控机制组成:

  • 更新门(Update Gate)
    • 作用:决定当前隐状态 (h_t) 中保留多少前一个隐状态 (h_{t-1}) 的信息。
  • 重置门(Reset Gate)
    • 作用:决定在计算当前隐状态时,前一个隐状态的信息被遗忘的程度。

GRU的API使用

自动循环
python 复制代码
# 创建模型
gru = nn.GRU(input_size=128, hidden_size=256)

# 定义输入(待处理信息)
X = torch.randn(13, 2, 128)
# 初始的短期状态
h0 = torch.zeros(1, 2, 256, dtype=torch.float32)

# 执行循环
output, hn = gru(X, h0)
output.shape, hn.shape
# 输出结果
# (torch.Size([13, 2, 256]), torch.Size([1, 2, 256]))
手写循环
python 复制代码
gru_cell = nn.GRUCell(input_size=128, hidden_size=256)

X = torch.randn(13, 2, 128)
h0 = torch.zeros(2, 256, dtype=torch.float32)

out = []
for x in X:
    h0 = gru_cell(x, h0)
    out.append(h0)
out = torch.stack(out)
len(out), out.shape

RNN、LSTM、GRU 的优势与劣势

RNN优势与劣势

优势

  1. 简单性:RNN 结构相对简单,易于理解和实现。
  2. 适合序列数据 :RNN 能够处理任意长度的序列数据(例如:股票数据、销售数据、自然语言、语音信号等),实现了统计机器学习深度学习 的过渡。

劣势

  1. 梯度消失与爆炸:传统 RNN 在处理长序列时,容易出现梯度消失或梯度爆炸的问题,导致训练困难。
  2. 训练时间较长:由于序列数据的特性,RNN 的训练时间通常较长,尤其是在长序列上。
  3. 并行化困难:RNN 的计算依赖于前一个时间步的输出,导致其在训练和推理时难以进行并行化,效率较低,无法有效利用硬件加速。

LSTM优势与劣势

优势

  1. 解决梯度消失问题:通过引入遗忘门、输入门和输出门,有效缓解了梯度消失问题。
  2. 捕捉长距离依赖:能够更好地捕捉长距离依赖关系,适用于长序列数据。

劣势

  1. 复杂性:结构较为复杂,参数较多,训练和调优的难度增加。
  2. 计算开销:由于门控机制的引入,计算成本较高,训练速度相对较慢。
  3. 过拟合风险:参数较多可能导致过拟合,尤其是在数据量不足的情况下。

GRU优势与劣势

优势

  1. 简化结构:相较于 LSTM,GRU 只有两个门,结构更简单,参数更少。
  2. 高效训练:由于参数较少,训练速度通常更快,适合资源有限的环境。
  3. 良好性能:在许多任务中,GRU 的表现与 LSTM 相当,有时甚至更好。

劣势

  1. 灵活性不足:虽然 GRU 在许多任务中表现良好,但在某些特定任务上,LSTM 可能会表现得更好。
  2. 可解释性问题:与其他深度学习模型一样,GRU 的内部机制较难以解释,可能导致模型的可解释性问题。

内容小结

  • LSTM 和 GRU 都是 RNN 的变种,它们都旨在解决传统 RNN 在处理长序列时遇到的梯度消失问题。
  • LSTM 通过引入遗忘门(Forget Gate)和输入门(Input Gate)来控制短期记忆和长期记忆,能够有效地捕捉长期依赖关系。
  • GRU 通过引入更新门(Update Gate)和重置门(Reset Gate)来控制短期记忆和长期记忆,能够有效地捕捉长期依赖关系。
  • LSTM 和 GRU 的API使用方式都较为相似,都是通过调用 nn.LSTM()nn.GRU() 函数创建模型,他们与RNN类似都有自动循环和手写循环两种方式。
  • RNN的优势是简单,易于理解和实现;劣势是梯度消失和训练时间较长。
  • LSTM的优势是能够捕捉长距离依赖关系,缺点是参数较多,训练和调优的难度增加。
  • GRU的优势是简化结构,缺点是参数较多,训练和调优的难度增加。

参考资料

一文搞懂 LSTM(长短期记忆网络)

相关推荐
肥猪猪爸31 分钟前
使用卡尔曼滤波器估计pybullet中的机器人位置
数据结构·人工智能·python·算法·机器人·卡尔曼滤波·pybullet
LZXCyrus1 小时前
【杂记】vLLM如何指定GPU单卡/多卡离线推理
人工智能·经验分享·python·深度学习·语言模型·llm·vllm
我感觉。1 小时前
【机器学习chp4】特征工程
人工智能·机器学习·主成分分析·特征工程
YRr YRr1 小时前
深度学习神经网络中的优化器的使用
人工智能·深度学习·神经网络
DieYoung_Alive1 小时前
一篇文章了解机器学习(下)
人工智能·机器学习
夏沫的梦1 小时前
生成式AI对产业的影响与冲击
人工智能·aigc
goomind2 小时前
YOLOv8实战木材缺陷识别
人工智能·yolo·目标检测·缺陷检测·pyqt5·木材缺陷识别
只怕自己不够好2 小时前
《OpenCV 图像基础操作全解析:从读取到像素处理与 ROI 应用》
人工智能·opencv·计算机视觉
幻风_huanfeng2 小时前
人工智能之数学基础:线性代数在人工智能中的地位
人工智能·深度学习·神经网络·线性代数·机器学习·自然语言处理
嵌入式大圣2 小时前
嵌入式系统与OpenCV
人工智能·opencv·计算机视觉