Day 19:LSTM与时间序列预测

Day 19:LSTM与时间序列预测

📋 目录

  1. RNN基础与序列数据
  2. 梯度消失与爆炸问题
  3. LSTM原理与门控机制
  4. LSTM vs 简单RNN
  5. 时间序列预测的注意事项

第一部分:RNN基础与序列数据

1.1 什么是序列数据?

序列数据:数据点之间存在顺序依赖关系。

常见序列数据

  • 时间序列(股价、天气、传感器)
  • 文本(单词序列)
  • 音频(信号序列)
  • 视频(帧序列)

1.2 为什么需要RNN?

传统神经网络的局限

  • 假设输入之间独立
  • 无法捕捉序列中的时间依赖

RNN的核心思想:网络具有"记忆",当前输出依赖之前的输入。

1.3 简单RNN结构

text 复制代码
        y₁      y₂      y₃
        ↑       ↑       ↑
        h₁  ←   h₂  ←   h₃
        ↑       ↑       ↑
        x₁      x₂      x₃

数学形式:
ht=tanh⁡(Wxhxt+Whhht−1+bh) h_t = \tanh(W_{xh}x_t + W_{hh}h_{t-1} + b_h) ht=tanh(Wxhxt+Whhht−1+bh)

yt=Whyht+by y_t = W_{hy}h_t + b_y yt=Whyht+by

其中

  • xtx_txt :ttt 时刻的输入
  • hth_tht:ttt 时刻的隐藏状态(记忆)
  • yty_tyt:ttt 时刻的输出

1.4 RNN的单元展开

python 复制代码
# RNN单元的时间展开
# 同一组参数在不同时间步重复使用
cell = RNNCell(weights)
h_0 = initial_state

for t in range(T):
    h_t = cell(x_t, h_{t-1})
    y_t = output_layer(h_t)

第二部分:梯度消失与爆炸问题

2.1 问题本质

长期依赖问题:RNN难以学习长期序列中的依赖关系。

原因:反向传播时,梯度随时间步长指数级衰减或增长。

2.2 梯度消失

数学解释: 在反向传播中,梯度乘以 W_{hh}\^T 多次:

∂Lt∂h1=∂Lt∂ht∏k=2t∂hk∂hk−1=∂Lt∂ht∏k=2tWhhTdiag(tanh⁡′(hk−1)) \frac{\partial L_t}{\partial h_1} = \frac{\partial L_t}{\partial h_t} \prod_{k=2}^{t} \frac{\partial h_k}{\partial h_{k-1}} = \frac{\partial L_t}{\partial h_t} \prod_{k=2}^{t} W_{hh}^T \text{diag}\left( \tanh'(h_{k-1}) \right) ∂h1∂Lt=∂ht∂Ltk=2∏t∂hk−1∂hk=∂ht∂Ltk=2∏tWhhTdiag(tanh′(hk−1))

当 ∥Whh∥<1\| W_{hh} \| < 1∥Whh∥<1:梯度指数级衰减 →\to→ 梯度消失

后果

  • 距离远的输入对当前输出影响小
  • 网络无法学习长期依赖

2.3 梯度爆炸

当 ∥Whh∥>1\| W_{hh} \| > 1∥Whh∥>1:梯度指数级增长 → 梯度爆炸

后果

  • 参数更新过大
  • 训练不稳定
  • Loss变为NaN

2.4 解决方案

问题 解决方案
梯度消失 LSTM, GRU, 残差连接
梯度爆炸 梯度裁剪(Gradient Clipping)

梯度裁剪

python 复制代码
# Keras中的梯度裁剪
optimizer = Adam(clipnorm=1.0)  # 按范数裁剪
optimizer = Adam(clipvalue=0.5) # 按值裁剪

第三部分:LSTM原理与门控机制

3.1 LSTM的核心思想

LSTM (Long Short-Term Memory) 通过门控机制控制信息的流动,解决了简单RNN的梯度消失问题。

关键创新

  • 细胞状态(Cell State):信息高速公路
  • 三个门:遗忘门、输入门、输出门

3.2 LSTM结构图

text 复制代码
        C_{t-1} ─────────────────────────→ C_t
                  ↖                      ↗
                   ×      +      ×
                   ↑      ↑      ↑
                  σ       σ     tanh
                   ↑      ↑      ↑
   h_{t-1} ────→  └──────┴──────┘ ←──── x_t
        ↓                        ↓
        h_t ←────────────────────┘

3.3 遗忘门(Forget Gate)

作用:决定从细胞状态中丢弃哪些信息。

ft=σ(Wf⋅[ht−1,xt]+bf) f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) ft=σ(Wf⋅[ht−1,xt]+bf)

  • 输出0-1之间
  • 1表示"完全保留",0表示"完全丢弃"

3.4 输入门(Input Gate)

作用:决定将哪些新信息存入细胞状态。

it=σ(Wi⋅[ht−1,xt]+bi) i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) it=σ(Wi⋅[ht−1,xt]+bi)

C~t=tanh⁡(WC⋅[ht−1,xt]+bC) \tilde{C}t = \tanh(W_C \cdot [h{t-1}, x_t] + b_C) C~t=tanh(WC⋅[ht−1,xt]+bC)

3.5 细胞状态更新

作用:更新长期记忆。

Ct=ft⊙Ct−1+it⊙C~t C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t Ct=ft⊙Ct−1+it⊙C~t

  • 保留旧信息(遗忘门控制)
  • 添加新信息(输入门控制)

3.6 输出门(Output Gate)

作用:决定输出什么信息。

ot=σ(Wo⋅[ht−1,xt]+bo) o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) ot=σ(Wo⋅[ht−1,xt]+bo)

ht=ot⊙tanh⁡(Ct) h_t = o_t \odot \tanh(C_t) ht=ot⊙tanh(Ct)

3.7 LSTM vs 简单RNN对比

特性 简单RNN LSTM
长期记忆 优秀
梯度消失 严重 显著缓解
参数量 多(4倍)
训练速度
适用场景 短期依赖 长期依赖

第四部分:LSTM vs 简单MLP

4.1 MLP处理序列的问题

MLP无法处理变长序列

  • 输入维度固定
  • 无法捕捉时间顺序

MLP的"时间"处理方式

  • 将序列展平为向量
  • 丢失了时间结构信息

4.2 时序数据建模对比

数据建模方式 输入形状 能否捕捉时间依赖
MLP(展平) (seq_len, features)
RNN/LSTM (seq_len, features)
CNN (seq_len, features) 局部

4.3 预测任务类型

任务类型 说明 示例
多步输入,单步输出 用N天预测明天 股价预测
多步输入,多步输出 用N天预测未来M天 未来5日走势
单步输入,多步输出 用当前预测未来序列 生成序列

第五部分:时间序列预测的注意事项

5.1 数据预处理

python 复制代码
# 1. 标准化(重要!)
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
scaled_data = scaler.fit_transform(data)

# 2. 创建滑动窗口
def create_sequences(data, seq_length):
    X, y = [], []
    for i in range(len(data) - seq_length):
        X.append(data[i:i+seq_length])
        y.append(data[i+seq_length])
    return np.array(X), np.array(y)

5.2 数据泄露防范

关键原则:不能使用未来信息!

python 复制代码
# 正确做法:只用历史数据预测未来
split_idx = int(len(X) * 0.7)
X_train = X[:split_idx]
X_test = X[split_idx:]

5.3 评估指标

指标 公式 适用场景
MSE 1n∑(y−y^)2\frac 1n \sum (y − \hat{y})^2n1∑(y−y^)2 通用
MAE 1n∑∣y−y^∣\frac{1}{n} \sum |y-\hat{y} |n1∑∣y−y^∣ 通用,鲁棒
MAPE 1n∑∣y−y^y∣\frac{1}{n}\sum |\frac{y-\hat{y}}{y}|n1∑∣yy−y^∣ 相对误差
Direction Accuracy 正确预测趋势次数n\frac{正确预测趋势次数}{n}n正确预测趋势次数 金融涨跌预测
相关推荐
索木木2 小时前
Flash Attention反向梯度优化显存
人工智能·机器学习·大模型·attention·训练·显存优化·aiinfra
mit6.8242 小时前
[CS153]AI基础设施与技术栈
人工智能
量子-Alex2 小时前
【大模型智能体】AutoFlow:大型语言模型代理的自动化工作流生成
人工智能·语言模型·自动化
Wzx1980122 小时前
cozen平台开发智能体
人工智能
GISer_Jing2 小时前
AI原生前端工程化进阶实践:从流式交互架构到端云协同全链路落地
前端·人工智能·后端·学习
EnCi Zheng2 小时前
03ab-PyTorch安装教程 [特殊字符]
人工智能·pytorch·python
SmartBrain2 小时前
从Prompt工程到Harness工程:AI Agent落地之路
人工智能·python·华为·aigc
科技小花9 小时前
全球化深水区,数据治理成为企业出海 “核心竞争力”
大数据·数据库·人工智能·数据治理·数据中台·全球化
zhuiyisuifeng10 小时前
2026前瞻:GPTimage2镜像官网或将颠覆视觉创作
人工智能·gpt