pytorch LSTM类解析

1、根据sin(t)预测cos(t)

输入网络的:train_x形状是 (100,) → 重塑后是 (20, 5, 1) 。实际输入LSTM中的维度为(5,20,1)对应 (seq_len, batch_size, input_size),解释:

seq_len=5:每个样本的时间序列长度为 5;

batch_size=train_data_len//5=20:总训练样本数 100,按 seq_len=5 划分后得到 20 个 batch

input_size=2:数据的维度是1。

2、关键细节

复制代码
class LstmRNN(nn.Module):
    def __init__(self, input_size, hidden_size=1, output_size=1, num_layers=1):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers)  # LSTM 层
        self.forwardCalculation = nn.Linear(hidden_size, output_size)  # 全连接层

    def forward(self, _x):
        x, _ = self.lstm(_x)  # LSTM 输出: (seq_len, batch, hidden_size)
        s, b, h = x.shape
        x = x.view(s * b, h)  # 展平: (seq_len*batch, hidden_size)
        x = self.forwardCalculation(x)  # 全连接层映射到输出维度
        x = x.view(s, b, -1)  # 恢复时序结构: (seq_len, batch, output_size)
        return x

2.1、核心代码解释

self.lstm = nn.LSTM(input_size, hidden_size, num_layers) # LSTM 层

代码中 LSTM 的具体结构结合代码中的参数 :input_size=1、hidden_size=16、num_layers=1,这个 LSTM 的结构可以拆解为 "单层、单方向、隐藏维度 16 的 LSTM",具体如下:

  1. 输入输出维度对应关系

代码中训练数据的维度:

输入 train_x_tensor 形状:(20, 5, 1) → 因为 batch_first=False,实际输入 LSTM 时维度顺序为 (seq_len=5, batch_size=20, input_size=1);

2.LSTM 前向传播输出:

x, _ = self.lstm(_x) # x 是 LSTM 的输出

输出 x 的形状为:(seq_len=5, batch_size=20, hidden_size=16)。

2.2、网络结构可视化

复制代码
# 输入层:每个时间步输入 1 维特征
[时间步 1: sin(t1)] → [时间步 2: sin(t2)] → ... → [时间步 5: sin(t5)]
        |                   |                          |
        └───────────────────┴──────────────────────────┘
                            ↓
# 单层 LSTM 核心层:隐藏维度 16
[LSTM 单元 (隐藏状态维度 16)]
        |
        ├─ 时间步 1 输出:16 维隐藏状态
        ├─ 时间步 2 输出:16 维隐藏状态
        ├─ 时间步 3 输出:16 维隐藏状态
        ├─ 时间步 4 输出:16 维隐藏状态
        └─ 时间步 5 输出:16 维隐藏状态
                            ↓
# 全连接层:将 16 维隐藏状态映射到 1 维输出(预测 cos(t))
[Linear(16 → 1)]
        |
        ├─ 时间步 1 输出:cos(t1) 预测值
        ├─ 时间步 2 输出:cos(t2) 预测值
        ├─ 时间步 3 输出:cos(t3) 预测值
        ├─ 时间步 4 输出:cos(t4) 预测值
        └─ 时间步 5 输出:cos(t5) 预测值

2.3、关键细节补充

1.隐藏状态 vs 细胞状态

LSTM 的输出 x 是每个时间步的隐藏状态 h_t,而 LSTM 实际内部还有一个细胞状态 c_t(代码中用 _ 舍弃了,因为这个任务不需要直接用细胞状态)。

2.参数数量计算

3、参数调整对模型的影响

如果你修改这些参数,模型结构会发生明显变化:

增大 hidden_size(比如 32):记忆容量提升,预测精度可能更高,但计算量变大,容易过拟合。

增大 num_layers(比如 2):变成堆叠 LSTM,可以捕捉更复杂的时序特征,但训练难度增加,需要调整学习率和正则化。

增大 input_size(比如 2):如果输入同时包含 sin(t) 和 sin(t-1),输入维度变为 2,LSTM 可以同时利用多个特征进行预测。

相关推荐
人工智能培训2 小时前
设备故障?数字孪生提前预警
人工智能·深度学习·神经网络·机器学习·生成对抗网络
风落无尘2 小时前
第十一章《对齐与安全》 完整学习资料
python·安全·机器学习
独隅2 小时前
PyTorch自动微分模块:从原理到实战一
人工智能·pytorch·python
Luhui Dev3 小时前
大角几何 MCP 服务上线:让 AI Agent 直接完成几何作图
人工智能·数学·机器学习·大角几何·luhuidev
wangqiaowq3 小时前
预训练 后预训练 微调
人工智能·深度学习·机器学习
罗西的思考4 小时前
【Agentic RL / 强化学习 / OPD】OpenClaw-RL 源码阅读笔记 --- (4)--- 系统架构
人工智能·算法·机器学习
Rocky Ding*4 小时前
一文读懂HiDream-I1稀疏 DiT 图像生成基础模型
论文阅读·人工智能·深度学习·机器学习·ai作画·aigc·ai-native
大模型最新论文速读8 小时前
05-29 · LLM 最新论文速览
论文阅读·人工智能·深度学习·机器学习·自然语言处理
春日见9 小时前
强化学习方法分类:
人工智能·机器学习·分类·数据挖掘·强化学习