面向初学者的循环神经网络详解,从参数维度到代码实战,一篇文章彻底搞懂!
前言
在深度学习领域,处理序列数据(如文本、语音、时间序列)是一个非常重要的任务。传统的前馈神经网络无法很好地处理序列中的依赖关系,而 ** 循环神经网络(Recurrent Neural Network, RNN)** 及其变体 LSTM、GRU 应运而生,成为处理序列数据的利器。
本文将从最基础的张量维度理解开始,循序渐进地讲解 RNN、LSTM、GRU 的核心原理、网络结构、API 使用和优缺点,配合完整的 PyTorch 代码示例,帮助初学者彻底掌握这三种经典的循环神经网络。
一、RNN 循环神经网络核心知识
1.1 RNN 核心 9 个参数详解
初学者学习 RNN 最大的困惑就是张量维度!下面这 9 个参数是理解 RNN 的核心,一定要记牢:
🔹 输入张量 input 的 3 个维度
Plain
input.shape = (seq_len, batch_size, input_size)
-
seq_len:序列长度,即一个样本包含多少个 token(词 / 时间步)
-
batch_size:批次大小,一次喂给模型多少个样本
-
input_size:输入特征维度,即每个 token 的向量维度(如词向量维度)
⚠️ 重要注意 :PyTorch 中默认
batch_first = False,所以 batch_size 在第二个维度!如果设置
batch_first = True,则形状变为(batch_size, seq_len, input_size )
🔹 隐藏状态 hidden 的 3 个维度
Plain
h0.shape = (num_layers, batch_size, hidden_size)
-
num_layers:隐藏层数,RNN 堆叠的层数
-
batch_size:批次大小,与 input 的 batch_size 保持一致
-
hidden_size:隐藏层特征维度,即每个时间步输出的特征维度
🔹 RNN 层定义的 3 个参数
python
torch.nn.RNN(input_size, hidden_size, num_layers)
-
input_size:输入特征维度,必须与 input 张量的最后一维一致
-
hidden_size:隐藏层特征维度,自定义的超参数
-
num_layers:隐藏层数,默认为 1
✅ 记忆口诀:9 个参数分三组,每组 3 个,input/hidden/RNN 各占 3 个!
1.2 循环网络结构 4 个示例(理解维度)
下面通过 4 个循序渐进的示例,彻底理解 RNN 的维度变化!
示例 1:最简单的情况
配置:1 层 RNN,1 个样本,1 个词,每个词 6 维向量
-
seq_len = 1(1 个词)
-
batch_size = 1(1 个样本)
-
input_size = 6(每个词 6 维)
-
num_layers = 1(1 层)
-
hidden_size = 8(输出 8 维特征)
Plain
input.shape = (1, 1, 6)
h0.shape = (1, 1, 8)
output.shape = (1, 1, 8) # 每个时间步的输出
hn.shape = (1, 1, 8) # 最后一个时间步的隐藏状态
示例 2:序列变长
配置:1 层 RNN,1 个样本,2 个词,每个词 6 维向量
-
seq_len = 2(2 个词)
-
batch_size = 1(1 个样本)
-
input_size = 6
Plain
input.shape = (2, 1, 6)
h0.shape = (1, 1, 8)
output.shape = (2, 1, 8) # 2个时间步,每个都有输出
hn.shape = (1, 1, 8) # 只保留最后一个时间步的隐藏状态
💡 观察:hn 永远等于 output 的最后一个时间步!
示例 3:批次变大
配置:1 层 RNN,2 个样本,每个样本 3 个词,每个词 6 维
-
seq_len = 3(每个样本 3 个词)
-
batch_size = 2(2 个样本并行处理)
-
input_size = 6
Plain
input.shape = (3, 2, 6)
h0.shape = (1, 2, 8)
output.shape = (3, 2, 8)
hn.shape = (1, 2, 8)
💡 观察:batch_size 维度只是并行处理,不影响其他维度!
示例 4:多层 RNN
配置:2 层 RNN,3 个样本,每个样本 4 个词,每个词 6 维
-
seq_len = 4
-
batch_size = 3
-
num_layers = 2(2 层堆叠)
-
input_size = 6
Plain
input.shape = (4, 3, 6)
h0.shape = (2, 3, 8)
output.shape = (4, 3, 8) # 只返回最后一层的所有时间步输出
hn.shape = (2, 3, 8) # 返回每一层最后一个时间步的隐藏状态
✅ 维度总结:无论怎么变,9 个核心参数的对应关系永远不变!
1.3 RNN 网络结构详解
🔹 隐藏状态值 ht 的特性
RNN 的核心就是隐藏状态(Hidden State),它相当于网络的记忆:
-
短期记忆特性
-
由于梯度消失问题,RNN 实际上只能有效记住最近 10-20 个时间步的信息
-
更早的信息会被指数级遗忘,这就是短期记忆的由来
-
-
梯度消失问题
-
反向传播时,梯度会随着时间步不断相乘
-
序列越长,梯度越容易趋近于 0,前面的层无法得到有效更新
-
-
信息混杂问题
-
同一个 ht 向量既要编码当前输入的信息,又要承载历史信息,还要负责输出
-
当序列变长时,它会不堪重负,容易出现信息丢失和混淆
-
🔹 两个输入
每个时间步,RNN 接收两个输入:
-
Xt:当前时间步的输入(当前词的向量)
-
Ht-1:上一个时间步的隐藏状态(历史记忆)
🔹 两个输出
每个时间步,RNN 产生两个输出:
-
Ht:当前时间步的隐藏状态(更新后的记忆)
-
Yt:当前时间步的输出(基于 Ht 计算得到)
📝 简单理解:RNN 就像一个人在读句子,每读一个词(Xt),就结合之前的理解(Ht-1),产生新的理解(Ht),并给出当前的判断(Yt)。
1.4 核心 API
python
torch.nn.RNN(
input_size, # 输入特征维度
hidden_size, # 隐藏层特征维度
num_layers=1, # 隐藏层数
bias=True, # 是否使用偏置
batch_first=False,# 是否将batch放在第一维
dropout=0.0, # dropout比率
bidirectional=False, # 是否双向RNN
proj_size=0 # 投影大小(高级用法)
)
1.5 RNN 优缺点
✅ 优点
-
结构简单轻量化:内部结构简单,参数共享,计算逻辑直观
-
训练速度快:需要资源少,计算量小,训练速度快
-
处理时序数据:能接收任意长度序列,适配文本、语音、时间序列
-
记忆上下文:隐藏态传递历史信息,能捕捉前后依赖关系
-
输入输出灵活:支持一对一、多对一、多对多等多种任务模式
❌ 缺点
-
无法处理长序列:存在严重的梯度消失和梯度爆炸问题
-
长距离依赖失效:无法学习远距离的关联,只能记住近期信息
-
串行计算:无法并行处理序列,训练推理速度慢
-
短期记忆局限:久远信息容易丢失,记忆容量有限
-
对长序列拟合差:复杂长时序任务效果弱
-
梯度优化难度大:训练不稳定,调参成本高
二、LSTM 长短时记忆网络
为了解决 RNN 的短期记忆问题,长短期记忆网络(Long Short-Term Memory, LSTM)在 1997 年被提出。LSTM 通过引入门控机制 和细胞状态,成功解决了梯度消失问题,能够有效处理长序列。
2.1 LSTM 网络结构
LSTM 的核心创新是三门控机制 和细胞状态(Cell State),相当于给网络增加了长期记忆。
🔹 三个门控机制
LSTM 有三个门,用来控制信息的流动:
-
遗忘门(Forget Gate)
-
作用:决定从细胞状态中丢弃哪些信息
-
例如:读到新的主语时,忘记旧的主语
-
输出范围:0~1,0 表示完全遗忘,1 表示完全保留
-
-
输入门(Input Gate)
-
作用:选择性地记忆当前输入中的重要内容
-
分为两步:
-
sigmoid 层决定哪些信息需要更新
-
tanh 层生成候选值,准备加入细胞状态
-
-
例如:新主语的性别、单复数等重要信息
-
-
输出门(Output Gate)
-
作用:基于当前细胞状态,过滤并输出相关信息
-
决定隐藏状态 ht 应该输出什么
-
例如:根据主语决定动词的形式
-
🔹 细胞状态 Ct
细胞状态是 LSTM 的灵魂!
-
它像一条 信息高速公路,在序列中直线传递重要的上下文信息
-
只有少量的线性交互,几乎不受干扰,保证了长期记忆
-
通过门控机制来精细控制:哪些该记住,哪些该遗忘
🔹 三个输入
每个时间步,LSTM 接收三个输入:
-
Xt:当前时间步的输入
-
Ht-1:上一个时间步的隐藏状态(短期记忆)
-
Ct-1:上一个时间步的细胞状态(长期记忆)
🔹 三个输出
每个时间步,LSTM 产生三个输出:
-
Ht:当前时间步的隐藏状态(更新后的短期记忆)
-
Yt:当前时间步的输出
-
Ct:当前时间步的细胞状态(更新后的长期记忆)
📝 形象理解:LSTM 就像一个有笔记本的人,Ct 是笔记本(长期记忆),ht 是工作记忆(短期记忆)。遗忘门就是擦除笔记本上没用的内容,输入门就是写下新的重要信息,输出门就是根据笔记本和当前思考给出回答。
2.2 核心 API
python
torch.nn.LSTM(
input_size, # 输入特征维度
hidden_size, # 隐藏层特征维度
num_layers=1, # 隐藏层数
bias=True, # 是否使用偏置
batch_first=False,# 是否将batch放在第一维
dropout=0.0, # dropout比率
bidirectional=False, # 是否双向
proj_size=0 # 投影大小
)
⚠️ 注意:LSTM 的输入输出与 RNN 略有不同:
输入需要同时提供 h0 和 c0
输出返回 (output, (hn, cn))
2.3 LSTM 优缺点
✅ 优点
-
处理长序列能力强:有效缓解梯度消失,能处理上百个时间步的序列
-
解决长距离依赖:能够学习序列中远距离的关联
-
可自主控制记忆 / 遗忘:三门控机制实现精细的信息管理
-
擅长处理时序数据:在文本、语音、时间序列上表现优秀
-
表达能力更强:比 GRU 功能更丰富,建模能力更高
❌ 缺点
-
结构复杂:三个门 + 细胞状态,参数量大
-
训练速度慢:计算量大,训练时间长
-
仍然串行计算:无法并行处理序列
-
超参数多:调参更麻烦,需要更多经验
-
对极长序列仍有压力:虽然比 RNN 强,但极长序列仍有挑战
-
可解释性差:门控机制的内部工作难以解释
三、GRU 门控循环单元
** 门控循环单元(Gated Recurrent Unit, GRU)** 是 LSTM 的简化版本,2014 年提出。它将 LSTM 的三个门简化为两个门,去掉了细胞状态,在保持性能的同时大幅简化了结构。
3.1 GRU 网络结构
GRU 将 LSTM 的遗忘门和输入门合并为更新门 ,新增重置门,去掉了独立的细胞状态。
🔹 两个门控机制
-
更新门(Update Gate)
-
作用:控制前一时刻的状态信息有多少被保留到当前时刻
-
相当于 LSTM 中遗忘门 + 输入门的组合
-
值越大,保留的历史信息越多
-
-
重置门(Reset Gate)
-
作用:控制忽略前一时刻状态信息的程度
-
值越小,忽略的历史信息越多
-
用于丢弃与未来无关的信息
-
🔹 输入输出结构
GRU 没有独立的细胞状态,隐藏状态 ht 同时承担短期和长期记忆:
两个输入:
-
Xt:当前时间步的输入
-
Ht-1:上一个时间步的隐藏状态
两个输出:
-
Ht:当前时间步的隐藏状态
-
Yt:当前时间步的输出
📝 GRU 的设计哲学:在 LSTM 的基础上做减法,用更少的参数达到相近的效果。实践中,GRU 和 LSTM 性能差异不大,但 GRU 训练更快。
3.2 核心 API
python
torch.nn.GRU(
input_size, # 输入特征维度
hidden_size, # 隐藏层特征维度
num_layers=1, # 隐藏层数
bias=True, # 是否使用偏置
batch_first=False,# 是否将batch放在第一维
dropout=0.0, # dropout比率
bidirectional=False # 是否双向
)
⚠️ 注意:GRU 的输入输出格式和 RNN 完全一样!
只需要提供 h0,不需要 c0
返回 (output, hn),和 RNN 一致
3.3 GRU 优缺点
✅ 优点
-
结构简单:比 LSTM 少一个门,参数量减少约 1/3
-
训练速度快:计算量小,收敛更快
-
调参更容易:超参数少,更容易训练
-
缓解梯度消失:同样能处理较长序列
-
内存占用小:模型更小,部署更友好
❌ 缺点
-
表达能力略弱:建模能力比 LSTM 稍差
-
复杂任务表现稍逊:在非常复杂的序列任务上可能不如 LSTM
-
仍然串行计算:无法并行
-
极长序列仍有限制:和 LSTM 一样,极长序列仍有挑战
四、完整代码示例
4.1 RNN 完整代码(4 个示例)
python
import torch
# ============================================================
# 核心知识点 -> RNN 9大参数
# input: (seq_len样本序列token数, batch_size批次大小, input_size输入特征维度)
# h0: (num_layers隐藏层数, batch_size批次大小, hidden_size输出特征维度)
# RNN层: (input_size输入特征维度, hidden_size输出特征维度, num_layers隐藏层数)
# ============================================================
def demo1_run_api():
"""
示例1:1层RNN,1个样本,每个样本1个词,每个词4维词向量
"""
# input: (seq_len, batch_size, input_size) = (1, 1, 4)
input = torch.randn(1, 1, 4)
# h0: (num_layers, batch_size, hidden_size) = (1, 1, 8)
h0 = torch.zeros(1, 1, 8)
# RNN层: (input_size, hidden_size, num_layers) = (4, 8, 1)
my_rnn = torch.nn.RNN(4, 8, 1)
# 前向传播
output, hn = my_rnn(input, h0)
print("="*50)
print("示例1:1层RNN,1个样本,1个词")
print(f"input形状: {input.shape}")
print(f"h0形状: {h0.shape}")
print(f"output形状: {output.shape}") # (1, 1, 8) - 1个时间步的输出
print(f"hn形状: {hn.shape}") # (1, 1, 8) - 最后一个时间步的隐藏状态
print(f"hn == output[-1]: {torch.allclose(hn, output[-1])}") # True!
def demo2_run_api():
"""
示例2:1层RNN,1个样本,每个样本2个词,每个词4维词向量
"""
# input: (2, 1, 4) - 2个时间步
input = torch.randn(2, 1, 4)
h0 = torch.zeros(1, 1, 8)
my_rnn = torch.nn.RNN(4, 8, 1)
output, hn = my_rnn(input, h0)
print("="*50)
print("示例2:1层RNN,1个样本,2个词")
print(f"input形状: {input.shape}")
print(f"output形状: {output.shape}") # (2, 1, 8) - 2个时间步都有输出
print(f"hn形状: {hn.shape}") # (1, 1, 8) - 只保留最后一个时间步
print(f"hn == output[-1]: {torch.allclose(hn, output[-1])}") # True!
def demo3_run_api():
"""
示例3:1层RNN,2个样本,每个样本3个词,每个词4维词向量
"""
# input: (3, 2, 4) - 3个时间步,2个样本并行
input = torch.randn(3, 2, 4)
h0 = torch.zeros(1, 2, 8)
my_rnn = torch.nn.RNN(4, 8, 1)
output, hn = my_rnn(input, h0)
print("="*50)
print("示例3:1层RNN,2个样本,3个词")
print(f"input形状: {input.shape}")
print(f"output形状: {output.shape}") # (3, 2, 8)
print(f"hn形状: {hn.shape}") # (1, 2, 8)
print(f"hn == output[-1]: {torch.allclose(hn, output[-1])}") # True!
def demo4_run_api():
"""
示例4:3层RNN,3个样本,每个样本4个词,每个词4维词向量
"""
# input: (4, 3, 4)
input = torch.randn(4, 3, 4)
h0 = torch.zeros(3, 3, 8) # 3层!
my_rnn = torch.nn.RNN(4, 8, 3)
output, hn = my_rnn(input, h0)
print("="*50)
print("示例4:3层RNN,3个样本,4个词")
print(f"input形状: {input.shape}")
print(f"output形状: {output.shape}") # (4, 3, 8) - 只返回最后一层的输出
print(f"hn形状: {hn.shape}") # (3, 3, 8) - 返回每一层的最后一个时间步
print(f"hn[-1] == output[-1]: {torch.allclose(hn[-1], output[-1])}") # True!
if __name__ == '__main__':
demo1_run_api()
demo2_run_api()
demo3_run_api()
demo4_run_api()
运行结果说明:
-
所有示例中,
hn永远等于output的最后一个时间步! -
多层 RNN 时,
hn[-1]等于output[-1] -
batch 维度只是并行处理,不影响逻辑
4.2 LSTM 完整代码(4 个示例)
python
import torch
# ============================================================
# 核心知识点 -> LSTM 与 RNN的区别:多了细胞状态c!
# 输入需要 (h0, c0),输出返回 (output, (hn, cn))
# ============================================================
def lstm_demo1():
"""
示例1:1层LSTM,1个样本,1个词,每个词4维
"""
input = torch.randn(1, 1, 4)
h0 = torch.zeros(1, 1, 8)
c0 = torch.zeros(1, 1, 8) # LSTM多了c0!
my_lstm = torch.nn.LSTM(4, 8, 1)
output, (hn, cn) = my_lstm(input, (h0, c0))
print("="*50)
print("LSTM示例1:1层,1个样本,1个词")
print(f"input形状: {input.shape}")
print(f"h0形状: {h0.shape}")
print(f"c0形状: {c0.shape}")
print(f"output形状: {output.shape}")
print(f"hn形状: {hn.shape}")
print(f"cn形状: {cn.shape}")
def lstm_demo2():
"""
示例2:1层LSTM,1个样本,2个词,每个词4维
"""
input = torch.randn(2, 1, 4)
h0 = torch.zeros(1, 1, 8)
c0 = torch.zeros(1, 1, 8)
my_lstm = torch.nn.LSTM(4, 8, 1)
output, (hn, cn) = my_lstm(input, (h0, c0))
print("="*50)
print("LSTM示例2:1层,1个样本,2个词")
print(f"input形状: {input.shape}")
print(f"output形状: {output.shape}") # (2, 1, 8)
print(f"hn形状: {hn.shape}") # (1, 1, 8)
print(f"cn形状: {cn.shape}") # (1, 1, 8)
def lstm_demo3():
"""
示例3:1层LSTM,2个样本,3个词,每个词4维
"""
input = torch.randn(3, 2, 4)
h0 = torch.zeros(1, 2, 8)
c0 = torch.zeros(1, 2, 8)
my_lstm = torch.nn.LSTM(4, 8, 1)
output, (hn, cn) = my_lstm(input, (h0, c0))
print("="*50)
print("LSTM示例3:1层,2个样本,3个词")
print(f"input形状: {input.shape}")
print(f"output形状: {output.shape}") # (3, 2, 8)
print(f"hn形状: {hn.shape}") # (1, 2, 8)
print(f"cn形状: {cn.shape}") # (1, 2, 8)
def lstm_demo4():
"""
示例4:3层LSTM,3个样本,4个词,每个词4维
"""
input = torch.randn(4, 3, 4)
h0 = torch.zeros(3, 3, 8)
c0 = torch.zeros(3, 3, 8)
my_lstm = torch.nn.LSTM(4, 8, 3)
output, (hn, cn) = my_lstm(input, (h0, c0))
print("="*50)
print("LSTM示例4:3层,3个样本,4个词")
print(f"input形状: {input.shape}")
print(f"output形状: {output.shape}") # (4, 3, 8)
print(f"hn形状: {hn.shape}") # (3, 3, 8)
print(f"cn形状: {cn.shape}") # (3, 3, 8)
if __name__ == '__main__':
lstm_demo1()
lstm_demo2()
lstm_demo3()
lstm_demo4()
LSTM 与 RNN 的关键区别:
-
初始化时需要同时提供
h0和c0,形状完全相同 -
前向传播返回
(output, (hn, cn)),多了cn -
其他维度规则与 RNN 完全一致!
五、三者对比总结
对比表格
| 对比维度 | RNN | LSTM | GRU |
|---|---|---|---|
| 门控数量 | 无门控 | 3 个门(遗忘、输入、输出) | 2 个门(更新、重置) |
| 细胞状态 | 无 | 有(独立的长期记忆) | 无(合并到隐藏状态) |
| 参数量 | 最少 | 最多 | 中等(约为 LSTM 的 2/3) |
| 训练速度 | 最快 | 最慢 | 中等 |
| 长序列能力 | 差(<20 步) | 好(~100 步) | 较好 |
| 梯度消失 | 严重 | 大幅缓解 | 缓解 |
| 表达能力 | 最弱 | 最强 | 较强 |
| 调参难度 | 简单 | 复杂 | 中等 |
| 内存占用 | 最小 | 最大 | 中等 |
适用场景选择
🎯 什么时候用 RNN?
-
序列非常短(<20 个时间步)
-
对速度要求极高
-
模型需要极度轻量化
-
简单的序列任务
🎯 什么时候用 LSTM?
-
序列较长,需要捕捉长距离依赖
-
任务复杂,需要更强的建模能力
-
对性能要求高于速度
-
机器翻译、文本生成、语音识别等复杂任务
🎯 什么时候用 GRU?
-
想要比 RNN 强,又不想像 LSTM 那么慢
-
数据集相对较小,防止过拟合
-
需要快速迭代实验
-
大多数情况下是 LSTM 的优秀替代品
初学者学习建议
-
先搞懂维度:9 个参数是基础,维度搞不懂一切都是空谈
-
从 RNN 开始:先理解简单的 RNN,再学 LSTM 和 GRU
-
动手跑代码:一定要亲手运行本文的代码,观察输出形状
-
理解直觉:用记忆的直觉理解门控机制,不要一开始就抠数学公式
-
实践出真知:在真实数据集上训练,感受三者的差异
结语
RNN、LSTM、GRU 是深度学习处理序列数据的基石。虽然现在 Transformer 大行其道,但循环神经网络的思想仍然非常重要,是理解序列建模的必经之路。
希望这篇文章能帮助你彻底搞懂这三种网络!记住:维度是基础,直觉是关键,代码是验证。
祝你学习愉快!🚀