2025-11-15 学习记录--Python-LSTM模型定义(PyTorch)

LSTM模型定义(PyTorch

  • LSTM(Long Short-Term Memory)长短期记忆网络
    是 RNN(循环神经网络)的一种改进版本,主要用来解决 时间序列预测需要记住过去信息的任务 ,例如:👇🏻
    • PM2.5 时间序列预测
    • 文本生成
    • 股票预测
    • 温度预测
    • 电力负载预测
  • 普通 RNN 的问题是运行久了就遗忘前面的信息 (梯度消失),而 LSTM 通过 "门结构(gates) " 让网络能够选择:👇🏻
    • 记住(Keep)
    • 忘掉(Forget)
    • 更新(Update)
  • 这些信息。
python 复制代码
# LSTM 模型定义(PyTorch)
# ---------------------------------------------------------
# 本文件实现一个简单的单层 LSTM 回归模型,用于预测下一小时 PM2.5。
# 输入维度: 24(窗口长度)
# 输出维度: 1(预测未来1小时 PM2.5)
# ---------------------------------------------------------

import torch  # 导入 PyTorch 主包,用于张量运算与设备管理
import torch.nn as nn  # 导入神经网络模块的子包,习惯性重命名为 nn

# 定义一个继承自 nn.Module 的 LSTM 模型类
class LSTMModel(nn.Module):
    def __init__(self, input_size=1, hidden_size=64, num_layers=1, dropout=0.0):
        super(LSTMModel, self).__init__()  # 调用父类构造函数,初始化模块内部状态

        # 定义一个 LSTM 层
        # input_size: 每个时间步的特征维度(这里每小时只有一个 PM2.5 值,所以是 1)
        # hidden_size: LSTM 隐藏态的维度(即每个时间步输出向量的长度)
        # num_layers: LSTM 堆叠层数(几层 LSTM 单元叠在一起)
        # batch_first=True: 输入/输出张量的形状为 (batch, seq_len, feature)
        # dropout: 当 num_layers>1 时,层间 dropout 的概率
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout
        )

        # 定义一个线性全连接层,把 LSTM 的最后隐藏态映射为预测值
        # 输入维度 hidden_size -> 输出维度 1(回归预测一个数)
        self.fc = nn.Linear(hidden_size, 1)

    def forward(self, x):
        # forward 定义前向传播逻辑,x 是模型输入张量
        # 期望 x.shape = (batch_size, seq_len=24, feature=1)
        out, _ = self.lstm(x)  # 把输入传入 LSTM,out 为每个时间步的输出(shape=(batch, seq_len, hidden_size))
                               # 第二个返回值是 (h_n, c_n) ------ 最后一个时间步的隐状态与细胞状态,这里用 _ 忽略它

        # 取 LSTM 输出序列中最后一个时间步的输出作为序列级特征
        # out[:, -1, :] 的形状为 (batch_size, hidden_size)
        out = out[:, -1, :]

        # 把最后时间步的隐藏向量通过全连接层映射为标量预测值
        # 最终 out 的形状为 (batch_size, 1)
        out = self.fc(out)
        return out  # 返回预测结果(未做激活,回归任务通常直接输出实数)
相关推荐
XDHCOM2 小时前
通过手机远程操控电脑,一步步学习便捷方法
学习·智能手机·电脑
百锦再2 小时前
第15章 并发编程
android·java·开发语言·python·rust·django·go
laufing2 小时前
pyinstaller 介绍
python·构建打包
胡楚昊3 小时前
Polar MISC(下)
学习
程序员东岸3 小时前
从零开始学二叉树(上):树的初识 —— 从文件系统到树的基本概念
数据结构·经验分享·笔记·学习·算法
谅望者3 小时前
数据分析笔记09:Python条件语循环
笔记·python·数据分析
Auspemak-Derafru3 小时前
从U盘损坏中恢复视频文件并修复修改日期的完整解决方案
python
Tonya433 小时前
测开学习DAY29
学习
techzhi3 小时前
Intellij idea 注释模版
java·python·intellij-idea