LSTM——长短期记忆神经网络

目录

[1.LSTM 工作原理](#1.LSTM 工作原理)

2.LSTM的代码实现

3.代码详解


LSTM (Long Short-Term Memory)是一种特殊的循环神经网络 (RNN),用于解决长序列中的长期依赖问题。它通过引入门机制,控制信息的流入、保留和输出,从而在避免梯度消失或爆炸的情况下捕获较长序列的依赖关系。以下是LSTM的工作原理和代码实现


1.LSTM 工作原理

LSTM 通过引入 细胞状态(Cell State)门控单元(Gates) 来控制信息流动,具体包含以下几个部分:

  1. 遗忘门(Forget Gate)

    遗忘门决定了上一个时间步的细胞状态是否需要保留或遗忘。遗忘门通过一个 sigmoid 激活函数(输出在 0 和 1 之间)来控制。输入为当前输入 和上一个隐藏状态 ​:

  2. 输入门(Input Gate)

    输入门决定当前时间步的新信息是否要更新到细胞状态中。它包含两个部分:

    • :用于选择要添加的新信息。
    • :候选细胞状态,通过 tanh 函数生成可能的新状态信息。

  3. 细胞状态更新

    细胞状态结合了遗忘门和输入门的输出来更新:

  4. 输出门(Output Gate)

    输出门控制 LSTM 的最终输出,即新的隐藏状态 。它将新的细胞状态 ​ 调整后输出:

2.LSTM的代码实现

以下是使用 PyTorch 实现 LSTM 的代码示例:

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim

# 定义 LSTM 模型
class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=1):
        super(LSTMModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        # 初始化隐藏状态和细胞状态
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        
        # 通过 LSTM 层
        out, _ = self.lstm(x, (h0, c0))
        
        # 获取最后一个时间步的输出
        out = self.fc(out[:, -1, :])
        return out

# 定义模型参数
input_size = 10    # 输入维度
hidden_size = 20   # 隐藏层维度
output_size = 1    # 输出维度
num_layers = 2     # LSTM 层数

# 初始化模型
model = LSTMModel(input_size, hidden_size, output_size, num_layers)

# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
num_epochs = 100
for epoch in range(num_epochs):
    # 假设输入数据 x 和标签 y
    x = torch.randn(32, 5, input_size)  # (batch_size, sequence_length, input_size)
    y = torch.randn(32, output_size)
    
    # 前向传播
    outputs = model(x)
    loss = criterion(outputs, y)
    
    # 反向传播和优化
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch+1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

3.代码详解

  • 输入数据 :这里的 x 是一个三维张量,形状为 (批次大小, 序列长度, 输入维度),其中 序列长度 是 LSTM 模型需要捕获依赖的时间步。
  • 隐藏层和输出层 :LSTM 输出的最后一个时间步的隐藏状态传递给全连接层 fc,用于输出预测结果。
  • 初始化状态 :LSTM 层需要初始化隐藏状态 h0 和细胞状态 c0,这通常在每个新序列的起点进行。
  • 损失函数和优化器:使用均方误差损失函数(MSELoss)和 Adam 优化器来优化模型。

通过调整输入、隐藏和输出维度,这种结构可以适用于各种时间序列预测、自然语言处理等任务。

相关推荐
火车叼位17 小时前
也许你不需要创建.venv, 此规范使python脚本自备依赖
python
YuTaoShao17 小时前
【LeetCode 每日一题】1653. 使字符串平衡的最少删除次数——(解法一)前后缀分解
算法·leetcode·职场和发展
火车叼位17 小时前
脚本伪装:让 Python 与 Node.js 像原生 Shell 命令一样运行
运维·javascript·python
VT.馒头17 小时前
【力扣】2727. 判断对象是否为空
javascript·数据结构·算法·leetcode·职场和发展
民乐团扒谱机17 小时前
【微实验】机器学习之集成学习 GBDT和XGBoost 附 matlab仿真代码 复制即可运行
人工智能·机器学习·matlab·集成学习·xgboost·gbdt·梯度提升树
芷栀夏17 小时前
CANN ops-math:揭秘异构计算架构下数学算子的低延迟高吞吐优化逻辑
人工智能·深度学习·神经网络·cann
goodluckyaa17 小时前
LCR 006. 两数之和 II - 输入有序数组
算法
孤狼warrior17 小时前
YOLO目标检测 一千字解析yolo最初的摸样 模型下载,数据集构建及模型训练代码
人工智能·python·深度学习·算法·yolo·目标检测·目标跟踪
机器学习之心17 小时前
TCN-Transformer-BiGRU组合模型回归+SHAP分析+新数据预测+多输出!深度学习可解释分析
深度学习·回归·transformer·shap分析
Katecat9966317 小时前
YOLO11分割算法实现甲状腺超声病灶自动检测与定位_DWR方法应用
python