RNN循环神经网络概述
RNN(Recurrent Neural Network)是一种处理序列数据的神经网络结构,其核心特点是具有循环连接,允许信息在网络中持久化。这种特性使其适合处理时间序列、自然语言等具有时序关系的数据。
基本结构
RNN的隐藏层神经元不仅接收当前时刻的输入,还接收上一时刻的隐藏状态。数学表达式为: h_t = \\sigma(W_{xh}x_t + W_{hh}h_{t-1} + b_h) y_t = W_{hy}h_t + b_y 其中:
- ( h_t ) 是当前时刻的隐藏状态
- ( x_t ) 是当前输入
- ( W ) 为权重矩阵
- ( \sigma ) 为激活函数(如tanh)
常见变体
LSTM(长短期记忆网络) 通过引入门控机制(输入门、遗忘门、输出门)解决梯度消失问题。计算公式: f_t = \\sigma(W_f \\cdot \[h_{t-1}, x_t + b_f) ] i_t = \\sigma(W_i \\cdot \[h_{t-1}, x_t + b_i) ] \\tilde{C}*t = \\tanh(W_C \\cdot \[h* {t-1}, x_t + b_C) ] C_t = f_t \\odot C_{t-1} + i_t \\odot \\tilde{C}*t o_t = \\sigma(W_o \\cdot \[h*{t-1}, x_t + b_o) ] h_t = o_t \\odot \\tanh(C_t)
GRU(门控循环单元) 简化版LSTM,合并遗忘门和输入门为更新门: z_t = \\sigma(W_z \\cdot \[h_{t-1}, x_t) ] r_t = \\sigma(W_r \\cdot \[h_{t-1}, x_t) ] \\tilde{h}*t = \\tanh(W \\cdot \[r_t \\odot h*{t-1}, x_t) ] h_t = (1 - z_t) \\odot h_{t-1} + z_t \\odot \\tilde{h}_t
应用场景
- 机器翻译:序列到序列建模
- 语音识别:时序信号处理
- 股票预测:时间序列分析
- 文本生成:字符/单词级预测
PyTorch实现示例
python
import torch.nn as nn
class SimpleRNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
out, _ = self.rnn(x) # out.shape: (batch, seq_len, hidden_size)
return self.fc(out[:, -1, :]) # 取最后一个时间步输出
# LSTM示例
lstm = nn.LSTM(input_size=10, hidden_size=20, num_layers=2)
训练技巧
梯度裁剪:防止梯度爆炸
python
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
双向RNN:捕获前后文信息
python
nn.LSTM(..., bidirectional=True)
序列填充:处理变长序列
python
from torch.nn.utils.rnn import pad_sequence
padded = pad_sequence(sequences, batch_first=True)
局限性
- 长序列处理能力有限(尽管LSTM/GRU有所改善)
- 并行计算效率低于Transformer
- 对近期输入存在偏置
实际应用中,Transformer架构在多数序列任务中已取代RNN,但在资源受限或需要在线学习的场景中,RNN仍具实用价值。