RNN、LSTM、GRU技术博客

面向初学者的循环神经网络详解,从参数维度到代码实战,一篇文章彻底搞懂!


前言

在深度学习领域,处理序列数据(如文本、语音、时间序列)是一个非常重要的任务。传统的前馈神经网络无法很好地处理序列中的依赖关系,而 ** 循环神经网络(Recurrent Neural Network, RNN)** 及其变体 LSTM、GRU 应运而生,成为处理序列数据的利器。

本文将从最基础的张量维度理解开始,循序渐进地讲解 RNN、LSTM、GRU 的核心原理、网络结构、API 使用和优缺点,配合完整的 PyTorch 代码示例,帮助初学者彻底掌握这三种经典的循环神经网络。


一、RNN 循环神经网络核心知识

1.1 RNN 核心 9 个参数详解

初学者学习 RNN 最大的困惑就是张量维度!下面这 9 个参数是理解 RNN 的核心,一定要记牢:

🔹 输入张量 input 的 3 个维度
Plain 复制代码
input.shape = (seq_len, batch_size, input_size)
  • seq_len:序列长度,即一个样本包含多少个 token(词 / 时间步)

  • batch_size:批次大小,一次喂给模型多少个样本

  • input_size:输入特征维度,即每个 token 的向量维度(如词向量维度)

⚠️ 重要注意 :PyTorch 中默认 batch_first = False,所以 batch_size 在第二个维度!

如果设置 batch_first = True,则形状变为 (batch_size, seq_len, input_size )

🔹 隐藏状态 hidden 的 3 个维度
Plain 复制代码
h0.shape = (num_layers, batch_size, hidden_size)
  • num_layers:隐藏层数,RNN 堆叠的层数

  • batch_size:批次大小,与 input 的 batch_size 保持一致

  • hidden_size:隐藏层特征维度,即每个时间步输出的特征维度

🔹 RNN 层定义的 3 个参数
python 复制代码
torch.nn.RNN(input_size, hidden_size, num_layers)
  • input_size:输入特征维度,必须与 input 张量的最后一维一致

  • hidden_size:隐藏层特征维度,自定义的超参数

  • num_layers:隐藏层数,默认为 1

记忆口诀:9 个参数分三组,每组 3 个,input/hidden/RNN 各占 3 个!


1.2 循环网络结构 4 个示例(理解维度)

下面通过 4 个循序渐进的示例,彻底理解 RNN 的维度变化!

示例 1:最简单的情况

配置:1 层 RNN,1 个样本,1 个词,每个词 6 维向量

  • seq_len = 1(1 个词)

  • batch_size = 1(1 个样本)

  • input_size = 6(每个词 6 维)

  • num_layers = 1(1 层)

  • hidden_size = 8(输出 8 维特征)

Plain 复制代码
input.shape  = (1, 1, 6)
h0.shape     = (1, 1, 8)
output.shape = (1, 1, 8)  # 每个时间步的输出
hn.shape     = (1, 1, 8)  # 最后一个时间步的隐藏状态
示例 2:序列变长

配置:1 层 RNN,1 个样本,2 个词,每个词 6 维向量

  • seq_len = 2(2 个词)

  • batch_size = 1(1 个样本)

  • input_size = 6

Plain 复制代码
input.shape  = (2, 1, 6)
h0.shape     = (1, 1, 8)
output.shape = (2, 1, 8)  # 2个时间步,每个都有输出
hn.shape     = (1, 1, 8)  # 只保留最后一个时间步的隐藏状态

💡 观察:hn 永远等于 output 的最后一个时间步!

示例 3:批次变大

配置:1 层 RNN,2 个样本,每个样本 3 个词,每个词 6 维

  • seq_len = 3(每个样本 3 个词)

  • batch_size = 2(2 个样本并行处理)

  • input_size = 6

Plain 复制代码
input.shape  = (3, 2, 6)
h0.shape     = (1, 2, 8)
output.shape = (3, 2, 8)
hn.shape     = (1, 2, 8)

💡 观察:batch_size 维度只是并行处理,不影响其他维度!

示例 4:多层 RNN

配置:2 层 RNN,3 个样本,每个样本 4 个词,每个词 6 维

  • seq_len = 4

  • batch_size = 3

  • num_layers = 2(2 层堆叠)

  • input_size = 6

Plain 复制代码
input.shape  = (4, 3, 6)
h0.shape     = (2, 3, 8)
output.shape = (4, 3, 8)  # 只返回最后一层的所有时间步输出
hn.shape     = (2, 3, 8)  # 返回每一层最后一个时间步的隐藏状态

维度总结:无论怎么变,9 个核心参数的对应关系永远不变!


1.3 RNN 网络结构详解

🔹 隐藏状态值 ht 的特性

RNN 的核心就是隐藏状态(Hidden State),它相当于网络的记忆:

  1. 短期记忆特性

    • 由于梯度消失问题,RNN 实际上只能有效记住最近 10-20 个时间步的信息

    • 更早的信息会被指数级遗忘,这就是短期记忆的由来

  2. 梯度消失问题

    • 反向传播时,梯度会随着时间步不断相乘

    • 序列越长,梯度越容易趋近于 0,前面的层无法得到有效更新

  3. 信息混杂问题

    • 同一个 ht 向量既要编码当前输入的信息,又要承载历史信息,还要负责输出

    • 当序列变长时,它会不堪重负,容易出现信息丢失和混淆

🔹 两个输入

每个时间步,RNN 接收两个输入:

  1. Xt:当前时间步的输入(当前词的向量)

  2. Ht-1:上一个时间步的隐藏状态(历史记忆)

🔹 两个输出

每个时间步,RNN 产生两个输出:

  1. Ht:当前时间步的隐藏状态(更新后的记忆)

  2. Yt:当前时间步的输出(基于 Ht 计算得到)

📝 简单理解:RNN 就像一个人在读句子,每读一个词(Xt),就结合之前的理解(Ht-1),产生新的理解(Ht),并给出当前的判断(Yt)。


1.4 核心 API

python 复制代码
torch.nn.RNN(
    input_size,       # 输入特征维度
    hidden_size,      # 隐藏层特征维度
    num_layers=1,     # 隐藏层数
    bias=True,        # 是否使用偏置
    batch_first=False,# 是否将batch放在第一维
    dropout=0.0,      # dropout比率
    bidirectional=False, # 是否双向RNN
    proj_size=0       # 投影大小(高级用法)
)

1.5 RNN 优缺点

✅ 优点
  1. 结构简单轻量化:内部结构简单,参数共享,计算逻辑直观

  2. 训练速度快:需要资源少,计算量小,训练速度快

  3. 处理时序数据:能接收任意长度序列,适配文本、语音、时间序列

  4. 记忆上下文:隐藏态传递历史信息,能捕捉前后依赖关系

  5. 输入输出灵活:支持一对一、多对一、多对多等多种任务模式

❌ 缺点
  1. 无法处理长序列:存在严重的梯度消失和梯度爆炸问题

  2. 长距离依赖失效:无法学习远距离的关联,只能记住近期信息

  3. 串行计算:无法并行处理序列,训练推理速度慢

  4. 短期记忆局限:久远信息容易丢失,记忆容量有限

  5. 对长序列拟合差:复杂长时序任务效果弱

  6. 梯度优化难度大:训练不稳定,调参成本高


二、LSTM 长短时记忆网络

为了解决 RNN 的短期记忆问题,长短期记忆网络(Long Short-Term Memory, LSTM)在 1997 年被提出。LSTM 通过引入门控机制细胞状态,成功解决了梯度消失问题,能够有效处理长序列。

2.1 LSTM 网络结构

LSTM 的核心创新是三门控机制细胞状态(Cell State),相当于给网络增加了长期记忆。

🔹 三个门控机制

LSTM 有三个门,用来控制信息的流动:

  1. 遗忘门(Forget Gate)

    • 作用:决定从细胞状态中丢弃哪些信息

    • 例如:读到新的主语时,忘记旧的主语

    • 输出范围:0~1,0 表示完全遗忘,1 表示完全保留

  2. 输入门(Input Gate)

    • 作用:选择性地记忆当前输入中的重要内容

    • 分为两步:

      • sigmoid 层决定哪些信息需要更新

      • tanh 层生成候选值,准备加入细胞状态

    • 例如:新主语的性别、单复数等重要信息

  3. 输出门(Output Gate)

    • 作用:基于当前细胞状态,过滤并输出相关信息

    • 决定隐藏状态 ht 应该输出什么

    • 例如:根据主语决定动词的形式

🔹 细胞状态 Ct

细胞状态是 LSTM 的灵魂!

  • 它像一条 信息高速公路,在序列中直线传递重要的上下文信息

  • 只有少量的线性交互,几乎不受干扰,保证了长期记忆

  • 通过门控机制来精细控制:哪些该记住,哪些该遗忘

🔹 三个输入

每个时间步,LSTM 接收三个输入:

  1. Xt:当前时间步的输入

  2. Ht-1:上一个时间步的隐藏状态(短期记忆)

  3. Ct-1:上一个时间步的细胞状态(长期记忆)

🔹 三个输出

每个时间步,LSTM 产生三个输出:

  1. Ht:当前时间步的隐藏状态(更新后的短期记忆)

  2. Yt:当前时间步的输出

  3. Ct:当前时间步的细胞状态(更新后的长期记忆)

📝 形象理解:LSTM 就像一个有笔记本的人,Ct 是笔记本(长期记忆),ht 是工作记忆(短期记忆)。遗忘门就是擦除笔记本上没用的内容,输入门就是写下新的重要信息,输出门就是根据笔记本和当前思考给出回答。


2.2 核心 API

python 复制代码
torch.nn.LSTM(
    input_size,       # 输入特征维度
    hidden_size,      # 隐藏层特征维度
    num_layers=1,     # 隐藏层数
    bias=True,        # 是否使用偏置
    batch_first=False,# 是否将batch放在第一维
    dropout=0.0,      # dropout比率
    bidirectional=False, # 是否双向
    proj_size=0       # 投影大小
)

⚠️ 注意:LSTM 的输入输出与 RNN 略有不同:

  • 输入需要同时提供 h0 和 c0

  • 输出返回 (output, (hn, cn))


2.3 LSTM 优缺点

✅ 优点
  1. 处理长序列能力强:有效缓解梯度消失,能处理上百个时间步的序列

  2. 解决长距离依赖:能够学习序列中远距离的关联

  3. 可自主控制记忆 / 遗忘:三门控机制实现精细的信息管理

  4. 擅长处理时序数据:在文本、语音、时间序列上表现优秀

  5. 表达能力更强:比 GRU 功能更丰富,建模能力更高

❌ 缺点
  1. 结构复杂:三个门 + 细胞状态,参数量大

  2. 训练速度慢:计算量大,训练时间长

  3. 仍然串行计算:无法并行处理序列

  4. 超参数多:调参更麻烦,需要更多经验

  5. 对极长序列仍有压力:虽然比 RNN 强,但极长序列仍有挑战

  6. 可解释性差:门控机制的内部工作难以解释


三、GRU 门控循环单元

** 门控循环单元(Gated Recurrent Unit, GRU)** 是 LSTM 的简化版本,2014 年提出。它将 LSTM 的三个门简化为两个门,去掉了细胞状态,在保持性能的同时大幅简化了结构。

3.1 GRU 网络结构

GRU 将 LSTM 的遗忘门和输入门合并为更新门 ,新增重置门,去掉了独立的细胞状态。

🔹 两个门控机制
  1. 更新门(Update Gate)

    • 作用:控制前一时刻的状态信息有多少被保留到当前时刻

    • 相当于 LSTM 中遗忘门 + 输入门的组合

    • 值越大,保留的历史信息越多

  2. 重置门(Reset Gate)

    • 作用:控制忽略前一时刻状态信息的程度

    • 值越小,忽略的历史信息越多

    • 用于丢弃与未来无关的信息

🔹 输入输出结构

GRU 没有独立的细胞状态,隐藏状态 ht 同时承担短期和长期记忆:

两个输入

  1. Xt:当前时间步的输入

  2. Ht-1:上一个时间步的隐藏状态

两个输出

  1. Ht:当前时间步的隐藏状态

  2. Yt:当前时间步的输出

📝 GRU 的设计哲学:在 LSTM 的基础上做减法,用更少的参数达到相近的效果。实践中,GRU 和 LSTM 性能差异不大,但 GRU 训练更快。


3.2 核心 API

python 复制代码
torch.nn.GRU(
    input_size,       # 输入特征维度
    hidden_size,      # 隐藏层特征维度
    num_layers=1,     # 隐藏层数
    bias=True,        # 是否使用偏置
    batch_first=False,# 是否将batch放在第一维
    dropout=0.0,      # dropout比率
    bidirectional=False # 是否双向
)

⚠️ 注意:GRU 的输入输出格式和 RNN 完全一样!

  • 只需要提供 h0,不需要 c0

  • 返回 (output, hn),和 RNN 一致


3.3 GRU 优缺点

✅ 优点
  1. 结构简单:比 LSTM 少一个门,参数量减少约 1/3

  2. 训练速度快:计算量小,收敛更快

  3. 调参更容易:超参数少,更容易训练

  4. 缓解梯度消失:同样能处理较长序列

  5. 内存占用小:模型更小,部署更友好

❌ 缺点
  1. 表达能力略弱:建模能力比 LSTM 稍差

  2. 复杂任务表现稍逊:在非常复杂的序列任务上可能不如 LSTM

  3. 仍然串行计算:无法并行

  4. 极长序列仍有限制:和 LSTM 一样,极长序列仍有挑战


四、完整代码示例

4.1 RNN 完整代码(4 个示例)

python 复制代码
import torch

# ============================================================
# 核心知识点 -> RNN 9大参数
# input:    (seq_len样本序列token数, batch_size批次大小, input_size输入特征维度)
# h0:       (num_layers隐藏层数, batch_size批次大小, hidden_size输出特征维度)
# RNN层:    (input_size输入特征维度, hidden_size输出特征维度, num_layers隐藏层数)
# ============================================================

def demo1_run_api():
    """
    示例1:1层RNN,1个样本,每个样本1个词,每个词4维词向量
    """
    # input: (seq_len, batch_size, input_size) = (1, 1, 4)
    input = torch.randn(1, 1, 4)
    
    # h0: (num_layers, batch_size, hidden_size) = (1, 1, 8)
    h0 = torch.zeros(1, 1, 8)
    
    # RNN层: (input_size, hidden_size, num_layers) = (4, 8, 1)
    my_rnn = torch.nn.RNN(4, 8, 1)
    
    # 前向传播
    output, hn = my_rnn(input, h0)
    
    print("="*50)
    print("示例1:1层RNN,1个样本,1个词")
    print(f"input形状:  {input.shape}")
    print(f"h0形状:     {h0.shape}")
    print(f"output形状: {output.shape}")  # (1, 1, 8) - 1个时间步的输出
    print(f"hn形状:     {hn.shape}")      # (1, 1, 8) - 最后一个时间步的隐藏状态
    print(f"hn == output[-1]: {torch.allclose(hn, output[-1])}")  # True!


def demo2_run_api():
    """
    示例2:1层RNN,1个样本,每个样本2个词,每个词4维词向量
    """
    # input: (2, 1, 4) - 2个时间步
    input = torch.randn(2, 1, 4)
    h0 = torch.zeros(1, 1, 8)
    my_rnn = torch.nn.RNN(4, 8, 1)
    output, hn = my_rnn(input, h0)
    
    print("="*50)
    print("示例2:1层RNN,1个样本,2个词")
    print(f"input形状:  {input.shape}")
    print(f"output形状: {output.shape}")  # (2, 1, 8) - 2个时间步都有输出
    print(f"hn形状:     {hn.shape}")      # (1, 1, 8) - 只保留最后一个时间步
    print(f"hn == output[-1]: {torch.allclose(hn, output[-1])}")  # True!


def demo3_run_api():
    """
    示例3:1层RNN,2个样本,每个样本3个词,每个词4维词向量
    """
    # input: (3, 2, 4) - 3个时间步,2个样本并行
    input = torch.randn(3, 2, 4)
    h0 = torch.zeros(1, 2, 8)
    my_rnn = torch.nn.RNN(4, 8, 1)
    output, hn = my_rnn(input, h0)
    
    print("="*50)
    print("示例3:1层RNN,2个样本,3个词")
    print(f"input形状:  {input.shape}")
    print(f"output形状: {output.shape}")  # (3, 2, 8)
    print(f"hn形状:     {hn.shape}")      # (1, 2, 8)
    print(f"hn == output[-1]: {torch.allclose(hn, output[-1])}")  # True!


def demo4_run_api():
    """
    示例4:3层RNN,3个样本,每个样本4个词,每个词4维词向量
    """
    # input: (4, 3, 4)
    input = torch.randn(4, 3, 4)
    h0 = torch.zeros(3, 3, 8)  # 3层!
    my_rnn = torch.nn.RNN(4, 8, 3)
    output, hn = my_rnn(input, h0)
    
    print("="*50)
    print("示例4:3层RNN,3个样本,4个词")
    print(f"input形状:  {input.shape}")
    print(f"output形状: {output.shape}")  # (4, 3, 8) - 只返回最后一层的输出
    print(f"hn形状:     {hn.shape}")      # (3, 3, 8) - 返回每一层的最后一个时间步
    print(f"hn[-1] == output[-1]: {torch.allclose(hn[-1], output[-1])}")  # True!


if __name__ == '__main__':
    demo1_run_api()
    demo2_run_api()
    demo3_run_api()
    demo4_run_api()

运行结果说明

  • 所有示例中,hn 永远等于 output 的最后一个时间步!

  • 多层 RNN 时,hn[-1] 等于 output[-1]

  • batch 维度只是并行处理,不影响逻辑


4.2 LSTM 完整代码(4 个示例)

python 复制代码
import torch

# ============================================================
# 核心知识点 -> LSTM 与 RNN的区别:多了细胞状态c!
# 输入需要 (h0, c0),输出返回 (output, (hn, cn))
# ============================================================

def lstm_demo1():
    """
    示例1:1层LSTM,1个样本,1个词,每个词4维
    """
    input = torch.randn(1, 1, 4)
    h0 = torch.zeros(1, 1, 8)
    c0 = torch.zeros(1, 1, 8)  # LSTM多了c0!
    
    my_lstm = torch.nn.LSTM(4, 8, 1)
    output, (hn, cn) = my_lstm(input, (h0, c0))
    
    print("="*50)
    print("LSTM示例1:1层,1个样本,1个词")
    print(f"input形状:  {input.shape}")
    print(f"h0形状:     {h0.shape}")
    print(f"c0形状:     {c0.shape}")
    print(f"output形状: {output.shape}")
    print(f"hn形状:     {hn.shape}")
    print(f"cn形状:     {cn.shape}")


def lstm_demo2():
    """
    示例2:1层LSTM,1个样本,2个词,每个词4维
    """
    input = torch.randn(2, 1, 4)
    h0 = torch.zeros(1, 1, 8)
    c0 = torch.zeros(1, 1, 8)
    
    my_lstm = torch.nn.LSTM(4, 8, 1)
    output, (hn, cn) = my_lstm(input, (h0, c0))
    
    print("="*50)
    print("LSTM示例2:1层,1个样本,2个词")
    print(f"input形状:  {input.shape}")
    print(f"output形状: {output.shape}")  # (2, 1, 8)
    print(f"hn形状:     {hn.shape}")      # (1, 1, 8)
    print(f"cn形状:     {cn.shape}")      # (1, 1, 8)


def lstm_demo3():
    """
    示例3:1层LSTM,2个样本,3个词,每个词4维
    """
    input = torch.randn(3, 2, 4)
    h0 = torch.zeros(1, 2, 8)
    c0 = torch.zeros(1, 2, 8)
    
    my_lstm = torch.nn.LSTM(4, 8, 1)
    output, (hn, cn) = my_lstm(input, (h0, c0))
    
    print("="*50)
    print("LSTM示例3:1层,2个样本,3个词")
    print(f"input形状:  {input.shape}")
    print(f"output形状: {output.shape}")  # (3, 2, 8)
    print(f"hn形状:     {hn.shape}")      # (1, 2, 8)
    print(f"cn形状:     {cn.shape}")      # (1, 2, 8)


def lstm_demo4():
    """
    示例4:3层LSTM,3个样本,4个词,每个词4维
    """
    input = torch.randn(4, 3, 4)
    h0 = torch.zeros(3, 3, 8)
    c0 = torch.zeros(3, 3, 8)
    
    my_lstm = torch.nn.LSTM(4, 8, 3)
    output, (hn, cn) = my_lstm(input, (h0, c0))
    
    print("="*50)
    print("LSTM示例4:3层,3个样本,4个词")
    print(f"input形状:  {input.shape}")
    print(f"output形状: {output.shape}")  # (4, 3, 8)
    print(f"hn形状:     {hn.shape}")      # (3, 3, 8)
    print(f"cn形状:     {cn.shape}")      # (3, 3, 8)


if __name__ == '__main__':
    lstm_demo1()
    lstm_demo2()
    lstm_demo3()
    lstm_demo4()

LSTM 与 RNN 的关键区别

  1. 初始化时需要同时提供 h0c0,形状完全相同

  2. 前向传播返回 (output, (hn, cn)),多了 cn

  3. 其他维度规则与 RNN 完全一致!


五、三者对比总结

对比表格

对比维度 RNN LSTM GRU
门控数量 无门控 3 个门(遗忘、输入、输出) 2 个门(更新、重置)
细胞状态 有(独立的长期记忆) 无(合并到隐藏状态)
参数量 最少 最多 中等(约为 LSTM 的 2/3)
训练速度 最快 最慢 中等
长序列能力 差(<20 步) 好(~100 步) 较好
梯度消失 严重 大幅缓解 缓解
表达能力 最弱 最强 较强
调参难度 简单 复杂 中等
内存占用 最小 最大 中等

适用场景选择

🎯 什么时候用 RNN?
  • 序列非常短(<20 个时间步)

  • 对速度要求极高

  • 模型需要极度轻量化

  • 简单的序列任务

🎯 什么时候用 LSTM?
  • 序列较长,需要捕捉长距离依赖

  • 任务复杂,需要更强的建模能力

  • 对性能要求高于速度

  • 机器翻译、文本生成、语音识别等复杂任务

🎯 什么时候用 GRU?
  • 想要比 RNN 强,又不想像 LSTM 那么慢

  • 数据集相对较小,防止过拟合

  • 需要快速迭代实验

  • 大多数情况下是 LSTM 的优秀替代品

初学者学习建议

  1. 先搞懂维度:9 个参数是基础,维度搞不懂一切都是空谈

  2. 从 RNN 开始:先理解简单的 RNN,再学 LSTM 和 GRU

  3. 动手跑代码:一定要亲手运行本文的代码,观察输出形状

  4. 理解直觉:用记忆的直觉理解门控机制,不要一开始就抠数学公式

  5. 实践出真知:在真实数据集上训练,感受三者的差异


结语

RNN、LSTM、GRU 是深度学习处理序列数据的基石。虽然现在 Transformer 大行其道,但循环神经网络的思想仍然非常重要,是理解序列建模的必经之路。

希望这篇文章能帮助你彻底搞懂这三种网络!记住:维度是基础,直觉是关键,代码是验证

祝你学习愉快!🚀

相关推荐
MediaTea10 小时前
DL:循环神经网络的基本原理与 PyTorch 实现
人工智能·pytorch·rnn·深度学习·神经网络
udc小白10 小时前
Excel实现LSTM示例
人工智能·深度学习·神经网络·机器学习·excel·lstm
ZHW_AI课题组1 天前
基于LSTM的天气预测
人工智能·rnn·lstm
啦啦啦_99991 天前
RNN 入门
人工智能·rnn·深度学习
风落无尘2 天前
第九章《语言与理解》 完整学习资料
gpt·rnn·语言模型·transformer
初心未改HD2 天前
深度学习之LSTM与GRU门控循环单元详解
深度学习·gru·lstm
初心未改HD2 天前
深度学习之RNN循环神经网络详解
人工智能·rnn·深度学习
Yunzenn3 天前
深度分析字节最新研究cola-DLM第 01 章:语言生成的三次范式之争 —— 从 RNN 到 AR 到扩散
linux·人工智能·rnn·深度学习·机器学习·架构·transformer
郑同学zxc3 天前
机器学习20-RNN
人工智能·rnn·机器学习