pytorch LSTM类解析

1、根据sin(t)预测cos(t)

输入网络的:train_x形状是 (100,) → 重塑后是 (20, 5, 1) 。实际输入LSTM中的维度为(5,20,1)对应 (seq_len, batch_size, input_size),解释:

seq_len=5:每个样本的时间序列长度为 5;

batch_size=train_data_len//5=20:总训练样本数 100,按 seq_len=5 划分后得到 20 个 batch

input_size=2:数据的维度是1。

2、关键细节

复制代码
class LstmRNN(nn.Module):
    def __init__(self, input_size, hidden_size=1, output_size=1, num_layers=1):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers)  # LSTM 层
        self.forwardCalculation = nn.Linear(hidden_size, output_size)  # 全连接层

    def forward(self, _x):
        x, _ = self.lstm(_x)  # LSTM 输出: (seq_len, batch, hidden_size)
        s, b, h = x.shape
        x = x.view(s * b, h)  # 展平: (seq_len*batch, hidden_size)
        x = self.forwardCalculation(x)  # 全连接层映射到输出维度
        x = x.view(s, b, -1)  # 恢复时序结构: (seq_len, batch, output_size)
        return x

2.1、核心代码解释

self.lstm = nn.LSTM(input_size, hidden_size, num_layers) # LSTM 层

代码中 LSTM 的具体结构结合代码中的参数 :input_size=1、hidden_size=16、num_layers=1,这个 LSTM 的结构可以拆解为 "单层、单方向、隐藏维度 16 的 LSTM",具体如下:

  1. 输入输出维度对应关系

代码中训练数据的维度:

输入 train_x_tensor 形状:(20, 5, 1) → 因为 batch_first=False,实际输入 LSTM 时维度顺序为 (seq_len=5, batch_size=20, input_size=1);

2.LSTM 前向传播输出:

x, _ = self.lstm(_x) # x 是 LSTM 的输出

输出 x 的形状为:(seq_len=5, batch_size=20, hidden_size=16)。

2.2、网络结构可视化

复制代码
# 输入层:每个时间步输入 1 维特征
[时间步 1: sin(t1)] → [时间步 2: sin(t2)] → ... → [时间步 5: sin(t5)]
        |                   |                          |
        └───────────────────┴──────────────────────────┘
                            ↓
# 单层 LSTM 核心层:隐藏维度 16
[LSTM 单元 (隐藏状态维度 16)]
        |
        ├─ 时间步 1 输出:16 维隐藏状态
        ├─ 时间步 2 输出:16 维隐藏状态
        ├─ 时间步 3 输出:16 维隐藏状态
        ├─ 时间步 4 输出:16 维隐藏状态
        └─ 时间步 5 输出:16 维隐藏状态
                            ↓
# 全连接层:将 16 维隐藏状态映射到 1 维输出(预测 cos(t))
[Linear(16 → 1)]
        |
        ├─ 时间步 1 输出:cos(t1) 预测值
        ├─ 时间步 2 输出:cos(t2) 预测值
        ├─ 时间步 3 输出:cos(t3) 预测值
        ├─ 时间步 4 输出:cos(t4) 预测值
        └─ 时间步 5 输出:cos(t5) 预测值

2.3、关键细节补充

1.隐藏状态 vs 细胞状态

LSTM 的输出 x 是每个时间步的隐藏状态 h_t,而 LSTM 实际内部还有一个细胞状态 c_t(代码中用 _ 舍弃了,因为这个任务不需要直接用细胞状态)。

2.参数数量计算

3、参数调整对模型的影响

如果你修改这些参数,模型结构会发生明显变化:

增大 hidden_size(比如 32):记忆容量提升,预测精度可能更高,但计算量变大,容易过拟合。

增大 num_layers(比如 2):变成堆叠 LSTM,可以捕捉更复杂的时序特征,但训练难度增加,需要调整学习率和正则化。

增大 input_size(比如 2):如果输入同时包含 sin(t) 和 sin(t-1),输入维度变为 2,LSTM 可以同时利用多个特征进行预测。

相关推荐
郝学胜-神的一滴17 分钟前
深入理解回归损失函数:MSE、L1 与 Smooth L1 的设计哲学
人工智能·python·程序人生·算法·机器学习·数据挖掘·回归
Godspeed Zhao30 分钟前
具身智能中的传感器技术40.2——事件相机0.2
人工智能·科技·数码相机·机器学习·事件相机
2zcode42 分钟前
基于注意力机制LSTM的温度预测系统设计与实现
人工智能·深度学习·lstm
网络工程小王1 小时前
[RAG 与文本向量化详解]RAG篇
数据库·人工智能·redis·机器学习
小何code1 小时前
人工智能【第13篇】集成学习入门:Bagging与Boosting原理详解
随机森林·机器学习·集成学习·boosting
2zcode10 小时前
基于LSTM神经网络的金属材料机器学习本构模型研究(硕士级别)
神经网络·机器学习·lstm·金属材料
phoenix@Capricornus12 小时前
从贝叶斯决策到最小距离判别法再到Fisher判别分析
机器学习
Chef_Chen14 小时前
论文解读:多模态智能体长期记忆突破:M3-Agent让AI像人一样“看、听、记、想“
人工智能·机器学习·agent·memory
代码飞天14 小时前
机器学习算法和函数整理——助力快速查阅
人工智能·算法·机器学习
绛橘色的日落(。・∀・)ノ15 小时前
机器学习 单变量线性回归模型
人工智能·机器学习