nn.LSTM个人记录

简介

nn.LSTM参数

python 复制代码
torch.nn.lstm(input_size,   "输入的嵌入向量维度,例如每个单词用50维向量表示,input_size就是50"
              hidden_size,  "隐藏层节点数量,也是输出的嵌入向量维度"
              num_layers,   "lstm 隐层的层数,默认为1"
              bias,         "隐层是否带 bias,默认为 true"
              batch_first,  "True 或者 False,如果是 True,则 input 为(batchsize, len, input_size),默认值为:False(len, batchsize, input_size)"
              dropout,      "除最后一层,每一层的输出都进行dropout,默认值0"
              bidirectional "如果设置为 True, 则表示双向 LSTM,默认为 False"
              )

维度

batch_first=True,输入维度(batchsize,len,input_size)

batch_first=False,输入维度(len,batchsize, input_size)

batch_first=False,输出维度(len,batchsize,hidden_size)

举例嵌入向量维度为1

假如输入x为(batchsize,len)的序列,即嵌入向量维度为1,进行一个回归预测。

如果将嵌入向量维度维度设为1就不太合理,因为如果len非常长例如几w,那么经过几w的时间步得到的得到的h维度为( batchsize,1**),序列太长丢失很多信息,再输入全连接层预测效果不好。并且lstm实际上将嵌入向量维度从input_size规约到hidden_size。**

所以在这里我们将len作为input_size,嵌入向量维度1作为len(即对调了一下)

添加一个维度:

python 复制代码
x = x.unsqueeze(0)

x维度变为(1,batchsize,len),相当于设置数据的长度为1,嵌入向量维度为len,通过nn.LSTM输入到网络中。

python 复制代码
#lstm为定义的网络
#h[-1]为最后输入到全连接层的嵌入矩阵 但是由于此问题中len为1,所以x等于h[-1]
x, (h, c) = lstm(x)

x维度变为(1,batchsize,hidden_size)

h为每层lstm最后一个时间步的输出一般可以输入到后续的全连接层),维度为(num_layers,batchsize,hidden_size)

c为最后一个时间步 LSTM cell 的状态(记忆单元,一般用不到),维度为(num_layers,batchsize,hidden_size)

移除张量中所有尺寸为 1 的维度,即将第一个维度移除掉:

python 复制代码
lstm_out = x.squeeze(0)

x维度变为(batchsize,hidden_size) ,输入到全连接层(线性层,维度(hidden_size,num_class))中,最终输出维度(batchsize,num_class)

参考:

Pytorch --- LSTM (nn.LSTM & nn.LSTMCell)-CSDN博客

相关推荐
码农小韩14 小时前
Mamba学习(一)——Mamba-V1原理(一)
深度学习·ssm·mamba·状态空间模型·序列模型
买大橘子也用券14 小时前
26软件系统安全赛-Fake Emotion(复盘)
python·深度学习·安全·网络安全
AI人工智能+14 小时前
施工许可证智能识别系统通过融合计算机视觉与自然语言处理技术,实现了建筑行业关键证件的自动化信息提取
人工智能·深度学习·计算机视觉·ocr·施工许可证识别
春日见14 小时前
5分钟入门强化学习之蒙特卡洛(MC)算法与实现
运维·服务器·人工智能·深度学习·算法·机器学习
不会计算机的g_c__b14 小时前
Argoverse API 完全解析:自动驾驶数据集与高精地图开发利器
人工智能·机器学习·自动驾驶
是一个Bug1 天前
Agent(智能体)应用 的入门学习路径
学习·机器学习
盖小雅1 天前
自动化排班如何破解劳动法合规难题:从规则冲突到可追溯的排班表
大数据·运维·机器学习·自动化
2401_876964131 天前
【湖北专升本】2026湖北专升本真题PDF+备考资料汇总
数据结构·人工智能·经验分享·深度学习·算法·计算机视觉
踏歌~1 天前
YA期货准备:0 了解期货
机器学习
数据科学小丫1 天前
特征工程处理
人工智能·算法·机器学习