pytorch LSTM 结构详解

最近项目用到了LSTM ,但是对LSTM 的输入输出不是很理解,对此,我详细查找了lstm 的资料

复制代码
import torch.nn as nn

class LSTMModel(nn.Module):
    def __init__(self, input_size=1, hidden_size=50, num_layers=2):
        super(LSTMModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, 1)  # 1 表示预测输出变量为1

    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
        out, _ = self.lstm(x, (h0, c0))
        out = self.fc(out[:, -1, :])
        return out # out 形状为(batch_size,1)
  • input_size=1:输入特征的维度,适用于单变量时间序列。

  • hidden_size=50:LSTM 隐藏层的维度,决定了模型的记忆能力。

  • num_layers=2:堆叠的 LSTM 层数,增加层数可以提升模型的表达能力。

  • batch_first=True :指定输入和输出的张量形状为 (batch_size, seq_len, input_size)

  • self.fc:一个全连接层,将 LSTM 的输出映射到最终的预测值。

  • batch_size 表示批次、seq_len 表示窗口大小、input_size 表示输入尺寸,单变量输入为1 ,多变量要基于个数变化

  • 初始化隐藏状态和细胞状态

    • h0c0 分别表示初始的隐藏状态和细胞状态,形状为 (num_layers, batch_size, hidden_size)

    • 在每次前向传播时,初始化为零张量。

  • LSTM 层处理

    • self.lstm(x, (h0, c0)):将输入 x 和初始状态传入 LSTM 层,输出 out 和新的状态。

    • out 的形状为 (batch_size, seq_len, hidden_size),包含了每个时间步的输出。

  • 全连接层映射

    • out[:, -1, :]:提取序列中最后一个时间步的输出。

    • self.fc(...):将提取的输出通过全连接层,得到最终的预测结果。

相关推荐
吴佳浩 Alben7 小时前
GPU 生产环境实践:硬件拓扑、显存管理与完整运维体系
运维·人工智能·pytorch·语言模型·transformer·vllm
湘美书院--湘美谈教育9 小时前
湘美谈教育湘美书院网文研究:人工智能与微型小说选集
人工智能·深度学习·神经网络·机器学习·ai写作
梦醒过后说珍重10 小时前
炼丹笔记:感知超分辨率模型中复合损失权重的科学调参SOP
深度学习
CoovallyAIHub10 小时前
Pipecat:构建实时语音 AI Agent 的开源编排框架,500ms 级端到端延迟
深度学习·算法·计算机视觉
CoovallyAIHub11 小时前
Energies | 8版YOLO对8版Transformer实测光伏缺陷检测,RF-DETR-Small综合胜出
深度学习·算法·计算机视觉
zh路西法11 小时前
【宇树机器人强化学习】(七):复杂地形的生成与训练
python·深度学习·机器学习·机器人
逄逄不是胖胖12 小时前
《动手学深度学习》-69预训练bert数据集实现
人工智能·深度学习·bert
CoovallyAIHub12 小时前
2.5GB 塞进浏览器:Mistral 开源实时语音识别,延迟不到半秒
深度学习·算法·计算机视觉
mygugu12 小时前
详细分析swanlab集成mmengine底层实现机制--源码分析
python·深度学习·可视化
Hello.Reader12 小时前
词语没有位置感?用“音乐节拍“给 Transformer 装上时钟——Positional Encoding 图解
人工智能·深度学习·transformer