LSTM(长短期记忆网络)是一种递归神经网络,设计用来解决梯度消失和长期依赖问题。
梯度消失:在反向传播过程中,由于链式法则,较早层的梯度小于1,连乘后数次迭代会导致梯度趋于0,使得网络很难学习早期信息。
长期依赖问题:传统神经网络在处理长序列数据时,梯度更新往往受限于短期依赖,难以有效学习长期依赖关系。
LSTM通过增加一个"遗忘门"、"输入门"和"输出门"来解决这些问题。它使用一个称为"单元状态"的隐藏状态,该状态可以记住长期信息。
以下是一个简单的LSTM单元的Python代码示例,使用PyTorch框架:
import torch
import torch.nn as nn
class LSTMCell(nn.Module):
def init(self, input_size, hidden_size):
super(LSTMCell, self).init()
self.hidden_size = hidden_size
self.input2hidden = nn.Linear(input_size + hidden_size, hidden_size)
self.input2cell = nn.Linear(input_size, hidden_size)
self.hidden2cell = nn.Linear(hidden_size, hidden_size)
def forward(self, input, hidden):
h, c = hidden
combined = torch.cat((input, h), dim=1) # concatenate along dimension 1 (channel dimension)
Input Gate
i = torch.sigmoid(self.input2hidden(combined))
Forget Gate
f = torch.sigmoid(self.input2cell(input) + self.hidden2cell(h))
New Cell State
new_c = f * c + i * torch.tanh(self.input2cell(combined))
Output Gate
o = torch.sigmoid(self.input2hidden(combined))
New Hidden State
new_h = o * torch.tanh(new_c)
return new_h, (new_h, new_c)
Example usage
input_size = 10
hidden_size = 20
lstm_cell = LSTMCell(input_size, hidden_size)
input = torch.randn(5, 3, input_size) # seq_len = 5, batch_size = 3
h0 = torch.randn(3, hidden_size)
c0 = torch.randn(3, hidden_size)
hidden_state = (h0, c0)
for input_step in input:
hidden_state = lstm_cell(input_step, hidden_state)
Output is the new hidden state
print(hidden_state[0])
这段代码定义了一个基本的LSTM单元,它接受一个输入序列和一个初始隐藏状态。然后,它遍历输入序列,逐个步骤地计算新的隐藏状态。这个例子中没有使用PyTorch提供的nn.LSTMCell模块,而是手动实现了LSTM单元的基本组成部分,以便更好地理解LSTM的工作原理。