pytorch深度学习-LSTM—循环神经网络变种

先用通俗的话介绍一下什么是LSTM

咱们用一个生活化的例子,把 LSTM 说清楚。

想象你在看一部很长的电影,需要记住剧情来理解后面的发展。普通的 "循环神经网络(RNN)" 就像一个记性不好的人:前面的剧情看了没多久就忘了,看到后面关键情节时,可能已经不记得前面的伏笔(比如 "前面出现的那个配角,其实是后面的反派"),导致理解出错。

而 LSTM(长短期记忆网络)就像一个 "有好记性的助手",它能主动判断:哪些剧情该记、哪些该忘、现在需要用哪些记忆。具体怎么做到的?它靠三个 "功能开关"(专业叫 "门")来管理记忆,就像你整理抽屉的过程:

1. 第一个开关:"该扔啥"(遗忘门)

你的抽屉(可以理解为 "长期记忆库")里堆了很多旧东西(前面的剧情、信息)。有些东西早就没用了(比如电影开头一闪而过的路人),留着占地方。这时候 "遗忘门" 就像你的手,会主动把这些没用的东西扔掉,给新东西腾地方。

2. 第二个开关:"该存啥"(输入门)

新剧情来了(比如电影里主角拿到了一把钥匙),这个信息可能很重要(后面要用来开锁)。"输入门" 会判断:这个新信息值得放进抽屉吗?如果值得,就把它整齐地放进去,和剩下的有用旧东西(比如主角的身份)放在一起。

3. 第三个开关:"该用啥"(输出门)

现在看到电影的某个片段(比如主角走到一扇锁着的门),你需要从抽屉里找有用的记忆来理解:"哦,前面拿到的钥匙应该就是开这扇门的!" 这时候 "输出门" 会从抽屉里精准地把 "钥匙" 这个记忆拿出来,帮你理解当前的情节。

总结一下

LSTM 就像一个 "智能抽屉":

  • 抽屉本身(细胞状态)用来存长期重要的记忆;
  • 三个开关(遗忘门、输入门、输出门)负责 "扔垃圾、存新货、取要用的东西";
  • 这样一来,不管剧情多长(序列多复杂),它都能记住关键信息,不会像普通 RNN 那样 "记了后面忘前面"。

它能干啥?

比如:

  • 理解一句话:"他昨天买了个西瓜,今天它坏了"------LSTM 能记住 "它" 指的是 "西瓜"(前面的信息);
  • 预测明天的天气:需要记住前几天的温度、湿度变化(长期趋势);
  • 语音转文字:能把 "我明天...(中间有杂音)... 班" 补全成 "我明天上班"(靠前后文记忆)。

简单说,LSTM 就是个 "会聪明记东西的网络",专门解决 "记不住长远信息" 的问题~

专业术语介绍

核心结构:一个 "仓库"+ 三个 "管理员"

LSTM 的核心是 "细胞状态(Cell State)" (可以理解为 "长期记忆仓库"),以及三个控制它的 "门(Gate)"(相当于仓库管理员)。

1. 细胞状态(Cell State)------ 仓库本身

类比:就像你家那个大衣柜,专门存长期重要的东西(比如过冬的羽绒服、毕业证这种不常用但必须留着的)。

专业说:它是一条贯穿整个序列(比如一句话、一串时间点)的 "信息高速公路",用来存储长期需要记住的核心信息,状态变化缓慢(不会随便丢东西)。

2. 遗忘门(Forget Gate)------ 扔垃圾的管理员

类比:衣柜里堆了去年穿破的旧 T 恤、过期的优惠券,这个管理员会检查:"这些还有用吗?没用就扔了腾地方。"

专业说:它通过一个 sigmoid 函数 输出 0-1 之间的数值(0 = 全扔,1 = 全留),决定细胞状态中哪些旧信息应该被遗忘(比如句子里早就没用的代词、时间序列里的噪音)。

3. 输入门(Input Gate)------ 存新货的管理员

类比:你新买了一件冲锋衣,这个管理员会先判断:"这件需要放进衣柜长期保存吗?" 如果需要,就把它整理好(比如叠整齐),再放进衣柜和剩下的有用旧东西(比如羽绒服)放在一起。

专业说:它分两步:

  • 第一步:用 sigmoid 函数筛选 "新输入的信息里哪些值得存"(0 = 不存,1 = 存);
  • 第二步:把选中的新信息用 tanh 函数 处理成 "候选细胞状态(Candidate Cell State)"(相当于 "叠整齐的新衣服");
  • 最后,把候选状态 "加" 到细胞状态里(更新仓库库存)。
4. 输出门(Output Gate)------ 取东西的管理员

类比:现在要出门露营,你需要从衣柜里找冲锋衣。这个管理员会从衣柜里挑出 "冲锋衣",递给你(同时可能也会把 "要不要带羽绒服" 的信息记下来,给下次参考)。

专业说:它也分两步:

  • 第一步:用 sigmoid 函数从细胞状态里筛选 "当前需要用哪些信息";
  • 第二步:把选中的信息用 tanh 函数处理后,生成 隐藏状态(Hidden State)(相当于 "递给你的冲锋衣");
  • 隐藏状态既是当前时刻的输出(比如理解句子的当前词、预测当前时间点的数值),也会传给下一个时刻的 LSTM 单元(作为 "下次参考的信息")。

串起来再看:

比如处理句子 "小明买了个苹果,他吃了它":

  • 细胞状态(仓库):会记住 "苹果" 这个核心信息(长期记忆);
  • 遗忘门:看到 "他" 时,会忘记前面可能无关的信息(比如 "买" 这个动作的细节);
  • 输入门:把 "苹果" 这个新信息存入细胞状态;
  • 输出门:看到 "它" 时,从细胞状态里取出 "苹果",确定 "它" 指的是苹果(隐藏状态输出这个结果)。

术语小结表:

专业术语 类比角色 核心功能
细胞状态(Cell State) 长期记忆仓库 存储贯穿序列的核心信息
遗忘门(Forget Gate) 扔垃圾的管理员 筛选并丢弃细胞状态中的无用旧信息
输入门(Input Gate) 存新货的管理员 筛选并存储新信息到细胞状态
候选细胞状态 整理好的新货 待存入仓库的新信息(预处理后)
输出门(Output Gate) 取东西的管理员 从细胞状态中筛选信息生成隐藏状态
隐藏状态(Hidden State) 当前要用的东西 作为当前输出,同时传给下一个时刻

这样一来,每个术语对应的 "功能" 和 "类比角色" 就清晰了~ 核心就是:细胞状态存长期记忆,三个门控(遗忘、输入、输出)负责 "删、存、取",解决了普通 RNN "记不住长远信息" 的问题。

简易代码实战(简模型构建)

一、初始化函数:创建四个线性层和激活函数

复制代码
def __init__(self, input_size, hidden_size):
    super(FixedLSTMCell, self).__init__()
    self.hidden_size = hidden_size
    
    # 为每个门单独设置线性层
    self.forget_gate = nn.Linear(input_size + hidden_size, hidden_size)
    self.input_gate = nn.Linear(input_size + hidden_size, hidden_size)
    self.cell_gate = nn.Linear(input_size + hidden_size, hidden_size)
    self.output_gate = nn.Linear(input_size + hidden_size, hidden_size)
    
    self.sigmoid = nn.Sigmoid()  # 门控函数(输出0-1之间的值)
    self.tanh = nn.Tanh()        # 状态激活函数(输出-1到1之间的值)

关键点:

  • 线性层的输入维度input_size + hidden_size,因为要拼接当前输入和上一时刻的隐藏状态
  • 线性层的输出维度hidden_size,与隐藏状态维度一致
  • sigmoid 函数:将值压缩到 [0,1] 区间,用于控制门的 "开关程度"
  • tanh 函数:将值压缩到 [-1,1] 区间,用于生成新信息和规范化细胞状态

二、前向传播:五步核心操作

1. 输入拼接
复制代码
combined = torch.cat((input, hidden), 1)
  • 操作 :将当前输入 input(形状:[batch_size, input_size])和上一时刻的隐藏状态 hidden(形状:[batch_size, hidden_size])在维度 1 上拼接
  • 结果combined 的形状为 [batch_size, input_size + hidden_size]
2. 计算遗忘门
复制代码
f_t = self.sigmoid(self.forget_gate(combined))
  • 数学公式
  • 分步操作
    1. self.forget_gate(combined):线性变换,计算
    2. self.sigmoid(...):应用 sigmoid 函数,将结果压缩到 [0,1] 区间
  • 作用:决定上一时刻的细胞状态 \(C_{t-1}\) 中哪些信息需要被遗忘
3. 计算输入门和候选状态
复制代码
i_t = self.sigmoid(self.input_gate(combined))
c_tilde = self.tanh(self.cell_gate(combined))
  • 输入门数学公式
    • 作用:决定候选状态 \(\tilde{C}_t\) 中哪些信息需要被加入到细胞状态
  • 候选状态数学公式
    • 作用:基于当前输入生成的新信息,范围在 [-1,1] 之间
4. 更新细胞状态
复制代码
cell = f_t * cell + i_t * c_tilde
  • 数学公式
  • 分步操作
    1. 遗忘旧信息f_t * cell,将上一时刻的细胞状态按元素与遗忘门输出相乘
      • 如果 f_t某维度为 0,则对应维度的旧信息被完全遗忘
    2. 添加新信息i_t * c_tilde,将候选状态按元素与输入门输出相乘
      • 如果 i_t 某维度为 1,则对应维度的新信息被完全保留
    3. 合并操作:将两部分相加,得到新的细胞状态
5. 计算输出门和更新隐藏状态
复制代码
o_t = self.sigmoid(self.output_gate(combined))
hidden = o_t * self.tanh(cell)
  • 输出门数学公式
    • 作用:决定细胞状态 C_t中哪些信息需要被输出为隐藏状态 h_t
  • 隐藏状态更新公式
    • 分步操作
      1. self.tanh(cell):将细胞状态的值压缩到 [-1,1] 区间
      2. o_t * ...:按元素与输出门相乘,控制输出哪些信息

三、用具体数字示例

假设:

  • input_size = 3(当前输入有 3 个特征)
  • hidden_size = 4(隐藏状态有 4 个维度)
  • batch_size = 1(单个样本)
1. 输入和隐藏状态
复制代码
input = torch.tensor([[0.2, 0.5, 0.3]])  # 形状: [1, 3]
hidden = torch.tensor([[0.1, 0.4, 0.3, 0.2]])  # 形状: [1, 4]
2. 拼接后的结果
复制代码
combined = torch.cat((input, hidden), 1)
# combined = [[0.2, 0.5, 0.3, 0.1, 0.4, 0.3, 0.2]]  # 形状: [1, 7]
3. 计算遗忘门(简化示例)

假设 forget_gate 的权重矩阵 W_f和偏置 b_f为:

复制代码
W_f = [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7],
       [0.2, 0.1, 0.4, 0.3, 0.6, 0.5, 0.8],
       [0.3, 0.4, 0.1, 0.2, 0.7, 0.8, 0.5],
       [0.4, 0.3, 0.2, 0.1, 0.8, 0.7, 0.6]]  # 形状: [4, 7]
b_f = [0.1, 0.2, 0.3, 0.4]  # 形状: [4]

则:

复制代码
# 线性变换: W_f * combined + b_f
linear_output = [[1.05, 1.25, 1.35, 1.45]]  # 形状: [1, 4]

# 应用sigmoid
f_t = sigmoid(linear_output) = [[0.74, 0.78, 0.80, 0.81]]  # 形状: [1, 4]

四、完整代码 + 注释

复制代码
class FixedLSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(FixedLSTMCell, self).__init__()
        self.hidden_size = hidden_size
        
        # 四个线性层,分别对应四个门
        self.forget_gate = nn.Linear(input_size + hidden_size, hidden_size)  # 遗忘门
        self.input_gate = nn.Linear(input_size + hidden_size, hidden_size)   # 输入门
        self.cell_gate = nn.Linear(input_size + hidden_size, hidden_size)    # 候选状态
        self.output_gate = nn.Linear(input_size + hidden_size, hidden_size)  # 输出门
        
        # 激活函数
        self.sigmoid = nn.Sigmoid()  # 将值压缩到[0,1],用于门控
        self.tanh = nn.Tanh()        # 将值压缩到[-1,1],用于生成新信息和规范化细胞状态
    
    def forward(self, input, hidden, cell):
        """
        input: 当前输入,形状为 [batch_size, input_size]
        hidden: 上一时刻的隐藏状态,形状为 [batch_size, hidden_size]
        cell: 上一时刻的细胞状态,形状为 [batch_size, hidden_size]
        """
        # 1. 拼接当前输入和上一时刻的隐藏状态
        combined = torch.cat((input, hidden), 1)  # 形状: [batch_size, input_size + hidden_size]
        
        # 2. 计算四个门控信号
        f_t = self.sigmoid(self.forget_gate(combined))  # 遗忘门: [batch_size, hidden_size]
        i_t = self.sigmoid(self.input_gate(combined))   # 输入门: [batch_size, hidden_size]
        c_tilde = self.tanh(self.cell_gate(combined))   # 候选状态: [batch_size, hidden_size]
        o_t = self.sigmoid(self.output_gate(combined))  # 输出门: [batch_size, hidden_size]
        
        # 3. 更新细胞状态
        # - f_t * cell: 选择性遗忘旧信息
        # - i_t * c_tilde: 选择性添加新信息
        cell = f_t * cell + i_t * c_tilde  # 形状: [batch_size, hidden_size]
        
        # 4. 更新隐藏状态
        # - self.tanh(cell): 规范化细胞状态的值到[-1,1]
        # - o_t * ...: 选择性输出细胞状态的信息
        hidden = o_t * self.tanh(cell)  # 形状: [batch_size, hidden_size]
        
        return hidden, cell  # 返回新的隐藏状态和细胞状态

五、关键总结

  1. 四个线性层对应四个门,每个门学习不同的权重矩阵
  2. 三个门控机制(遗忘门、输入门、输出门)通过 sigmoid 函数实现 "开关" 控制
  3. 细胞状态通过选择性遗忘和添加实现长期记忆
  4. 隐藏状态是细胞状态的选择性输出,也是 LSTM 的最终输出
相关推荐
机器学习之心6 小时前
三种深度学习模型(LSTM、CNN-LSTM、贝叶斯优化的CNN-LSTM/BO-CNN-LSTM)对北半球光伏数据进行时间序列预测
深度学习·cnn·lstm·cnn-lstm·贝叶斯优化的cnn-lstm
技术猿188702783516 小时前
实现“micro 关键字搜索全覆盖商品”并通过 API 接口提供实时数据(一个方法)
开发语言·网络·python·深度学习·测试工具
停走的风7 小时前
(李宏毅)deep learning(五)--learning rate
人工智能·深度学习·机器学习
fishjar1007 小时前
LLaMA-Factory安装部署
人工智能·深度学习
feifeikon7 小时前
模型篇(Bert llama deepseek)
人工智能·深度学习·自然语言处理
DogDaoDao9 小时前
2025年 GitHub 主流开源视频生成模型介绍
人工智能·深度学习·开源·大模型·github·音视频·视频生成
大千AI助手9 小时前
OpenAI GPT-4o技术详解:全能多模态模型的架构革新与生态影响
人工智能·深度学习·机器学习·自然语言处理·大模型·openai·gpt-4o
昵称是6硬币9 小时前
(RT-DETR)DETRs Beat YOLOs on Real-time Object Detection论文精读(逐段解析)
图像处理·人工智能·深度学习·目标检测·计算机视觉·transformer
天选之女wow12 小时前
【Gaussian Haircut论文】在Deepseek和Chatgpt的帮助下慢速了解核心方法
人工智能·深度学习·计算机视觉·3d