【现代深度学习技术】现代循环神经网络03:深度循环神经网络

【作者主页】Francek Chen

【专栏介绍】⌈ ⌈ ⌈PyTorch深度学习 ⌋ ⌋ ⌋ 深度学习 (DL, Deep Learning) 特指基于深层神经网络模型和方法的机器学习。它是在统计机器学习、人工神经网络等算法模型基础上,结合当代大数据和大算力的发展而发展出来的。深度学习最重要的技术特征是具有自动提取特征的能力。神经网络算法、算力和数据是开展深度学习的三要素。深度学习在计算机视觉、自然语言处理、多模态数据分析、科学探索等领域都取得了很多成果。本专栏介绍基于PyTorch的深度学习算法实现。

【GitCode】专栏资源保存在我的GitCode仓库:https://gitcode.com/Morse_Chen/PyTorch_deep_learning

文章目录


到目前为止,我们只讨论了具有一个单向隐藏层的循环神经网络。其中,隐变量和观测值与具体的函数形式的交互方式是相当随意的。只要交互类型建模具有足够的灵活性,这就不是一个大问题。然而,对一个单层来说,这可能具有相当的挑战性。之前在线性模型中,我们通过添加更多的层来解决这个问题。而在循环神经网络中,我们首先需要确定如何添加更多的层,以及在哪里添加额外的非线性,因此这个问题有点棘手。

事实上,我们可以将多层循环神经网络堆叠在一起,通过对几个简单层的组合,产生了一个灵活的机制。特别是,数据可能与不同层的堆叠有关。例如,我们可能希望保持有关金融市场状况(熊市或牛市)的宏观数据可用,而微观数据只记录较短期的时间动态。

图1描述了一个具有 L L L个隐藏层的深度循环神经网络,每个隐状态都连续地传递到当前层的下一个时间步和下一层的当前时间步。


图1 深度循环神经网络结构

一、函数依赖关系

我们可以将深度架构中的函数依赖关系形式化,这个架构是由图1中描述了 L L L个隐藏层构成。后续的讨论主要集中在经典的循环神经网络模型上,但是这些讨论也适应于其他序列模型。

假设在时间步 t t t有一个小批量的输入数据 X t ∈ R n × d \mathbf{X}_t \in \mathbb{R}^{n \times d} Xt∈Rn×d(样本数: n n n,每个样本中的输入数: d d d)。同时,将 l t h l^\mathrm{th} lth隐藏层( l = 1 , ... , L l=1,\ldots,L l=1,...,L)的隐状态设为 H t ( l ) ∈ R n × h \mathbf{H}_t^{(l)} \in \mathbb{R}^{n \times h} Ht(l)∈Rn×h(隐藏单元数: h h h),输出层变量设为 O t ∈ R n × q \mathbf{O}_t \in \mathbb{R}^{n \times q} Ot∈Rn×q(输出数: q q q)。设置 H t ( 0 ) = X t \mathbf{H}t^{(0)} = \mathbf{X}t Ht(0)=Xt,第 l l l个隐藏层的隐状态使用激活函数 ϕ l \phi_l ϕl,则
H t ( l ) = ϕ l ( H t ( l − 1 ) W x h ( l ) + H t − 1 ( l ) W h h ( l ) + b h ( l ) ) (1) \mathbf{H}t^{(l)} = \phi_l(\mathbf{H}t^{(l-1)} \mathbf{W}{xh}^{(l)} + \mathbf{H}{t-1}^{(l)} \mathbf{W}
{hh}^{(l)} + \mathbf{b}h^{(l)}) \tag{1} Ht(l)=ϕl(Ht(l−1)Wxh(l)+Ht−1(l)Whh(l)+bh(l))(1) 其中,权重 W x h ( l ) ∈ R h × h \mathbf{W}{xh}^{(l)} \in \mathbb{R}^{h \times h} Wxh(l)∈Rh×h, W h h ( l ) ∈ R h × h \mathbf{W}
{hh}^{(l)} \in \mathbb{R}^{h \times h} Whh(l)∈Rh×h和偏置 b h ( l ) ∈ R 1 × h \mathbf{b}_h^{(l)} \in \mathbb{R}^{1 \times h} bh(l)∈R1×h都是第 l l l个隐藏层的模型参数。

最后,输出层的计算仅基于第 l l l个隐藏层最终的隐状态:
O t = H t ( L ) W h q + b q (2) \mathbf{O}_t = \mathbf{H}t^{(L)} \mathbf{W}{hq} + \mathbf{b}q \tag{2} Ot=Ht(L)Whq+bq(2) 其中,权重 W h q ∈ R h × q \mathbf{W}{hq} \in \mathbb{R}^{h \times q} Whq∈Rh×q和偏置 b q ∈ R 1 × q \mathbf{b}_q \in \mathbb{R}^{1 \times q} bq∈R1×q都是输出层的模型参数。

与多层感知机一样,隐藏层数目 L L L和隐藏单元数目 h h h都是超参数。也就是说,它们可以由我们调整的。另外,用门控循环单元或长短期记忆网络的隐状态来代替式(1)中的隐状态进行计算,可以很容易地得到深度门控循环神经网络或深度长短期记忆神经网络。

二、简洁实现

实现多层循环神经网络所需的许多逻辑细节在高级API中都是现成的。简单起见,我们仅示范使用此类内置函数的实现方式。以长短期记忆网络模型为例,该代码与之前在长短期记忆网络(LSTM)中使用的代码非常相似,实际上唯一的区别是我们指定了层的数量,而不是使用单一层这个默认值。像往常一样,我们从加载数据集开始。

python 复制代码
import torch
from torch import nn
from d2l import torch as d2l

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

选择超参数这类架构决策也跟长短期记忆网络(LSTM)中的决策非常相似。因为我们有不同的词元,所以输入和输出都选择相同数量,即vocab_size。隐藏单元的数量仍然是 256 256 256。唯一的区别是,我们现在通过num_layers的值来设定隐藏层数。

python 复制代码
vocab_size, num_hiddens, num_layers = len(vocab), 256, 2
num_inputs = vocab_size
device = d2l.try_gpu()
lstm_layer = nn.LSTM(num_inputs, num_hiddens, num_layers)
model = d2l.RNNModel(lstm_layer, len(vocab))
model = model.to(device)

三、训练与预测

由于使用了长短期记忆网络模型来实例化两个层,因此训练速度被大大降低了。

python 复制代码
num_epochs, lr = 500, 2
d2l.train_ch8(model, train_iter, vocab, lr*1.0, num_epochs, device)


小结

  • 在深度循环神经网络中,隐状态的信息被传递到当前层的下一时间步和下一层的当前时间步。
  • 有许多不同风格的深度循环神经网络,如长短期记忆网络、门控循环单元、或经典循环神经网络。这些模型在深度学习框架的高级API中都有涵盖。
  • 总体而言,深度循环神经网络需要大量的调参(如学习率和截断)来确保合适的收敛,模型的初始化也需要谨慎。
相关推荐
Elastic 中国社区官方博客3 分钟前
Elastic Platform 8.18 和 9.0:ES|QL Lookup Joins 功能现已推出,Lucene 10!
大数据·人工智能·sql·elasticsearch·搜索引擎·全文检索·lucene
Tech Synapse21 分钟前
树莓派智能摄像头实战指南:基于TensorFlow Lite的端到端AI部署
人工智能·python·tensorflow·mobilenetv2·tensorflow lite
点云SLAM32 分钟前
张正友相机标定算法(Zhang’s camera calibration method)原理和过程推导(附OpenCV代码示例)
人工智能·opencv·计算机视觉·相机标定·张正友相机标定算法·内外参标定
是Yu欸32 分钟前
阿里云 OpenManus 实战:高效AI协作体系
人工智能·阿里云·langchain·prompt·aigc·ai写作·openmanus
闭月之泪舞34 分钟前
深度学习-神经网络参数优化的约束与迭代策略
人工智能·深度学习·神经网络
北京阿尔泰科技厂家35 分钟前
PXI总线开关卡80个交叉点组成的中密度 PXI矩阵开关模块
人工智能·科技·工业自动化·矩阵开关·pxi总线
哈全网络1 小时前
如何使用 DeepSeek 帮助自己的工作?
人工智能·算法·ai编程·ai写作
资讯分享周1 小时前
TCL中环深化全球布局,技术迭代应对行业调整
大数据·人工智能
心灵彼岸-诗和远方1 小时前
走进AI的奇妙世界:探索历史、革命与未来机遇
人工智能
程序员阿龙1 小时前
【精选】基于数据挖掘的广州招聘可视化分析系统(大数据组件+Spark+Hive+MySQL+AI智能云+DeepSeek人工智能+深度学习之LSTM算法)
大数据·人工智能·hadoop·数据挖掘·spark·数据分析与可视化·用户兴趣分析