动手学深度学习(Pytorch版)代码实践 -循环神经网络-57长短期记忆网络(LSTM)

57长短期记忆网络(LSTM

1.LSTM原理

LSTM是专为解决标准RNN的长时依赖问题而设计的。标准RNN在训练过程中,随着时间步的增加,梯度可能会消失或爆炸,导致模型难以学习和记忆长时间间隔的信息。LSTM通过引入一组称为门的机制来解决这个问题:

  1. 输入门(Input Gate):控制有多少新的信息可以传递到记忆单元中。
  2. 遗忘门(Forget Gate):控制当前记忆单元中有多少信息会被保留。
  3. 输出门(Output Gate):控制记忆单元的输出有多少被传递到下一步。

LSTM还引入了一个称为记忆单元(Cell State)的概念,用于携带长期信息。这些门的组合使得LSTM能够选择性地记住或遗忘信息,从而解决了长时依赖问题。

2.优点
  1. 解决梯度消失问题 :通过门控机制,LSTM能够有效地传递梯度,避免了梯度消失和爆炸的问题。
  2. 捕捉长时依赖LSTM能够记住和利用长时间间隔的信息,这是标准RNN难以做到的。
  3. 灵活性LSTM适用于各种序列数据处理任务,如时间序列预测、语言建模和序列到序列的翻译等。
3.LSTMGRU的区别

GRU(门控循环单元)是另一种解决长时依赖问题的RNN变体。GRULSTM都引入了门控机制,但它们的具体实现有所不同。

  1. 结构简化GRU的结构比LSTM更简单,参数更少,计算效率更高。
  2. 性能对比 :在一些任务上,GRULSTM的性能相当,但在某些情况下,GRU可能表现更好,特别是在较小的数据集或较短的序列上。
  3. 门的数量LSTM有三个门(输入门、遗忘门和输出门),而GRU只有两个门(更新门和重置门)。
4.LSTM代码实践
python 复制代码
import torch
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as plt

# 设置批量大小和序列步数
batch_size, num_steps = 32, 35
# 加载时间机器数据集
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

# 初始化LSTM模型参数
def get_lstm_params(vocab_size, num_hiddens, device):
    # 输入输出的维度大小
    num_inputs = num_outputs = vocab_size

    # 正态分布初始化权重
    def normal(shape):
        return torch.randn(size=shape, device=device) * 0.01

    # 三个权重参数(用于输入门、遗忘门、输出门和候选记忆元)
    def three():
        return (normal((num_inputs, num_hiddens)),  # 输入到隐藏状态的权重
                normal((num_hiddens, num_hiddens)),  # 隐藏状态到隐藏状态的权重
                torch.zeros(num_hiddens, device=device))  # 偏置

    W_xi, W_hi, b_i = three()  # 输入门参数
    W_xf, W_hf, b_f = three()  # 遗忘门参数
    W_xo, W_ho, b_o = three()  # 输出门参数
    W_xc, W_hc, b_c = three()  # 候选记忆元参数
    # 输出层参数
    W_hq = normal((num_hiddens, num_outputs))  # 隐藏状态到输出的权重
    b_q = torch.zeros(num_outputs, device=device)  # 输出偏置
    # 将所有参数附加到参数列表中
    params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,
              b_c, W_hq, b_q]
    for param in params:
        param.requires_grad_(True)  # 设置参数需要梯度
    return params

# 初始化LSTM的隐藏状态
def init_lstm_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device),  # 隐藏状态
            torch.zeros((batch_size, num_hiddens), device=device))  # 记忆元

# LSTM前向传播
def lstm(inputs, state, params):
    [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,
     W_hq, b_q] = params
    (H, C) = state  # 隐藏状态和记忆元
    outputs = []
    for X in inputs:
        # 输入门
        I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)
        # 遗忘门
        F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)
        # 输出门
        O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)
        # 候选记忆元
        C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)
        # 更新记忆元
        C = F * C + I * C_tilda
        # 更新隐藏状态
        H = O * torch.tanh(C)
        # 计算输出
        Y = (H @ W_hq) + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H, C)  # 返回输出和状态

# 训练和预测模型
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
# 创建自定义的LSTM模型
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,
                            init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
plt.show()
# perplexity 1.3, 34433.0 tokens/sec on cuda:0
# 预测结果示例:time traveller conellace there wardeal that are almost us we hou

# 使用PyTorch的简洁实现
num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)  # 创建LSTM层
model = d2l.RNNModel(lstm_layer, len(vocab))  # 创建模型
model = model.to(device)  # 将模型移动到GPU
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
plt.show()
# perplexity 1.0, 317323.7 tokens/sec on cuda:0
# 预测结果示例:time travelleryou can show black is white by argument said filby

自定义的LSTM模型:


简洁实现:

相关推荐
埃菲尔铁塔_CV算法18 分钟前
深度学习神经网络在机器人领域应用的深度剖析:原理、实践与前沿探索
深度学习·神经网络·机器人
墨绿色的摆渡人2 小时前
用 Python 从零开始创建神经网络(三):添加层级(Adding Layers)
人工智能·python·深度学习·神经网络
B站计算机毕业设计超人2 小时前
计算机毕业设计Python+CNN卷积神经网络股票预测系统 股票推荐系统 股票可视化 股票数据分析 量化交易系统 股票爬虫 股票K线图 大数据毕业设计 AI
大数据·爬虫·python·深度学习·机器学习·课程设计·数据可视化
goomind3 小时前
YOLOv11实战PCB电路板缺陷识别
人工智能·python·深度学习·yolo·目标检测·计算机视觉·缺陷检测
yyfhq4 小时前
atttention1111
人工智能·pytorch·python
城市数据研习社4 小时前
【论文分享】三维景观格局如何影响城市居民的情绪
深度学习·机器学习·数据分析
城市数据研习社6 小时前
【论文分享】基于街景图像识别和深度学习的针对不同移动能力老年人的街道步行可达性研究——以南京成贤街社区为例
人工智能·深度学习·数据分析
B站计算机毕业设计超人8 小时前
计算机毕业设计Python+大模型中医养生问答系统 知识图谱 医疗大数据 中医可视化 机器学习 深度学习 人工智能 大数据毕业设计
大数据·人工智能·爬虫·python·深度学习·机器学习·知识图谱
爱数学的程序猿9 小时前
联邦学习的未来:深入剖析FedAvg算法与数据不均衡的解决之道
深度学习·学习·机器学习
开心星人10 小时前
【深度学习】wsl-ubuntu深度学习基本配置
人工智能·深度学习·ubuntu