Day 19:LSTM与时间序列预测
📋 目录
- RNN基础与序列数据
- 梯度消失与爆炸问题
- LSTM原理与门控机制
- LSTM vs 简单RNN
- 时间序列预测的注意事项
第一部分:RNN基础与序列数据
1.1 什么是序列数据?
序列数据:数据点之间存在顺序依赖关系。
常见序列数据:
- 时间序列(股价、天气、传感器)
- 文本(单词序列)
- 音频(信号序列)
- 视频(帧序列)
1.2 为什么需要RNN?
传统神经网络的局限:
- 假设输入之间独立
- 无法捕捉序列中的时间依赖
RNN的核心思想:网络具有"记忆",当前输出依赖之前的输入。
1.3 简单RNN结构
text
y₁ y₂ y₃
↑ ↑ ↑
h₁ ← h₂ ← h₃
↑ ↑ ↑
x₁ x₂ x₃
数学形式:
ht=tanh(Wxhxt+Whhht−1+bh) h_t = \tanh(W_{xh}x_t + W_{hh}h_{t-1} + b_h) ht=tanh(Wxhxt+Whhht−1+bh)
yt=Whyht+by y_t = W_{hy}h_t + b_y yt=Whyht+by
其中:
- xtx_txt :ttt 时刻的输入
- hth_tht:ttt 时刻的隐藏状态(记忆)
- yty_tyt:ttt 时刻的输出
1.4 RNN的单元展开
python
# RNN单元的时间展开
# 同一组参数在不同时间步重复使用
cell = RNNCell(weights)
h_0 = initial_state
for t in range(T):
h_t = cell(x_t, h_{t-1})
y_t = output_layer(h_t)
第二部分:梯度消失与爆炸问题
2.1 问题本质
长期依赖问题:RNN难以学习长期序列中的依赖关系。
原因:反向传播时,梯度随时间步长指数级衰减或增长。
2.2 梯度消失
数学解释: 在反向传播中,梯度乘以 W_{hh}\^T 多次:
∂Lt∂h1=∂Lt∂ht∏k=2t∂hk∂hk−1=∂Lt∂ht∏k=2tWhhTdiag(tanh′(hk−1)) \frac{\partial L_t}{\partial h_1} = \frac{\partial L_t}{\partial h_t} \prod_{k=2}^{t} \frac{\partial h_k}{\partial h_{k-1}} = \frac{\partial L_t}{\partial h_t} \prod_{k=2}^{t} W_{hh}^T \text{diag}\left( \tanh'(h_{k-1}) \right) ∂h1∂Lt=∂ht∂Ltk=2∏t∂hk−1∂hk=∂ht∂Ltk=2∏tWhhTdiag(tanh′(hk−1))
当 ∥Whh∥<1\| W_{hh} \| < 1∥Whh∥<1:梯度指数级衰减 →\to→ 梯度消失
后果:
- 距离远的输入对当前输出影响小
- 网络无法学习长期依赖
2.3 梯度爆炸
当 ∥Whh∥>1\| W_{hh} \| > 1∥Whh∥>1:梯度指数级增长 → 梯度爆炸
后果:
- 参数更新过大
- 训练不稳定
- Loss变为NaN
2.4 解决方案
| 问题 | 解决方案 |
|---|---|
| 梯度消失 | LSTM, GRU, 残差连接 |
| 梯度爆炸 | 梯度裁剪(Gradient Clipping) |
梯度裁剪:
python
# Keras中的梯度裁剪
optimizer = Adam(clipnorm=1.0) # 按范数裁剪
optimizer = Adam(clipvalue=0.5) # 按值裁剪
第三部分:LSTM原理与门控机制
3.1 LSTM的核心思想
LSTM (Long Short-Term Memory) 通过门控机制控制信息的流动,解决了简单RNN的梯度消失问题。
关键创新:
- 细胞状态(Cell State):信息高速公路
- 三个门:遗忘门、输入门、输出门
3.2 LSTM结构图
text
C_{t-1} ─────────────────────────→ C_t
↖ ↗
× + ×
↑ ↑ ↑
σ σ tanh
↑ ↑ ↑
h_{t-1} ────→ └──────┴──────┘ ←──── x_t
↓ ↓
h_t ←────────────────────┘
3.3 遗忘门(Forget Gate)
作用:决定从细胞状态中丢弃哪些信息。
ft=σ(Wf⋅[ht−1,xt]+bf) f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) ft=σ(Wf⋅[ht−1,xt]+bf)
- 输出0-1之间
- 1表示"完全保留",0表示"完全丢弃"
3.4 输入门(Input Gate)
作用:决定将哪些新信息存入细胞状态。
it=σ(Wi⋅[ht−1,xt]+bi) i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) it=σ(Wi⋅[ht−1,xt]+bi)
C~t=tanh(WC⋅[ht−1,xt]+bC) \tilde{C}t = \tanh(W_C \cdot [h{t-1}, x_t] + b_C) C~t=tanh(WC⋅[ht−1,xt]+bC)
3.5 细胞状态更新
作用:更新长期记忆。
Ct=ft⊙Ct−1+it⊙C~t C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t Ct=ft⊙Ct−1+it⊙C~t
- 保留旧信息(遗忘门控制)
- 添加新信息(输入门控制)
3.6 输出门(Output Gate)
作用:决定输出什么信息。
ot=σ(Wo⋅[ht−1,xt]+bo) o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) ot=σ(Wo⋅[ht−1,xt]+bo)
ht=ot⊙tanh(Ct) h_t = o_t \odot \tanh(C_t) ht=ot⊙tanh(Ct)
3.7 LSTM vs 简单RNN对比
| 特性 | 简单RNN | LSTM |
|---|---|---|
| 长期记忆 | 差 | 优秀 |
| 梯度消失 | 严重 | 显著缓解 |
| 参数量 | 少 | 多(4倍) |
| 训练速度 | 快 | 慢 |
| 适用场景 | 短期依赖 | 长期依赖 |
第四部分:LSTM vs 简单MLP
4.1 MLP处理序列的问题
MLP无法处理变长序列:
- 输入维度固定
- 无法捕捉时间顺序
MLP的"时间"处理方式:
- 将序列展平为向量
- 丢失了时间结构信息
4.2 时序数据建模对比
| 数据建模方式 | 输入形状 | 能否捕捉时间依赖 |
|---|---|---|
| MLP(展平) | (seq_len, features) | 否 |
| RNN/LSTM | (seq_len, features) | 是 |
| CNN | (seq_len, features) | 局部 |
4.3 预测任务类型
| 任务类型 | 说明 | 示例 |
|---|---|---|
| 多步输入,单步输出 | 用N天预测明天 | 股价预测 |
| 多步输入,多步输出 | 用N天预测未来M天 | 未来5日走势 |
| 单步输入,多步输出 | 用当前预测未来序列 | 生成序列 |
第五部分:时间序列预测的注意事项
5.1 数据预处理
python
# 1. 标准化(重要!)
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
scaled_data = scaler.fit_transform(data)
# 2. 创建滑动窗口
def create_sequences(data, seq_length):
X, y = [], []
for i in range(len(data) - seq_length):
X.append(data[i:i+seq_length])
y.append(data[i+seq_length])
return np.array(X), np.array(y)
5.2 数据泄露防范
关键原则:不能使用未来信息!
python
# 正确做法:只用历史数据预测未来
split_idx = int(len(X) * 0.7)
X_train = X[:split_idx]
X_test = X[split_idx:]
5.3 评估指标
| 指标 | 公式 | 适用场景 |
|---|---|---|
| MSE | 1n∑(y−y^)2\frac 1n \sum (y − \hat{y})^2n1∑(y−y^)2 | 通用 |
| MAE | 1n∑∣y−y^∣\frac{1}{n} \sum |y-\hat{y} |n1∑∣y−y^∣ | 通用,鲁棒 |
| MAPE | 1n∑∣y−y^y∣\frac{1}{n}\sum |\frac{y-\hat{y}}{y}|n1∑∣yy−y^∣ | 相对误差 |
| Direction Accuracy | 正确预测趋势次数n\frac{正确预测趋势次数}{n}n正确预测趋势次数 | 金融涨跌预测 |