【深度学习教程——04_序列模型(Sequence)】16_RNN为什么记不住长文本?循环神经网络的缺陷

16_RNN为什么记不住长文本?循环神经网络的缺陷

本章目标 :理解序列数据的本质 ------ 顺序就是信息。掌握 RNN 的核心机制:权重共享时间展开 (Unrolling) 。深入剖析其致命缺陷:梯度消失长距离依赖问题,为下一章 LSTM 的出场做铺垫。


目录

  1. [序列数据:Context Matters](#序列数据:Context Matters)
  2. RNN:自带记忆的神经网络
  3. 核心机制:权重共享 (Weight Sharing)
  4. BPTT:随时间反向传播
  5. [痛点:为什么 RNN 记不住长句子?](#痛点:为什么 RNN 记不住长句子?)
  6. [实战:PyTorch 实现 RNN Cell](#实战:PyTorch 实现 RNN Cell)

1. 序列数据:Context Matters

假如我想预测这句话的下一个词:

  • "The clouds are in the ____." (Answer: sky)
  • "I grew up in France... I speak fluent ____." (Answer: French)

在 CNN 或 MLP 中,输入是独立的。但在序列数据中,当前的输出依赖于之前的输入 。这就需要网络拥有记忆力 (Memory)


2. RNN:自带记忆的神经网络

循环神经网络 (RNN) 的结构其实非常简单,就比 MLP 多了一个环(Loop):

h t = tanh ⁡ ( W x h ⋅ x t + W h h ⋅ h t − 1 + b ) h_t = \tanh(W_{xh} \cdot x_t + W_{hh} \cdot h_{t-1} + b) ht=tanh(Wxh⋅xt+Whh⋅ht−1+b)

  • x t x_t xt: 当前时刻的输入 (Input)。
  • h t − 1 h_{t-1} ht−1: 上一时刻的隐藏状态 (Hidden State),也就是"记忆"。
  • h t h_t ht: 当前时刻更新后的"记忆"。
  • tanh ⁡ \tanh tanh: 激活函数,把值压缩到 ( − 1 , 1 ) (-1, 1) (−1,1)。

看上图:左边是折叠图 (Folded),右边是时间展开图 (Unfolded)。
关键点 :虽然右图看起来很深,但每一层用的 W h h W_{hh} Whh 和 W x h W_{xh} Wxh 都是同一个!


3. 核心机制:权重共享 (Weight Sharing)

这一点和 CNN 很像:

  • CNN : 卷积核在空间上滑动(不管猫在左上角还是右下角都用同一个核)。
  • RNN : 细胞核在时间上滑动(不管是在第1秒还是第100秒都用同一个参数)。

如果你对每个时间步都用不同的权重,那就不是 RNN,而是全连接网络了,而且无法处理变长序列。


4. BPTT:随时间反向传播 (Backpropagation Through Time)

当我们要训练 RNN 时,我们把它按时间轴展开 (Unroll) 成一个深层网络。

  • Forward : x 1 → h 1 → h 2 → ⋯ → h T x_1 \to h_1 \to h_2 \to \dots \to h_T x1→h1→h2→⋯→hT
  • Loss : L = ∑ t = 1 T L t L = \sum_{t=1}^T L_t L=∑t=1TLt (每个时刻都有损失)。
  • Backward : 梯度需要从最后一个时刻 T T T,一直反向传播回 t = 0 t=0 t=0。这条红色的梯度流就是 BPTT。

5. 痛点:为什么 RNN 记不住长句子?

这就是著名的 长距离依赖 (Long-term Dependency) 问题,本质上是 梯度消失

在 BPTT 中,梯度的计算涉及连乘:
∂ h T ∂ h t = ∏ k = t T − 1 ∂ h k + 1 ∂ h k ≈ ∏ W h h ⋅ tanh ⁡ ′ \frac{\partial h_T}{\partial h_t} = \prod_{k=t}^{T-1} \frac{\partial h_{k+1}}{\partial h_k} \approx \prod W_{hh} \cdot \tanh' ∂ht∂hT=k=t∏T−1∂hk∂hk+1≈∏Whh⋅tanh′

  • tanh ⁡ ′ \tanh' tanh′ 的导数最大值是 1(通常远小于 1)。
  • 如果 ∣ W h h ∣ |W_{hh}| ∣Whh∣ 也小于 1(初始化或正则化导致)。
  • 那么几百个小于 1 的数相乘,结果趋近于 0。

结果 :RNN 只能"看"到最近几步的信息。

对于 "I grew up in France... (隔了100个词) ... I speak fluent French "。

RNN 读到 "speak" 时,关于 "France" 的梯度早就消失了,它根本不知道前面提到了法国。


6. 实战:PyTorch 实现 RNN Cell

我们手动实现一个简单的 RNN Cell,感受一下这个循环过程。

python 复制代码
import torch
import torch.nn as nn

class MyRNNCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        # 输入 -> 隐层
        self.i2h = nn.Linear(input_size, hidden_size)
        # 隐层 -> 隐层 (循环核)
        self.h2h = nn.Linear(hidden_size, hidden_size)

    def forward(self, input, hidden):
        # h_t = tanh(W_ih * x_t + W_hh * h_{t-1})
        combined = self.i2h(input) + self.h2h(hidden)
        hidden = torch.tanh(combined)
        return hidden

# 使用示例
input_size = 4
hidden_size = 8
cell = MyRNNCell(input_size, hidden_size)

# 模拟序列: SeqLen=5, Batch=1, Feature=4
inputs = torch.randn(5, 1, input_size)
hidden = torch.zeros(1, hidden_size) # 初始记忆为 0

print("Processing sequence:")
for idx, x in enumerate(inputs):
    # x: [Batch, Input]
    hidden = cell(x[0], hidden)
    print(f"Step {idx+1}: Hidden state norm = {hidden.norm():.4f}")

问题 :由于梯度消失,RNN 无法处理长序列。有什么办法能让记忆保持得更久一点?

这就需要给神经元装上"开关"(门控机制)。

下一章,我们将介绍 LSTM (Long Short-Term Memory)

下一章:17_LSTM的三个门在控制什么?长短期记忆网络详解

相关推荐
jay神19 小时前
基于YOLOv8的无人机识别与检测系统
人工智能·深度学习·yolo·目标检测·毕业设计·无人机
皮卡丘不断更19 小时前
我把传统项目问答升级成了 Agent-RAG:Spring Boot + FastAPI + ChromaDB 工程落地实践
人工智能·spring boot·后端·架构·python3.11
serve the people19 小时前
BERT模型
人工智能·深度学习·bert
木斯佳19 小时前
前端八股文面经大全:得物AI应用开发一面(2026-03-23)·面经深度解析【加精】
前端·人工智能·ai·markdown·chat·rag
绒绒毛毛雨20 小时前
On the Plasticity and Stability for Post-Training Large Language Models
人工智能·机器学习·语言模型
SuniaWang1 天前
《Spring AI + 大模型全栈实战》学习手册系列 · 专题六:《Vue3 前端开发实战:打造企业级 RAG 问答界面》
java·前端·人工智能·spring boot·后端·spring·架构
IDZSY04301 天前
AI社交平台进阶指南:如何用AI社交提升工作学习效率
人工智能·学习
七七powerful1 天前
运维养龙虾--AI 驱动的架构图革命:draw.io MCP 让运维画图效率提升 10 倍,使用codebuddy实战
运维·人工智能·draw.io
水星梦月1 天前
大白话讲解AI/LLM核心概念
人工智能
温九味闻醉1 天前
关于腾讯广告算法大赛2025项目分析1 - dataset.py
人工智能·算法·机器学习