pytorch-LSTM

目录

  • [1. RNN存在的问题](#1. RNN存在的问题)
  • [2. LSTM的由来](#2. LSTM的由来)
  • [3. LSTM门](#3. LSTM门)
    • [3.1 遗忘门](#3.1 遗忘门)
    • [3.2 输入门](#3.2 输入门)
    • [3.3 输出门](#3.3 输出门)
  • [4. LSTM是如何减轻梯度弥散问题](#4. LSTM是如何减轻梯度弥散问题)

1. RNN存在的问题

如下图:RNN能满足预测下一个单词,但是对于获取更多的上下文信息就做不到了。

2. LSTM的由来

RNN能做到短时记忆即short time memory,而LSTM相对RNN能够处理更长的时间序列,因此被称为LSTM即long short time memory

RNN有一串重复的模块,这些模块使用统一的权重Whh和Wih

LSTM也有一连串的类似结构,但是重复模块是不同的结构,它用四个单层的神经网络替代,并以指定的方式相互作用。它有三个门,分别是遗忘门、输入门和输出门。

3. LSTM门

门是一种信息过滤方式,他们由sigmod函数和点乘操作组成,sigmod范围是0~1,因此通过sigmod函数可以控制输出。

3.1 遗忘门

遗忘门f~t~是h~t-1~和x~t~经过一系列运算,再经过sigmod函数得到的

3.2 输入门

输入门由两部分组成,一个是i~t~输入门层,它是通过h~t-1~和x~t~经过一系列运算,再经过sigmod函数得到的。

另一个是新的输入C~t~^'^,这里没有直接使用x~t~作为输入,而是通过h~t-1~和x~t~经过一系列运算,再经过tanh函数得到新的输入C~t~^'^。

最后输出C~t~ = f~t~*C~t-1~ + i~t~*C~t~^'^

3.3 输出门

输出门o~t~也是通过h~t-1~和x~t~经过一系列运算,再经过sigmod函数得到的。

最后的输出h~t~ = o~t~*tanh(C~t~)
注意:LSTM中h~t~已经不是memory了,而是输出,C~t~才是memory

可以看出每个门的运算都与h~t-1~和x~t~相关,并且通过sigmod函数来控制门的开度,最后的输出h~t~使用了tanh

输入们和遗忘门门的组合,会得到不同的值,如下图:

4. LSTM是如何减轻梯度弥散问题

从梯度计算公式可以知道,RNN的梯度中有W~hh~的累乘,当W~hh~<1时,就可能出现梯度弥散,而LSTM梯度由几项累加得到,即使W很小也很难出现梯度弥散。

相关推荐
Shy9604189 小时前
Pytorch实现transformer语言模型
人工智能·pytorch
shuyeah10 小时前
LSTM结构原理
人工智能·rnn·lstm
夜猫程序猿10 小时前
RNN中的梯度消失与梯度爆炸问题
rnn·深度学习
YRr YRr11 小时前
如何解决RNN在处理深层序列数据时遇到的如梯度消失、长期以来等问题
人工智能·rnn·lstm
周末不下雨17 小时前
跟着小土堆学习pytorch(六)——神经网络的基本骨架(nn.model)
pytorch·神经网络·学习
蜡笔小新星1 天前
针对初学者的PyTorch项目推荐
开发语言·人工智能·pytorch·经验分享·python·深度学习·学习
矩阵猫咪1 天前
【深度学习】时间序列预测、分类、异常检测、概率预测项目实战案例
人工智能·pytorch·深度学习·神经网络·机器学习·transformer·时间序列预测
zs1996_1 天前
深度学习注意力机制类型总结&pytorch实现代码
人工智能·pytorch·深度学习
阿亨仔1 天前
Pytorch猴痘病识别
人工智能·pytorch·python·深度学习·算法·机器学习
AI视觉网奇1 天前
nvlink 训练笔记
pytorch·笔记·深度学习