1. RNN 的基本概念
循环神经网络(Recurrent Neural Network, RNN)是一类 用于处理序列数据 的神经网络。与传统的前馈神经网络(如 MLP、CNN)不同,RNN 在网络结构中引入了 时间上的循环 ,能够记忆历史信息并用于当前的预测。
典型应用包括:
-
自然语言处理(语言建模、机器翻译、情感分析)
-
时间序列预测(股价预测、天气预报)
-
语音识别
RNN 的核心思想是:在处理序列第 步数据时,模型会考虑前一时刻
的隐藏状态,从而引入记忆能力。
2. RNN 的数学原理与公式推导
2.1 基本结构
在 RNN 中,每个时间步的隐藏状态 由当前输入
和前一时刻隐藏状态
共同决定。其计算公式为:
其中:
-
:第 t 个时间步的输入向量
-
:第 t 个时间步的隐藏状态
-
:隐藏层权重矩阵
-
:输入权重矩阵
-
:偏置项
-
:激活函数(常用
或 ReLU)
输出层通常为:
其中 一般为 softmax 或 sigmoid,用于分类或预测。
2.2 序列展开(Unrolling)
RNN 在时间维度上可以展开为一个链式结构:
这样,RNN 就可以对序列进行逐步建模。
2.3 损失函数
对于一个序列长度为 的样本,RNN 的损失通常为所有时间步的损失和:
其中 为交叉熵或均方误差。
3. RNN 的优缺点
✅ 优点
-
能够处理任意长度的序列数据。
-
参数共享:不同时间步使用相同的权重,降低参数数量。
-
在语音、NLP、时间序列预测等任务中表现良好。
❌ 缺点
-
梯度消失 / 梯度爆炸:长序列训练时,梯度可能随时间步衰减或爆炸。
-
并行计算困难:序列依赖关系导致训练速度较慢。
-
长期依赖问题:对远距离的序列信息捕捉能力不足(需 LSTM/GRU 改进)。
4. RNN 的应用场景
-
自然语言处理:机器翻译、语音识别、文本生成。
-
金融预测:股价走势建模。
-
信号处理:传感器数据分析、语音识别。
5. Python 实现(PyTorch)
python
import torch
import torch.nn as nn
# 定义一个简单的RNN模型
class RNNModel(nn.Module):
def __init__(self, input_size, hidden_size, output_size, num_layers=1):
super(RNNModel, self).__init__()
self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
out, h = self.rnn(x) # out: (batch, seq_len, hidden_size)
out = self.fc(out[:, -1, :]) # 只取最后时间步
return out
# 示例输入
batch_size, seq_len, input_size = 16, 10, 8
hidden_size, output_size = 32, 2
x = torch.randn(batch_size, seq_len, input_size)
# 模型
model = RNNModel(input_size, hidden_size, output_size)
y_pred = model(x)
print("输入 shape:", x.shape)
print("输出 shape:", y_pred.shape)
6. 总结
-
RNN 是处理序列数据的经典模型,能够记忆上下文信息。
-
数学本质:通过递归公式
进行时序建模。
-
存在梯度消失问题,通常使用 LSTM 和 GRU 改进。
-
应用广泛,尤其在 NLP、语音识别、时间序列预测领域