【深度学习教程——04_序列模型(Sequence)】16_RNN为什么记不住长文本?循环神经网络的缺陷

16_RNN为什么记不住长文本?循环神经网络的缺陷

本章目标 :理解序列数据的本质 ------ 顺序就是信息。掌握 RNN 的核心机制:权重共享时间展开 (Unrolling) 。深入剖析其致命缺陷:梯度消失长距离依赖问题,为下一章 LSTM 的出场做铺垫。


目录

  1. [序列数据:Context Matters](#序列数据:Context Matters)
  2. RNN:自带记忆的神经网络
  3. 核心机制:权重共享 (Weight Sharing)
  4. BPTT:随时间反向传播
  5. [痛点:为什么 RNN 记不住长句子?](#痛点:为什么 RNN 记不住长句子?)
  6. [实战:PyTorch 实现 RNN Cell](#实战:PyTorch 实现 RNN Cell)

1. 序列数据:Context Matters

假如我想预测这句话的下一个词:

  • "The clouds are in the ____." (Answer: sky)
  • "I grew up in France... I speak fluent ____." (Answer: French)

在 CNN 或 MLP 中,输入是独立的。但在序列数据中,当前的输出依赖于之前的输入 。这就需要网络拥有记忆力 (Memory)


2. RNN:自带记忆的神经网络

循环神经网络 (RNN) 的结构其实非常简单,就比 MLP 多了一个环(Loop):

h t = tanh ⁡ ( W x h ⋅ x t + W h h ⋅ h t − 1 + b ) h_t = \tanh(W_{xh} \cdot x_t + W_{hh} \cdot h_{t-1} + b) ht=tanh(Wxh⋅xt+Whh⋅ht−1+b)

  • x t x_t xt: 当前时刻的输入 (Input)。
  • h t − 1 h_{t-1} ht−1: 上一时刻的隐藏状态 (Hidden State),也就是"记忆"。
  • h t h_t ht: 当前时刻更新后的"记忆"。
  • tanh ⁡ \tanh tanh: 激活函数,把值压缩到 ( − 1 , 1 ) (-1, 1) (−1,1)。

看上图:左边是折叠图 (Folded),右边是时间展开图 (Unfolded)。
关键点 :虽然右图看起来很深,但每一层用的 W h h W_{hh} Whh 和 W x h W_{xh} Wxh 都是同一个!


3. 核心机制:权重共享 (Weight Sharing)

这一点和 CNN 很像:

  • CNN : 卷积核在空间上滑动(不管猫在左上角还是右下角都用同一个核)。
  • RNN : 细胞核在时间上滑动(不管是在第1秒还是第100秒都用同一个参数)。

如果你对每个时间步都用不同的权重,那就不是 RNN,而是全连接网络了,而且无法处理变长序列。


4. BPTT:随时间反向传播 (Backpropagation Through Time)

当我们要训练 RNN 时,我们把它按时间轴展开 (Unroll) 成一个深层网络。

  • Forward : x 1 → h 1 → h 2 → ⋯ → h T x_1 \to h_1 \to h_2 \to \dots \to h_T x1→h1→h2→⋯→hT
  • Loss : L = ∑ t = 1 T L t L = \sum_{t=1}^T L_t L=∑t=1TLt (每个时刻都有损失)。
  • Backward : 梯度需要从最后一个时刻 T T T,一直反向传播回 t = 0 t=0 t=0。这条红色的梯度流就是 BPTT。

5. 痛点:为什么 RNN 记不住长句子?

这就是著名的 长距离依赖 (Long-term Dependency) 问题,本质上是 梯度消失

在 BPTT 中,梯度的计算涉及连乘:
∂ h T ∂ h t = ∏ k = t T − 1 ∂ h k + 1 ∂ h k ≈ ∏ W h h ⋅ tanh ⁡ ′ \frac{\partial h_T}{\partial h_t} = \prod_{k=t}^{T-1} \frac{\partial h_{k+1}}{\partial h_k} \approx \prod W_{hh} \cdot \tanh' ∂ht∂hT=k=t∏T−1∂hk∂hk+1≈∏Whh⋅tanh′

  • tanh ⁡ ′ \tanh' tanh′ 的导数最大值是 1(通常远小于 1)。
  • 如果 ∣ W h h ∣ |W_{hh}| ∣Whh∣ 也小于 1(初始化或正则化导致)。
  • 那么几百个小于 1 的数相乘,结果趋近于 0。

结果 :RNN 只能"看"到最近几步的信息。

对于 "I grew up in France... (隔了100个词) ... I speak fluent French "。

RNN 读到 "speak" 时,关于 "France" 的梯度早就消失了,它根本不知道前面提到了法国。


6. 实战:PyTorch 实现 RNN Cell

我们手动实现一个简单的 RNN Cell,感受一下这个循环过程。

python 复制代码
import torch
import torch.nn as nn

class MyRNNCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        # 输入 -> 隐层
        self.i2h = nn.Linear(input_size, hidden_size)
        # 隐层 -> 隐层 (循环核)
        self.h2h = nn.Linear(hidden_size, hidden_size)

    def forward(self, input, hidden):
        # h_t = tanh(W_ih * x_t + W_hh * h_{t-1})
        combined = self.i2h(input) + self.h2h(hidden)
        hidden = torch.tanh(combined)
        return hidden

# 使用示例
input_size = 4
hidden_size = 8
cell = MyRNNCell(input_size, hidden_size)

# 模拟序列: SeqLen=5, Batch=1, Feature=4
inputs = torch.randn(5, 1, input_size)
hidden = torch.zeros(1, hidden_size) # 初始记忆为 0

print("Processing sequence:")
for idx, x in enumerate(inputs):
    # x: [Batch, Input]
    hidden = cell(x[0], hidden)
    print(f"Step {idx+1}: Hidden state norm = {hidden.norm():.4f}")

问题 :由于梯度消失,RNN 无法处理长序列。有什么办法能让记忆保持得更久一点?

这就需要给神经元装上"开关"(门控机制)。

下一章,我们将介绍 LSTM (Long Short-Term Memory)

下一章:17_LSTM的三个门在控制什么?长短期记忆网络详解

相关推荐
个入资料2 小时前
阿里云ecs+飞书搭建openclaw
人工智能
CoovallyAIHub2 小时前
OpenClaw一脚踩碎传统CV?机器终于不再只是看世界
深度学习·算法·计算机视觉
CoovallyAIHub2 小时前
仅凭单目相机实现3D锥桶定位?UNet-RKNet破解自动驾驶锥桶检测难题
深度学习·算法·计算机视觉
孤烟3 小时前
【RAG 实战系列 02】检索精度翻倍!混合检索(稀疏 + 稠密)实战教程
人工智能·llm
明明如月学长3 小时前
OpenClaw 帮我睡后全自动完成了老板交代的任务
人工智能
uuware3 小时前
Lupine.Press + AI 助您分分钟搞定技术项目的文档网站
人工智能·前端框架
海上日出3 小时前
使用 QuantStats 进行投资组合绩效分析:Python 量化实战指南
人工智能
Qinana3 小时前
150行代码搞定私有知识库!Node.js + LangChain 打造最小化 RAG 系统全流程
人工智能·程序员·node.js
猿猿长成记3 小时前
AI专栏 | AI大法则之思维链、自洽性、思维树
人工智能
用户5191495848453 小时前
CrushFTP 条件竞争认证绕过漏洞利用工具 (CVE-2025-54309)
人工智能·aigc