nn.LSTM个人记录

简介

nn.LSTM参数

python 复制代码
torch.nn.lstm(input_size,   "输入的嵌入向量维度,例如每个单词用50维向量表示,input_size就是50"
              hidden_size,  "隐藏层节点数量,也是输出的嵌入向量维度"
              num_layers,   "lstm 隐层的层数,默认为1"
              bias,         "隐层是否带 bias,默认为 true"
              batch_first,  "True 或者 False,如果是 True,则 input 为(batchsize, len, input_size),默认值为:False(len, batchsize, input_size)"
              dropout,      "除最后一层,每一层的输出都进行dropout,默认值0"
              bidirectional "如果设置为 True, 则表示双向 LSTM,默认为 False"
              )

维度

batch_first=True,输入维度(batchsize,len,input_size)

batch_first=False,输入维度(len,batchsize, input_size)

batch_first=False,输出维度(len,batchsize,hidden_size)

举例嵌入向量维度为1

假如输入x为(batchsize,len)的序列,即嵌入向量维度为1,进行一个回归预测。

如果将嵌入向量维度维度设为1就不太合理,因为如果len非常长例如几w,那么经过几w的时间步得到的得到的h维度为( batchsize,1**),序列太长丢失很多信息,再输入全连接层预测效果不好。并且lstm实际上将嵌入向量维度从input_size规约到hidden_size。**

所以在这里我们将len作为input_size,嵌入向量维度1作为len(即对调了一下)

添加一个维度:

python 复制代码
x = x.unsqueeze(0)

x维度变为(1,batchsize,len),相当于设置数据的长度为1,嵌入向量维度为len,通过nn.LSTM输入到网络中。

python 复制代码
#lstm为定义的网络
#h[-1]为最后输入到全连接层的嵌入矩阵 但是由于此问题中len为1,所以x等于h[-1]
x, (h, c) = lstm(x)

x维度变为(1,batchsize,hidden_size)

h为每层lstm最后一个时间步的输出一般可以输入到后续的全连接层),维度为(num_layers,batchsize,hidden_size)

c为最后一个时间步 LSTM cell 的状态(记忆单元,一般用不到),维度为(num_layers,batchsize,hidden_size)

移除张量中所有尺寸为 1 的维度,即将第一个维度移除掉:

python 复制代码
lstm_out = x.squeeze(0)

x维度变为(batchsize,hidden_size) ,输入到全连接层(线性层,维度(hidden_size,num_class))中,最终输出维度(batchsize,num_class)

参考:

Pytorch --- LSTM (nn.LSTM & nn.LSTMCell)-CSDN博客

相关推荐
Donvink28 分钟前
YOLO系列论文综述(从YOLOv1到YOLOv11)【第3篇:YOLOv1——YOLO的开山之作】
人工智能·深度学习·yolo·目标检测
袁袁袁袁满29 分钟前
体验免费开箱即用的AI工具:Blackbox.AI
人工智能·gpt·深度学习·chatgpt·大模型·gpt-4o·blackbox.ai
橙子小哥的代码世界32 分钟前
【大模型】从零样本到少样本学习:一文读懂 Zero-shot、One-shot 和 Few-shot 的核心原理与应用!
深度学习·神经网络·自然语言处理·tensorflow·sklearn
魍魉19882 小时前
神经网络的数学——一个完整的例子
神经网络·决策树·机器学习
IT古董3 小时前
【人工智能】Python常用库-TensorFlow常用方法教程
人工智能·python·机器学习·tensorflow
L_cl3 小时前
NLP 2、机器学习简介
人工智能·机器学习·自然语言处理
爱喝白开水a3 小时前
优化注意力层提升 Transformer 模型效率:通过改进注意力机制降低机器学习成本
人工智能·深度学习·机器学习·自然语言处理·大模型·transformer·大模型微调
IT古董3 小时前
【机器学习】机器学习的基本分类-监督学习(Supervised Learning)
人工智能·学习·机器学习·分类
奔跑的犀牛先生3 小时前
【小白学机器学习37】用numpy计算协方差cov(x,y) 和 皮尔逊相关系数 r(x,y)
人工智能·python·机器学习
池央3 小时前
生成式机器学习:自回归模型
人工智能·机器学习·回归