RNN(循环神经网络)

介绍

循环神经网络(Recurrent Neural Network,RNN)是一种强大的神经网络结构,主要用于处理序列数据,例如时间序列、文本、语音等。它通过引入循环结构,能够利用序列中的历史信息来预测当前的输出,非常适合处理具有时间依赖性或顺序关系的数据。
总的来讲:RNN的核心是记忆 每个时间点的输出都收到之前时间点的影响

从输出来考虑,RNN有1-1的,1-多的,多-1的,也有多-多的

基本原理

循环神经网络的核心在于其"循环"的结构。在传统的神经网络中,输入和输出是独立的,而在RNN中,网络的输出不仅取决于当前的输入,还依赖于之前时刻的输出(即隐藏状态)。这种结构使得RNN能够记住之前的信息,并将其用于当前的计算。

假设我们有一个序列 x=(x_1,x_2,...,x_T),RNN会逐个处理序列中的元素,并在每一步更新其隐藏状态 h_t。隐藏状态的更新公式通常为:

其中,f 是一个非线性函数,通常是一个激活函数(如tanh或ReLU)。最终,RNN可以根据隐藏状态生成输出 y_t,输出的计算公式为:

其中,g 是另一个非线性函数。

数学表达

循环神经网络(RNN)的数学表达式主要包括隐藏状态的更新和输出的计算。

以下是RNN的基本数学表达式:

1. 隐藏状态的更新

在RNN中,隐藏状态h_t是通过当前输入x_t和前一时刻的隐藏状态 h_t−1来更新的。具体的更新公式为:

其中:

  • h_t是第 t 时刻的隐藏状态
  • x_t是第 t 时刻的输入
  • W_hh是隐藏状态到隐藏状态的权重矩阵
  • W_xh是输入到隐藏状态的权重矩阵
  • b_h是隐藏状态的偏置项
  • tanh 是激活函数,通常使用双曲正切函数(tanh),也可以使用其他非线性激活函数(如ReLU)

2. 输出的计算

在某些情况下,RNN的输出y_t可以直接从隐藏状态 h_t计算得到。输出的计算公式为:

其中:

  • y_t是第 t 时刻的输出
  • W_hy是隐藏状态到输出的权重矩阵
  • b_y是输出的偏置项

完整的RNN过程就是先完成隐藏状态更新,再进行输出计算

3. 初始化

在开始处理序列之前,需要初始化隐藏状态h_0。通常初始化为零向量:

4. 训练过程

在训练RNN时,通常使用反向传播通过时间(Backpropagation Through Time,BPTT)来计算梯度并更新权重。BPTT将RNN展开为一个时间步长的序列,并对每个时间步长进行反向传播。

梯度消失和梯度爆炸问题

在实际应用中,RNN可能会遇到梯度消失或梯度爆炸的问题,这使得训练长序列变得困难:

梯度消失问题(Vanishing Gradient Problem)

梯度消失问题是指在训练过程中,随着网络层数的增加或时间步长的增加,梯度逐渐变得非常小,导致网络的权重更新非常缓慢,甚至无法更新。这使得网络难以学习到长距离的依赖关系。

原因

在RNN中,隐藏状态的更新公式为:

在反向传播过程中,梯度会通过链式法则逐层传递。对于RNN,梯度会沿着时间步长反向传播,即:

其中:

由于 tanh 函数的导数范围在 (0,1) 之间,当多个这样的导数相乘时,梯度会迅速衰减。例如,假设每个时间步长的梯度乘积为 0.9,那么经过10个时间步长后,梯度会变为 0.910≈0.3487,经过20个时间步长后,梯度会变为 0.920≈0.1216。这种衰减使得网络难以学习到长距离的依赖关系。

梯度爆炸问题(Exploding Gradient Problem)

梯度爆炸问题是指在训练过程中,随着网络层数的增加或时间步长的增加,梯度逐渐变得非常大,导致网络的权重更新过大,甚至导致数值不稳定。这使得网络的训练变得非常困难,甚至无法收敛。

原因

与梯度消失问题相反,梯度爆炸问题通常发生在权重矩阵的范数较大时。在RNN中,如果权重矩阵 W_hh的范数较大,那么在反向传播过程中,梯度会迅速增大。例如,假设每个时间步长的梯度乘积为 1.1,那么经过10个时间步长后,梯度会变为 1.110

≈2.5937,经过20个时间步长后,梯度会变为 1.120≈6.7275。这种快速增大会导致权重更新过大,甚至导致数值不稳定。

梯度爆炸问题的解决方案

梯度裁剪(Gradient Clipping):

在反向传播过程中,对梯度进行裁剪,使其不超过某个阈值,防止梯度过大。

权重正则化:

使用权重正则化(如L2正则化)来限制权重的范数,防止权重过大。

改进的RNN架构:

使用LSTM或GRU等改进的RNN架构,这些架构通过引入门控机制来控制信息的流动,防止梯度爆炸。

RNN创建

dart 复制代码
from torch import nn
rnn_cell = nn.RNNCell(input_size=3, hidden_size=2)
print(rnn_cell)

用如上方法,我们可以很简单就创建一个隐藏向量为2维,输入向量为3维的RNN

现在,我们来用 PyTorch 实现了一个最简单、最朴素的 RNN 前向计算,用来演示 nn.RNN 的输入和输出长什么样。

导入所需模块

dart 复制代码
import torch
import torch.nn as nn

建网络

dart 复制代码
rnn = nn.RNN(5, 6, 2)

输入特征维度 input_size = 5:
每个时间步喂给 RNN 的向量长度是 5。
隐藏层维度 hidden_size = 6:
每个 RNN 单元内部有 6 个隐藏神经元,所以隐藏状态 h_t 是 6 维向量。
层数 num_layers = 2:
纵向堆了 2 层 RNN(第一层输出作为第二层输入)。

随机生成输入

dart 复制代码
input1 = torch.randn(1, 3, 5)

形状 (seq_len, batch, input_size) = (1, 3, 5)
序列长度 seq_len = 1:每个样本只有 1 个时间步。
批大小 batch = 3:每个时间步的行向量长度为 3。
特征维度 input_size = 5:每个时间步的列向量长度为 5(和网络配置一致)。

初始隐藏状态

dart 复制代码
h0 = torch.randn(2, 3, 6)

形状 (num_layers, batch, hidden_size) = (2, 3, 6)
2 层 RNN → 第一层、第二层各需要一个 (batch, hidden_size) 的隐藏状态。
这里直接随机初始化,实际训练时通常用零向量或学习到的初始状态。

前向计算

dart 复制代码
output, hn = rnn(input1, h0)

output:形状 (seq_len, batch, hidden_size) = (1, 3, 6)
保存每一层在最后一个时间步的输出。因为 seq_len = 1,所以只有 1 个时间步,也就是这 1 步的输出。
hn:形状 (num_layers, batch, hidden_size) = (2, 3, 6)
保存最后一层在每个时间步后的隐藏状态。因为是单层序列,所以这里直接就是最后一个时间步的隐藏状态。

网络结构

把它画成"两层、单向、无 bias、tanh 激活"的 RNN,就像下面这张"垂直切片"图:

dart 复制代码
时间步 t(只有 1 个)      输出
┌------------┐
│  x_t(5)    │ ─┐
└------------┘  │  ┌----------------Layer 0----------------┐
                 │  │   (6 个 tanh 单元)                    │
                 │  │   ┌-----┐  ┌-----┐        ┌-----┐   │
                 │  │   │     │  │     │  ...   │     │   │
                 │  │   │ tanh│  │ tanh│        │ tanh│   │
                 └─>│   └-----┘  └-----┘        └-----┘   │
                    │   ↑ 6 维 h_t^{(0)}                   │
                    │   │                                   │
                    │   │                                   │
                    │   │                                   │
                    │   │  ┌----------------Layer 1---------------┐
                    │   └─>│   (6 个 tanh 单元)                   │
                    │      │   ┌-----┐  ┌-----┐        ┌-----┐  │
                    │      │   │     │  │     │  ...   │     │  │
                    │      │   │ tanh│  │ tanh│        │ tanh│  │
                    └----->│   └-----┘  └-----┘        └-----┘  │
                           │   ↑ 6 维 h_t^{(1)}                  │
                           │   │                                   │
                           │   └--------------→ output_t(6)        │
                           └----------------------------------------┘

打印结果

dart 复制代码
print(input1)         #形状(1,3,5)
print(output)         # 形状 (1, 3, 6)
print(output.shape)   # torch.Size([1, 3, 6])
print(hn)             # 形状 (2, 3, 6)
print(hn.shape)       # torch.Size([2, 3, 6])

output:

dart 复制代码
tensor([[[ 0.1578, -0.6580,  1.4173,  0.9465, -0.3168],
         [ 0.9392,  0.6594,  0.2955,  2.3961,  0.9437],
         [-0.3017,  0.6763,  0.2471, -1.1089,  0.3852]]])
tensor([[[ 0.1992,  0.4487,  0.7773,  0.4494,  0.3454, -0.4028],
         [-0.6498, -0.2264,  0.7769,  0.2682,  0.3949,  0.7619],
         [-0.4186, -0.9363, -0.8965, -0.0036, -0.8698, -0.3272]]],
       grad_fn=<StackBackward0>)
torch.Size([1, 3, 6])
tensor([[[-0.8587, -0.0270,  0.3406,  0.6439, -0.2807, -0.5318],
         [-0.1035, -0.4623, -0.5111,  0.0668, -0.4180, -0.1708],
         [-0.2161,  0.5315,  0.7317,  0.3937, -0.7685, -0.6344]],

        [[ 0.1992,  0.4487,  0.7773,  0.4494,  0.3454, -0.4028],
         [-0.6498, -0.2264,  0.7769,  0.2682,  0.3949,  0.7619],
         [-0.4186, -0.9363, -0.8965, -0.0036, -0.8698, -0.3272]]],
       grad_fn=<StackBackward0>)
torch.Size([2, 3, 6])
相关推荐
python_tty9 分钟前
排序算法(二):插入排序
算法·排序算法
然我17 分钟前
面试官:如何判断元素是否出现过?我:三种哈希方法任你选
前端·javascript·算法
F_D_Z44 分钟前
【EM算法】三硬币模型
算法·机器学习·概率论·em算法·极大似然估计
秋说1 小时前
【PTA数据结构 | C语言版】字符串插入操作(不限长)
c语言·数据结构·算法
机器之心2 小时前
马斯克Grok这个二次元「小姐姐」,攻陷了整个互联网
人工智能
凌肖战2 小时前
力扣网编程135题:分发糖果(贪心算法)
算法·leetcode
szxinmai主板定制专家2 小时前
基于光栅传感器+FPGA+ARM的测量控制解决方案
arm开发·人工智能·嵌入式硬件·fpga开发
Guheyunyi2 小时前
电气安全监测系统:筑牢电气安全防线
大数据·运维·网络·人工智能·安全·架构
三桥君2 小时前
在AI应用中Prompt撰写重要却难掌握,‘理解模型与行业知识是关键’:提升迫在眉睫
人工智能·ai·系统架构·prompt·产品经理·三桥君
semantist@语校2 小时前
日本语言学校:签证制度类 Prompt 的结构整理路径与策略
人工智能·百度·ai·语言模型·prompt·github·数据集