Day 19:LSTM与时间序列预测

Day 19:LSTM与时间序列预测

📋 目录

  1. RNN基础与序列数据
  2. 梯度消失与爆炸问题
  3. LSTM原理与门控机制
  4. LSTM vs 简单RNN
  5. 时间序列预测的注意事项

第一部分:RNN基础与序列数据

1.1 什么是序列数据?

序列数据:数据点之间存在顺序依赖关系。

常见序列数据

  • 时间序列(股价、天气、传感器)
  • 文本(单词序列)
  • 音频(信号序列)
  • 视频(帧序列)

1.2 为什么需要RNN?

传统神经网络的局限

  • 假设输入之间独立
  • 无法捕捉序列中的时间依赖

RNN的核心思想:网络具有"记忆",当前输出依赖之前的输入。

1.3 简单RNN结构

text 复制代码
        y₁      y₂      y₃
        ↑       ↑       ↑
        h₁  ←   h₂  ←   h₃
        ↑       ↑       ↑
        x₁      x₂      x₃

数学形式:
ht=tanh⁡(Wxhxt+Whhht−1+bh) h_t = \tanh(W_{xh}x_t + W_{hh}h_{t-1} + b_h) ht=tanh(Wxhxt+Whhht−1+bh)

yt=Whyht+by y_t = W_{hy}h_t + b_y yt=Whyht+by

其中

  • xtx_txt :ttt 时刻的输入
  • hth_tht:ttt 时刻的隐藏状态(记忆)
  • yty_tyt:ttt 时刻的输出

1.4 RNN的单元展开

python 复制代码
# RNN单元的时间展开
# 同一组参数在不同时间步重复使用
cell = RNNCell(weights)
h_0 = initial_state

for t in range(T):
    h_t = cell(x_t, h_{t-1})
    y_t = output_layer(h_t)

第二部分:梯度消失与爆炸问题

2.1 问题本质

长期依赖问题:RNN难以学习长期序列中的依赖关系。

原因:反向传播时,梯度随时间步长指数级衰减或增长。

2.2 梯度消失

数学解释: 在反向传播中,梯度乘以 W_{hh}\^T 多次:

∂Lt∂h1=∂Lt∂ht∏k=2t∂hk∂hk−1=∂Lt∂ht∏k=2tWhhTdiag(tanh⁡′(hk−1)) \frac{\partial L_t}{\partial h_1} = \frac{\partial L_t}{\partial h_t} \prod_{k=2}^{t} \frac{\partial h_k}{\partial h_{k-1}} = \frac{\partial L_t}{\partial h_t} \prod_{k=2}^{t} W_{hh}^T \text{diag}\left( \tanh'(h_{k-1}) \right) ∂h1∂Lt=∂ht∂Ltk=2∏t∂hk−1∂hk=∂ht∂Ltk=2∏tWhhTdiag(tanh′(hk−1))

当 ∥Whh∥<1\| W_{hh} \| < 1∥Whh∥<1:梯度指数级衰减 →\to→ 梯度消失

后果

  • 距离远的输入对当前输出影响小
  • 网络无法学习长期依赖

2.3 梯度爆炸

当 ∥Whh∥>1\| W_{hh} \| > 1∥Whh∥>1:梯度指数级增长 → 梯度爆炸

后果

  • 参数更新过大
  • 训练不稳定
  • Loss变为NaN

2.4 解决方案

问题 解决方案
梯度消失 LSTM, GRU, 残差连接
梯度爆炸 梯度裁剪(Gradient Clipping)

梯度裁剪

python 复制代码
# Keras中的梯度裁剪
optimizer = Adam(clipnorm=1.0)  # 按范数裁剪
optimizer = Adam(clipvalue=0.5) # 按值裁剪

第三部分:LSTM原理与门控机制

3.1 LSTM的核心思想

LSTM (Long Short-Term Memory) 通过门控机制控制信息的流动,解决了简单RNN的梯度消失问题。

关键创新

  • 细胞状态(Cell State):信息高速公路
  • 三个门:遗忘门、输入门、输出门

3.2 LSTM结构图

text 复制代码
        C_{t-1} ─────────────────────────→ C_t
                  ↖                      ↗
                   ×      +      ×
                   ↑      ↑      ↑
                  σ       σ     tanh
                   ↑      ↑      ↑
   h_{t-1} ────→  └──────┴──────┘ ←──── x_t
        ↓                        ↓
        h_t ←────────────────────┘

3.3 遗忘门(Forget Gate)

作用:决定从细胞状态中丢弃哪些信息。

ft=σ(Wf⋅ht−1,xt+bf) f_t = \sigma(W_f \cdot h_{t-1}, x_t + b_f) ft=σ(Wf⋅ht−1,xt+bf)

  • 输出0-1之间
  • 1表示"完全保留",0表示"完全丢弃"

3.4 输入门(Input Gate)

作用:决定将哪些新信息存入细胞状态。

it=σ(Wi⋅ht−1,xt+bi) i_t = \sigma(W_i \cdot h_{t-1}, x_t + b_i) it=σ(Wi⋅ht−1,xt+bi)

C~t=tanh⁡(WC⋅ht−1,xt+bC) \tilde{C}_t = \tanh(W_C \cdot h_{t-1}, x_t + b_C) C~t=tanh(WC⋅ht−1,xt+bC)

3.5 细胞状态更新

作用:更新长期记忆。

Ct=ft⊙Ct−1+it⊙C~t C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t Ct=ft⊙Ct−1+it⊙C~t

  • 保留旧信息(遗忘门控制)
  • 添加新信息(输入门控制)

3.6 输出门(Output Gate)

作用:决定输出什么信息。

ot=σ(Wo⋅ht−1,xt+bo) o_t = \sigma(W_o \cdot h_{t-1}, x_t + b_o) ot=σ(Wo⋅ht−1,xt+bo)

ht=ot⊙tanh⁡(Ct) h_t = o_t \odot \tanh(C_t) ht=ot⊙tanh(Ct)

3.7 LSTM vs 简单RNN对比

特性 简单RNN LSTM
长期记忆 优秀
梯度消失 严重 显著缓解
参数量 多(4倍)
训练速度
适用场景 短期依赖 长期依赖

第四部分:LSTM vs 简单MLP

4.1 MLP处理序列的问题

MLP无法处理变长序列

  • 输入维度固定
  • 无法捕捉时间顺序

MLP的"时间"处理方式

  • 将序列展平为向量
  • 丢失了时间结构信息

4.2 时序数据建模对比

数据建模方式 输入形状 能否捕捉时间依赖
MLP(展平) (seq_len, features)
RNN/LSTM (seq_len, features)
CNN (seq_len, features) 局部

4.3 预测任务类型

任务类型 说明 示例
多步输入,单步输出 用N天预测明天 股价预测
多步输入,多步输出 用N天预测未来M天 未来5日走势
单步输入,多步输出 用当前预测未来序列 生成序列

第五部分:时间序列预测的注意事项

5.1 数据预处理

python 复制代码
# 1. 标准化(重要!)
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
scaled_data = scaler.fit_transform(data)

# 2. 创建滑动窗口
def create_sequences(data, seq_length):
    X, y = [], []
    for i in range(len(data) - seq_length):
        X.append(data[i:i+seq_length])
        y.append(data[i+seq_length])
    return np.array(X), np.array(y)

5.2 数据泄露防范

关键原则:不能使用未来信息!

python 复制代码
# 正确做法:只用历史数据预测未来
split_idx = int(len(X) * 0.7)
X_train = X[:split_idx]
X_test = X[split_idx:]

5.3 评估指标

指标 公式 适用场景
MSE 1n∑(y−y^)2\frac 1n \sum (y − \hat{y})^2n1∑(y−y^)2 通用
MAE 1n∑∣y−y^∣\frac{1}{n} \sum |y-\hat{y} |n1∑∣y−y^∣ 通用,鲁棒
MAPE 1n∑∣y−y^y∣\frac{1}{n}\sum |\frac{y-\hat{y}}{y}|n1∑∣yy−y^∣ 相对误差
Direction Accuracy 正确预测趋势次数n\frac{正确预测趋势次数}{n}n正确预测趋势次数 金融涨跌预测
相关推荐
冬奇Lab2 分钟前
每日一个开源项目(第133篇):EchoBird - 把 AI 工具的安装和部署做成傻瓜操作
人工智能·开源·资讯
IT_陈寒1 小时前
Redis的SETNX并发问题让我加了三天班
前端·人工智能·后端
用户5191495848453 小时前
Windows 渗透测试载荷加载器 POC 工具集
人工智能·aigc
大树883 小时前
金刚石散热越强,管路越先见顶
大数据·运维·服务器·人工智能·ai
通信小呆呆3 小时前
当算法有了“五感”:多模态数据融合如何向人体感官协同学习?
人工智能·学习·算法·机器学习·机器人
施小赞3 小时前
普通 RAG vs GraphRAG 核心对比
人工智能·ai
EAIReport3 小时前
RuoYi-AI 企业级AI开发平台实战详解
人工智能
xiao5kou4chang6kai43 小时前
MATLAB机器学习、深度学习--从数据预处理到模型训练
深度学习·机器学习·matlab·数据预处理
HelloWorld__来都来了3 小时前
【每日学术速报】2026-06-15
人工智能·具身智能
H__Rick3 小时前
自动对焦学习-3
人工智能·学习·计算机视觉