LSTM和GRU的介绍以及Pytorch源码解析

介绍一下LSTM模型的结构以及源码,用作自己复习的材料。

LSTM模型所对应的源码在:\PyTorch\Lib\site-packages\torch\nn\modules\RNN.py文件中。

上次上一篇文章介绍了RNN序列模型,但是RNN模型存在比较严重的梯度爆炸和梯度消失问题。

本文介绍的LSTM模型解决的RNN的大部分缺陷。

首先展示LSTM的模型框架:

下面是LSTM模型的数学推导公式:

表示时刻的隐藏状态,表示时刻的记忆细胞状态,表示时刻的输入,表示在时间的隐藏状态或在时间的初始隐藏状态。

分别是输入门、遗忘门、单元门和输出门。

这张图片比较好的介绍了各个门之间的交互关系以及输入输出,大家可以放大看一下。

接下来展示GRU的框架模型:

下面是GRU的数学推导公式:

表示时刻的隐藏状态,表示时刻的输入,表示在时间的隐藏状态或在时间的初始隐藏状态。分别表示重置门更新门和新建门

上面的图片可以更直观的看到GRU中是如何迭代的。

接下来我们看一下源码中LSTM和GRU类的初始化(只介绍几个重要的参数):

复制代码
torch.nn.LSTM(self, input_size, hidden_size, num_layers=1,
              bias=True, batch_first=False, dropout=0.0, 
              bidirectional=False, proj_size=0, device=None,
              dtype=None)

torch.nn.GRU(self, input_size, hidden_size, num_layers=1,
             bias=True, batch_first=False, dropout=0.0, 
             bidirectional=False, device=None, dtype=None)
  • input_size:输入数据中的特征数(可以理解为嵌入维度 embedding_dim)。
  • hidden_size:处于隐藏状态 h 的特征数(可以理解为输出的特征维度)。
  • num_layers:代表着RNN的层数,默认是1(层),当该参数大于零时,又称为多层RNN。
  • bidirectional:即是否启用双向LSTM(GRU),默认关闭。

LSTM与GRU都是特殊的RNN,因此输入输出可以参考的上一篇介绍RNN的文章,在这里直接进行代码举例。

复制代码
lstm1 = nn.LSTM(input_size=20,hidden_size=40,num_layers=4,bidirectional=True)
lstm2 = nn.LSTM(input_size=20,hidden_size=40,num_layers=4,bidirectional=False)

gru1 = nn.GRU(input_size=20,hidden_size=25,num_layers=4,bidirectional=True)
gru2 = nn.GRU(input_size=20,hidden_size=25,num_layers=4,bidirectional=False)

tensor1 = torch.randn(5,10,20)  # (batch_size * seq_len * emb_dim)
tensor2 = torch.randn(5,10,20)  # (batch_size * seq_len * emb_dim)

out_lstm1,(hn, cn) = lstm1(tensor1)  # (batch_size * seq_len * (hidden_size * bidirectional))
out_lstm2,(hn, cn) = lstm2(tensor2)  # (batch_size * seq_len * (hidden_size * bidirectional))

out_gru1,h_n = gru1(tensor1)  # (batch_size * seq_len * (hidden_size * bidirectional))
out_gru2,h_n = gru2(tensor1)  # (batch_size * seq_len * (hidden_size * bidirectional))

print(out_lstm1.shape)  # torch.Size([5, 10, 80])
print(out_lstm2.shape)  # torch.Size([5, 10, 40])

print(out_gru1.shape)  # torch.Size([5, 10, 50])
print(out_gru2.shape)  # torch.Size([5, 10, 25])

维度已经在注释中给大家标注上了!

相关推荐
胡耀超1 分钟前
大模型架构演进全景:从Transformer到下一代智能系统的技术路径(MoE、Mamba/SSM、混合架构)
人工智能·深度学习·ai·架构·大模型·transformer·技术趋势分析
小杨勇敢飞1 小时前
UNBIASED WATERMARK:大语言模型的无偏差水印
人工智能·语言模型·自然语言处理
Luchang-Li1 小时前
sglang pytorch NCCL hang分析
pytorch·python·nccl
m0_603888711 小时前
Delta Activations A Representation for Finetuned Large Language Models
人工智能·ai·语言模型·自然语言处理·论文速览
金融小师妹1 小时前
基于哈塞特独立性表态的AI量化研究:美联储政策独立性的多维验证
大数据·人工智能·算法
qinyia2 小时前
Wisdom SSH 是一款创新性工具,通过集成 AI 助手,为服务器性能优化带来极大便利。
服务器·人工智能·ssh
昨日之日20064 小时前
Wan2.2-S2V - 音频驱动图像生成电影级质量的数字人视频 ComfyUI工作流 支持50系显卡 一键整合包下载
人工智能·音视频
SEO_juper7 小时前
大型语言模型SEO(LLM SEO)完全手册:驾驭搜索新范式
人工智能·语言模型·自然语言处理·chatgpt·llm·seo·数字营销
攻城狮7号8 小时前
腾讯混元翻译模型Hunyuan-MT-7B开源,先前拿了30个冠军
人工智能·hunyuan-mt-7b·腾讯混元翻译模型·30个冠军
zezexihaha8 小时前
从“帮写文案”到“管生活”:个人AI工具的边界在哪?
人工智能