一、循环神经网络(RNN)
1.1 基本原理
循环神经网络之所以得名,是因为它在处理序列数据时,隐藏层的节点之间存在循环连接。这意味着网络能够记住之前时间步的信息,并利用这些信息来处理当前的输入。 想象一下,我们正在处理一段文本,每个单词就是一个时间步的输入。RNN 在读取每个单词时,不仅会考虑当前单词的含义,还会结合之前已经读过的单词信息,从而更好地理解整个句子的语境。 用数学公式来表示,假设我们有一个输入序列 <math xmlns="http://www.w3.org/1998/Math/MathML"> x 1 , x 2 , . . . , x T x_1,x_2,...,x_T </math>x1,x2,...,xT,在时间步t,RNN 的隐藏状态 <math xmlns="http://www.w3.org/1998/Math/MathML"> h t h_ t </math>ht的计算方式如下: <math xmlns="http://www.w3.org/1998/Math/MathML"> h t = σ ( W x h x t + W h h h t − 1 + b h ) h_t=σ(W_{xh}x_t+W_{hh}h_{t−1}+b_h) </math>ht=σ(Wxhxt+Whhht−1+bh) 其中,σ是激活函数(通常为 tanh 或 sigmoid), <math xmlns="http://www.w3.org/1998/Math/MathML"> W x h W_{xh} </math>Wxh是输入到隐藏层的权重矩阵, <math xmlns="http://www.w3.org/1998/Math/MathML"> W h h W_{hh} </math>Whh是隐藏层到隐藏层的权重矩阵, <math xmlns="http://www.w3.org/1998/Math/MathML"> b h b_h </math>bh是偏置项。输出 <math xmlns="http://www.w3.org/1998/Math/MathML"> y t y_t </math>yt通常通过以下公式计算:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> y t = W h y h t + b y y_t =W_{hy}h_t+b_y </math>yt=Whyht+by
这里 <math xmlns="http://www.w3.org/1998/Math/MathML"> W h y W_{hy} </math>Why是隐藏层到输出层的权重矩阵, <math xmlns="http://www.w3.org/1998/Math/MathML"> b y b_y </math>by是输出层的偏置项。
1.2 前向传播过程
以前文提到的文本处理为例,假设我们有一个简单的句子 "我喜欢深度学习",我们将每个单词通过词向量表示后作为输入 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt依次输入到 RNN 中。在第一个时间步,输入 "我" 对应的词向量 <math xmlns="http://www.w3.org/1998/Math/MathML"> x 1 x_1 </math>x1,结合初始隐藏状态 <math xmlns="http://www.w3.org/1998/Math/MathML"> h 0 h_0 </math>h0(通常初始化为零向量),通过上述公式计算得到隐藏状态 <math xmlns="http://www.w3.org/1998/Math/MathML"> h 1 h_1 </math>h1。接着,输入 "喜欢" 对应的词向量 <math xmlns="http://www.w3.org/1998/Math/MathML"> x 2 x_2 </math>x2,此时结合 <math xmlns="http://www.w3.org/1998/Math/MathML"> h 1 h_1 </math>h1计算 <math xmlns="http://www.w3.org/1998/Math/MathML"> h 2 h_2 </math>h2,以此类推,直到处理完整个句子。最终的隐藏状态 <math xmlns="http://www.w3.org/1998/Math/MathML"> h T h_T </math>hT可以用于预测句子的情感倾向(比如是积极还是消极)等任务。
1.3 训练过程
RNN 的训练通常使用反向传播通过时间(Backpropagation Through Time, BPTT)算法。BPTT 算法本质上是标准反向传播算法在时间序列上的扩展。它通过计算损失函数(比如交叉熵损失)关于网络参数( <math xmlns="http://www.w3.org/1998/Math/MathML"> W x h W_{xh} </math>Wxh , <math xmlns="http://www.w3.org/1998/Math/MathML"> W h h W_{hh} </math>Whh , <math xmlns="http://www.w3.org/1998/Math/MathML"> W h y W_{hy} </math>Why , <math xmlns="http://www.w3.org/1998/Math/MathML"> b h b_h </math>bh , <math xmlns="http://www.w3.org/1998/Math/MathML"> b y b_y </math>by等)的梯度,然后使用梯度下降等优化算法来更新参数,使得损失函数逐渐减小。在计算梯度时,由于隐藏层状态在时间步之间的循环连接,梯度会在时间维度上进行反向传播,这也是 BPTT 名称的由来。
1.4 面临的挑战
RNN 虽然具有记忆能力,但在处理长序列时,会面临梯度消失或梯度爆炸的问题。简单来说,当反向传播的时间步数增多时,梯度在传递过程中可能会变得非常小(梯度消失),导致前面时间步的信息对当前时间步的影响几乎可以忽略不计,使得模型难以学习到长距离的依赖关系;或者梯度变得非常大(梯度爆炸),导致参数更新不稳定,模型无法正常训练。
挑战类型 | 描述 | 对模型的影响 |
---|---|---|
梯度消失 | 反向传播时梯度逐渐变小 | 难以学习长距离依赖关系,模型性能下降 |
梯度爆炸 | 反向传播时梯度逐渐变大 | 参数更新不稳定,模型无法正常训练 |
二、长短期记忆网络(LSTM)
2.1 结构与原理
为了解决 RNN 的梯度问题,LSTM 应运而生。LSTM 引入了一种特殊的结构 ------ 细胞状态(Cell State),它就像一条信息高速公路,能够让信息在序列中相对轻松地流动,从而有效捕捉长期依赖关系。 LSTM 通过三个门来控制细胞状态中的信息:遗忘门(Forget Gate)、输入门(Input Gate)和输出门(Output Gate)。 遗忘门决定从上一个时间步的细胞状态 <math xmlns="http://www.w3.org/1998/Math/MathML"> C t − 1 C_{t−1} </math>Ct−1中丢弃哪些信息,其计算公式为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> f _ t = σ ( W _ x f x _ t + W _ h f h _ t − 1 + b _ f ) f\t=σ(W\{xf}x\t+W\{hf}h\_{t−1}+b\_f) </math>f_t=σ(W_xfx_t+W_hfh_t−1+b_f)
这里 <math xmlns="http://www.w3.org/1998/Math/MathML"> W x f , W h f , b f W_{xf} ,W_{hf} ,b_f </math>Wxf,Whf,bf分别是遗忘门对应的权重矩阵和偏置项。 输入门决定将哪些新信息添加到细胞状态中,它由两部分组成。首先是输入门值 <math xmlns="http://www.w3.org/1998/Math/MathML"> i t i_t </math>it,计算公式为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> i _ t = σ ( W _ x i x _ t + W _ h i h _ t − 1 + b _ i ) i\t=σ(W\{xi}x\t+W\{hi}h\_{t−1}+b\_i) </math>i_t=σ(W_xix_t+W_hih_t−1+b_i)
然后是候选细胞状态 <math xmlns="http://www.w3.org/1998/Math/MathML"> C ~ t \tilde C_t </math>C~t,计算公式为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> C ~ _ t = t a n h ( W _ x C x _ t + W _ h C h _ t − 1 + b _ C ) \tilde C\t =tanh(W\{xC}x\t+W\{hC}h\_{t−1}+b\_C) </math>C~_t=tanh(W_xCx_t+W_hCh_t−1+b_C)
最终更新后的细胞状态 <math xmlns="http://www.w3.org/1998/Math/MathML"> C t C_t </math>Ct为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> C _ t = f _ t ⊙ C _ t − 1 + i _ t ⊙ C ~ _ t C\_t =f\t ⊙C\{t−1} +i\_t⊙ \tilde C\_t </math>C_t=f_t⊙C_t−1+i_t⊙C~_t
其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> ⊙ ⊙ </math>⊙表示逐元素相乘。 输出门决定细胞状态的哪些部分将作为当前时间步的输出 <math xmlns="http://www.w3.org/1998/Math/MathML"> h t h_t </math>ht,计算公式为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> o _ t = σ ( W _ x o x _ t + W _ h o h _ t − 1 + b _ o ) o\t=σ(W\{xo}x\t+W\{ho}h\_{t−1}+b\_o) </math>o_t=σ(W_xox_t+W_hoh_t−1+b_o)
然后当前时间步的隐藏状态 <math xmlns="http://www.w3.org/1998/Math/MathML"> h t h_t </math>ht为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> h _ t = o _ t ⊙ t a n h ( C _ t ) h\_t=o\_t⊙tanh(C\_t) </math>h_t=o_t⊙tanh(C_t)
LSTM 结构示意图:

2.2 工作流程
在每个时间步,LSTM 首先通过遗忘门决定保留或丢弃上一个时间步细胞状态中的哪些信息。然后通过输入门和候选细胞状态决定添加哪些新信息到细胞状态中。更新完细胞状态后,再通过输出门决定输出哪些信息作为当前时间步的隐藏状态。这个过程不断重复,使得 LSTM 能够有效处理长序列数据。
2.3 应用案例 - 股价预测
假设我们要预测某只股票未来的价格走势。我们将过去一段时间(比如 100 天)的股票价格作为输入序列,通过 LSTM 模型进行训练。在训练过程中,LSTM 可以学习到股票价格之间的长期依赖关系,比如某些宏观经济因素对股价的长期影响。当训练完成后,我们可以输入最近一段时间的股价,让模型预测未来几天的股价。与传统的时间序列预测方法相比,LSTM 能够更好地捕捉股价波动中的复杂模式,从而提高预测的准确性。
三、门控循环单元(GRU)
3.1 结构与原理
GRU 可以看作是 LSTM 的简化版本。它将 LSTM 中的遗忘门和输入门合并为一个更新门(Update Gate),同时取消了单独的细胞状态,直接通过隐藏状态传递信息。 更新门 <math xmlns="http://www.w3.org/1998/Math/MathML"> z t z_t </math>zt的计算公式为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> z _ t = σ ( W _ x z x _ t + W _ h z h _ t − 1 + b _ z ) z\t=σ(W\{xz}x\t+W\{hz}h\_{t−1}+b\_z) </math>z_t=σ(W_xzx_t+W_hzh_t−1+b_z)
重置门(Reset Gate) <math xmlns="http://www.w3.org/1998/Math/MathML"> r t r_t </math>rt的计算公式为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> r _ t = σ ( W _ x r x _ t + W _ h r h _ t − 1 + b _ r ) r\t=σ(W\{xr}x\t +W\{hr} h\_{t−1} +b\_r) </math>r_t=σ(W_xrx_t+W_hrh_t−1+b_r)
候选隐藏状态 <math xmlns="http://www.w3.org/1998/Math/MathML"> h ~ t \tilde h_t </math>h~t的计算公式为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> h ~ _ t = t a n h ( W _ x h x _ t + r _ t ⊙ ( W _ h h h _ t − 1 ) + b _ h ) \tilde h\t=tanh(W\{xh}x\t +r\t ⊙(W\{hh}h\{t−1})+b\_h) </math>h~_t=tanh(W_xhx_t+r_t⊙(W_hhh_t−1)+b_h)
最终的隐藏状态 <math xmlns="http://www.w3.org/1998/Math/MathML"> h t h_t </math>ht 为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> h _ t = ( 1 − z _ t ) ⊙ h _ t − 1 + z _ t ⊙ h ~ _ t h\_t =(1−z\t )⊙h\{t−1} +z\_t ⊙\tilde h\_t </math>h_t=(1−z_t)⊙h_t−1+z_t⊙h~_t
GRU 结构示意图:

3.2 与 LSTM 的比较
与 LSTM 相比,GRU 结构更简单,参数更少,因此训练速度更快。在一些对实时性要求较高或者数据量较小的场景中,GRU 可能会表现得更好。但在处理非常复杂的长序列数据时,LSTM 由于其更精细的门控机制,可能会取得更好的效果。
模型 | 结构特点 | 参数数量 | 训练速度 | 适用场景 |
---|---|---|---|---|
LSTM | 有细胞状态,三个门控 | 较多 | 较慢 | 复杂长序列数据 |
GRU | 无细胞状态,两个门控 | 较少 | 较快 | 实时性要求高或数据量小 |
3.3 应用案例 - 实时语音识别
在实时语音识别系统中,需要快速处理连续的语音流数据。GRU 由于其简单高效的结构,能够在保证一定准确率的前提下,快速对语音数据进行处理和识别。它可以实时地将输入的语音信号转换为文字,满足人们在语音交互场景中的需求。
总结
循环神经网络(RNN)为处理序列数据提供了基础框架,但其在长序列处理上的局限性促使了长短期记忆网络(LSTM)和门控循环单元(GRU)的诞生。LSTM 通过精细的门控机制和细胞状态,有效地解决了梯度问题,能够处理复杂的长序列数据。GRU 则在保持一定性能的同时,通过简化结构提高了训练效率。在实际应用中,我们需要根据具体任务的特点和需求,选择合适的模型。希望通过本文的介绍,你对 RNN、LSTM 和 GRU 有了更深入的理解,并能够在自己的项目中灵活运用它们。