长短期记忆网络(LSTM)

1. LSTM 简介

循环神经网络(RNN)在处理序列数据(如文本、时间序列、语音)时具有天然优势,但传统 RNN 存在 梯度消失和梯度爆炸 的问题,难以捕捉长距离依赖关系。

长短期记忆网络(Long Short-Term Memory,LSTM) 是一种特殊的 RNN 变体,由 Hochreiter 和 Schmidhuber 在 1997 年提出。它通过 门控机制(Gates) 控制信息的保留与遗忘,解决了 RNN 在长序列训练中的问题。

2. LSTM 的核心结构

LSTM 的基本单元包含三个门和一个单元状态(Cell State):

  • 遗忘门(Forget Gate):决定丢弃多少历史信息。

  • 输入门(Input Gate):决定当前输入信息保留多少。

  • 输出门(Output Gate):决定当前时刻的隐藏状态输出。

  • 单元状态(Cell State):类似"传送带",携带长期信息。

结构示意

可以把 LSTM 看作在传统 RNN 的基础上,多了一条 信息高速通道(Cell State),通过门控机制有选择地更新。

3. LSTM 数学公式

设输入为 ,上一时刻的隐藏状态为 ,上一时刻的单元状态为 ,则 LSTM 的计算公式如下:

(1)遗忘门(Forget Gate):

(2)输入门(Input Gate):

(3)候选单元状态(Cell Candidate):

(4)单元状态更新(Cell State):

(5)输出门(Output Gate):

(6)隐藏状态更新(Hidden State):

其中:

  • 表示 Sigmoid 激活函数;

  • 表示双曲正切激活函数;

  • 表示元素逐位相乘。

4. LSTM 的优点

  1. 解决长依赖问题:能够捕捉数百步的时间依赖关系。

  2. 避免梯度消失/爆炸:门控机制使得梯度能在长序列中稳定传播。

  3. 广泛应用:文本生成、机器翻译、语音识别、金融时间序列预测等。

5. LSTM 的缺点

  1. 结构复杂:相较于 RNN,参数更多,计算开销更大。

  2. 训练速度慢:长序列数据下,训练时间成本较高。

  3. 难以并行:依赖序列前后顺序,难以像 Transformer 那样并行化。

6. Python 实现示例(PyTorch)

python 复制代码
import torch
import torch.nn as nn

class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=1):
        super(LSTMModel, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        # x shape: (batch_size, seq_len, input_size)
        out, (h_n, c_n) = self.lstm(x)  
        out = self.fc(out[:, -1, :])  # 取最后时间步的隐藏状态
        return out

# 示例
input_size = 10   # 每个时间步输入维度
hidden_size = 50  # 隐藏层维度
output_size = 1   # 输出维度
model = LSTMModel(input_size, hidden_size, output_size)

x = torch.randn(32, 5, 10)  # batch_size=32, seq_len=5, input_size=10
y = model(x)
print(y.shape)  # (32, 1)

7. 总结

  • RNN 善于处理短期依赖,但难以记忆长期信息。

  • LSTM 引入门控机制,通过遗忘门、输入门、输出门和单元状态,有效解决了梯度消失问题,能更好地建模长序列。

  • 但 LSTM 依然存在计算慢、难以并行的缺点,这也是后来 GRUTransformer 出现的原因。

相关推荐
无垠的广袤11 小时前
【LattePanda Mu 开发套件】AI 图像识别网页服务器
服务器·人工智能·python·单片机·嵌入式硬件·物联网
芒果量化12 小时前
ML4T - 第7章第7节 逻辑回归拟合宏观数据Logistic Regression with Macro Data
人工智能·机器学习·逻辑回归·线性回归
西岭千秋雪_12 小时前
RAG核心特性:ETL
数据仓库·人工智能·spring boot·ai编程·etl
无风听海12 小时前
神经网络之Softmax激活函数求导过程
人工智能·深度学习·神经网络
youcans_12 小时前
【Trae】Trae 插件实战手册(1)PyCharm 安装 Trae
人工智能·python·pycharm·ai编程·trae
说私域12 小时前
基于开源AI智能名片链动2+1模式S2B2C商城小程序的引流爆款设计策略研究
人工智能·小程序
张较瘦_12 小时前
[论文阅读] AI + 软件工程 | 从“事后补救”到“实时防控”,SemGuard重塑LLM代码生成质量
论文阅读·人工智能·软件工程
IT古董12 小时前
【第五章:计算机视觉-项目实战之生成对抗网络实战】1.对抗生成网络原理-(1)对抗生成网络算法基础知识:基本思想、GAN的基本架构、应用场景、标注格式
人工智能·生成对抗网络·计算机视觉
MoRanzhi120313 小时前
0. NumPy 系列教程:科学计算与数据分析实战
人工智能·python·机器学习·数据挖掘·数据分析·numpy·概率论
金井PRATHAMA13 小时前
语义网络(Semantic Net)对人工智能中自然语言处理的深层语义分析的影响与启示
人工智能·自然语言处理·知识图谱