用deepseek学大模型08-循环神经网络

从入门到精通循环神经网络 (RNN)

https://www.dxy.cn/bbs/newweb/pc/post/50883341

https://wenku.csdn.net/column/kbnq75axws

1. RNN 基础

RNN 通过隐藏状态传递序列信息,核心公式:

  • 隐藏状态:
    h t = tanh ⁡ ( W h h h t − 1 + W x h x t + b h ) \mathbf{h}t = \tanh(\mathbf{W}{hh} \mathbf{h}{t-1} + \mathbf{W}{xh} \mathbf{x}_t + \mathbf{b}_h) ht=tanh(Whhht−1+Wxhxt+bh)
  • 输出:
    y t = W h y h t + b y \mathbf{y}t = \mathbf{W}{hy} \mathbf{h}_t + \mathbf{b}_y yt=Whyht+by
2. 目标函数与损失函数
  • 目标函数:最小化预测与真实值的差距。
  • 损失函数 (以 MSE 为例):
    L = 1 2 T ∑ t = 1 T ( y t − y ^ t ) 2 L = \frac{1}{2T} \sum_{t=1}^T (\mathbf{y}_t - \mathbf{\hat{y}}_t)^2 L=2T1∑t=1T(yt−y^t)2
3. 梯度下降与数学推导

标量形式 (以 W h h W_{hh} Whh为例):
∂ L ∂ W h h = ∑ t = 1 T ∂ L ∂ y t ⋅ ∂ y t ∂ h t ⋅ ( ∏ k = 1 t ∂ h k ∂ h k − 1 ) ⋅ ∂ h 1 ∂ W h h \frac{\partial L}{\partial W_{hh}} = \sum_{t=1}^T \frac{\partial L}{\partial \mathbf{y}_t} \cdot \frac{\partial \mathbf{y}_t}{\partial \mathbf{h}t} \cdot \left( \prod{k=1}^t \frac{\partial \mathbf{h}k}{\partial \mathbf{h}{k-1}} \right) \cdot \frac{\partial \mathbf{h}1}{\partial W{hh}} ∂Whh∂L=t=1∑T∂yt∂L⋅∂ht∂yt⋅(k=1∏t∂hk−1∂hk)⋅∂Whh∂h1

其中, ∂ h k ∂ h k − 1 = W h h T ⋅ diag ( 1 − tanh ⁡ 2 ( ⋅ ) ) \frac{\partial \mathbf{h}k}{\partial \mathbf{h}{k-1}} = \mathbf{W}_{hh}^T \cdot \text{diag}(1 - \tanh^2(\cdot)) ∂hk−1∂hk=WhhT⋅diag(1−tanh2(⋅)),导致梯度消失/爆炸。

矩阵形式
∂ L ∂ W h h = ∑ t = 1 T diag ( 1 − h t 2 ) ⋅ h t − 1 T ⋅ ( W h y T ( y ^ t − y t ) ∏ k = t 1 W h h T diag ( 1 − h k 2 ) ) \frac{\partial L}{\partial \mathbf{W}{hh}} = \sum{t=1}^T \text{diag}(1 - \mathbf{h}t^2) \cdot \mathbf{h}{t-1}^T \cdot \left( \mathbf{W}_{hy}^T (\mathbf{\hat{y}}t - \mathbf{y}t) \prod{k=t}^1 \mathbf{W}{hh}^T \text{diag}(1 - \mathbf{h}_k^2) \right) ∂Whh∂L=t=1∑Tdiag(1−ht2)⋅ht−1T⋅(WhyT(y^t−yt)k=t∏1WhhTdiag(1−hk2))

4. PyTorch 代码案例
python 复制代码
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

# 数据生成
seq_len = 20
time = torch.arange(0, seq_len, 0.1)
data = torch.sin(time) + torch.randn(seq_len * 10) * 0.1

# 转换为序列数据
def create_dataset(data, window=5):
    X, y = [], []
    for i in range(len(data)-window):
        X.append(data[i:i+window])
        y.append(data[i+window])
    return torch.stack(X), torch.stack(y)

X, y = create_dataset(data, window=5)
X = X.unsqueeze(-1).float()  # (samples, window, features)
y = y.unsqueeze(-1).float()

# 定义模型
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        out, _ = self.rnn(x)  # out: (batch, seq, hidden)
        out = self.fc(out[:, -1, :])  # 取最后一个时间步
        return out

model = RNN(1, 32, 1)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# 训练
epochs = 100
losses = []
for epoch in range(epochs):
    optimizer.zero_grad()
    outputs = model(X)
    loss = criterion(outputs, y)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)  # 梯度裁剪
    optimizer.step()
    losses.append(loss.item())

# 可视化损失
plt.plot(losses)
plt.title('Training Loss')
plt.show()

# 预测
with torch.no_grad():
    pred = model(X)

plt.plot(time[5:], y.numpy(), label='True')
plt.plot(time[5:], pred.numpy(), label='Predicted')
plt.legend()
plt.show()
5. 可视化展示
  • 损失曲线:展示训练过程中损失下降。
  • 预测对比:真实值与预测值的时间序列对比。
  • 隐藏状态可视化(可选):通过 PCA 降维展示隐藏状态变化。
6. 应用场景与优缺点
  • 应用:时间序列预测、文本生成、机器翻译。
  • 优点:处理变长序列,捕捉时序依赖。
  • 缺点:梯度消失/爆炸,长程依赖困难,计算效率低。
7. 改进方法
  • 结构改进 :使用 LSTM/GRU 的门控机制,例如 LSTM 的遗忘门:
    f t = σ ( W f [ h t − 1 , x t ] + b f ) f_t = \sigma(\mathbf{W}f [\mathbf{h}{t-1}, \mathbf{x}_t] + \mathbf{b}_f) ft=σ(Wf[ht−1,xt]+bf)
  • 梯度裁剪:限制梯度最大值,防止爆炸。
  • 优化算法:Adam 自适应学习率。
  • 注意力机制:增强长距离依赖捕捉能力。
8. 数学推导改进(LSTM 示例)

LSTM 通过细胞状态 C t \mathbf{C}_t Ct传递信息,梯度流动更稳定:
C t = f t ⊙ C t − 1 + i t ⊙ C ~ t \mathbf{C}t = f_t \odot \mathbf{C}{t-1} + i_t \odot \tilde{\mathbf{C}}_t Ct=ft⊙Ct−1+it⊙C~t

其中遗忘门 f t f_t ft控制历史信息保留,避免传统 RNN 的连乘梯度,缓解消失问题。


通过上述步骤,您可系统掌握 RNN 的核心理论、实现及优化方法。

相关推荐
@心都15 分钟前
机器学习数学基础:29.t检验
人工智能·机器学习
9命怪猫18 分钟前
DeepSeek底层揭秘——微调
人工智能·深度学习·神经网络·ai·大模型
kcarly2 小时前
KTransformers如何通过内核级优化、多GPU并行策略和稀疏注意力等技术显著加速大语言模型的推理速度?
人工智能·语言模型·自然语言处理
Jackilina_Stone2 小时前
【论文阅读笔记】浅谈深度学习中的知识蒸馏 | 关系知识蒸馏 | CVPR 2019 | RKD
论文阅读·深度学习·蒸馏·rkd
倒霉蛋小马3 小时前
【YOLOv8】损失函数
深度学习·yolo·机器学习
MinIO官方账号3 小时前
使用 AIStor 和 OpenSearch 增强搜索功能
人工智能
江江江江江江江江江4 小时前
深度神经网络终极指南:从数学本质到工业级实现(附Keras版本代码)
人工智能·keras·dnn
Fansv5874 小时前
深度学习-2.机械学习基础
人工智能·经验分享·python·深度学习·算法·机器学习
小怪兽会微笑4 小时前
PyTorch Tensor 形状变化操作详解
人工智能·pytorch·python
Erekys5 小时前
视觉分析之边缘检测算法
人工智能·计算机视觉·音视频