2025-11-15 学习记录--Python-LSTM模型定义(PyTorch)

LSTM模型定义(PyTorch

  • LSTM(Long Short-Term Memory)长短期记忆网络
    是 RNN(循环神经网络)的一种改进版本,主要用来解决 时间序列预测需要记住过去信息的任务 ,例如:👇🏻
    • PM2.5 时间序列预测
    • 文本生成
    • 股票预测
    • 温度预测
    • 电力负载预测
  • 普通 RNN 的问题是运行久了就遗忘前面的信息 (梯度消失),而 LSTM 通过 "门结构(gates) " 让网络能够选择:👇🏻
    • 记住(Keep)
    • 忘掉(Forget)
    • 更新(Update)
  • 这些信息。
python 复制代码
# LSTM 模型定义(PyTorch)
# ---------------------------------------------------------
# 本文件实现一个简单的单层 LSTM 回归模型,用于预测下一小时 PM2.5。
# 输入维度: 24(窗口长度)
# 输出维度: 1(预测未来1小时 PM2.5)
# ---------------------------------------------------------

import torch  # 导入 PyTorch 主包,用于张量运算与设备管理
import torch.nn as nn  # 导入神经网络模块的子包,习惯性重命名为 nn

# 定义一个继承自 nn.Module 的 LSTM 模型类
class LSTMModel(nn.Module):
    def __init__(self, input_size=1, hidden_size=64, num_layers=1, dropout=0.0):
        super(LSTMModel, self).__init__()  # 调用父类构造函数,初始化模块内部状态

        # 定义一个 LSTM 层
        # input_size: 每个时间步的特征维度(这里每小时只有一个 PM2.5 值,所以是 1)
        # hidden_size: LSTM 隐藏态的维度(即每个时间步输出向量的长度)
        # num_layers: LSTM 堆叠层数(几层 LSTM 单元叠在一起)
        # batch_first=True: 输入/输出张量的形状为 (batch, seq_len, feature)
        # dropout: 当 num_layers>1 时,层间 dropout 的概率
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout
        )

        # 定义一个线性全连接层,把 LSTM 的最后隐藏态映射为预测值
        # 输入维度 hidden_size -> 输出维度 1(回归预测一个数)
        self.fc = nn.Linear(hidden_size, 1)

    def forward(self, x):
        # forward 定义前向传播逻辑,x 是模型输入张量
        # 期望 x.shape = (batch_size, seq_len=24, feature=1)
        out, _ = self.lstm(x)  # 把输入传入 LSTM,out 为每个时间步的输出(shape=(batch, seq_len, hidden_size))
                               # 第二个返回值是 (h_n, c_n) ------ 最后一个时间步的隐状态与细胞状态,这里用 _ 忽略它

        # 取 LSTM 输出序列中最后一个时间步的输出作为序列级特征
        # out[:, -1, :] 的形状为 (batch_size, hidden_size)
        out = out[:, -1, :]

        # 把最后时间步的隐藏向量通过全连接层映射为标量预测值
        # 最终 out 的形状为 (batch_size, 1)
        out = self.fc(out)
        return out  # 返回预测结果(未做激活,回归任务通常直接输出实数)
相关推荐
梨落秋霜13 小时前
Python入门篇【文件处理】
android·java·python
Java 码农14 小时前
RabbitMQ集群部署方案及配置指南03
java·python·rabbitmq
张登杰踩15 小时前
VIA标注格式转Labelme标注格式
python
气概15 小时前
法奥机器人学习使用
学习·junit·机器人
Qhumaing15 小时前
C++学习:【PTA】数据结构 7-1 实验7-1(最小生成树-Prim算法)
c++·学习·算法
Learner15 小时前
Python数据类型(四):字典
python
好大哥呀16 小时前
Java Web的学习路径
java·前端·学习
odoo中国16 小时前
Odoo 19 模块结构概述
开发语言·python·module·odoo·核心组件·py文件按
Jelena1577958579216 小时前
Java爬虫api接口测试
python
踩坑记录17 小时前
leetcode hot100 3.无重复字符的最长子串 medium 滑动窗口(双指针)
python·leetcode