循环神经网络(RNN)

循环神经网络(RNN)基本原理

一、RNN核心思想

目标 :处理序列数据(如文本、时间序列),通过循环连接 传递隐藏状态,捕捉序列的动态依赖关系。
核心特性

  • 参数共享:所有时间步共享同一组权重。
  • 记忆能力:隐藏状态 h t h_t ht 编码历史信息。

二、网络结构与数学公式

1. RNN展开结构

  • 输入 :时间步 t t t 的输入 x t x_t xt(如词向量)。
  • 隐藏状态 : h t h_t ht 融合当前输入与历史信息。
  • 输出 : y t y_t yt 基于 h t h_t ht 生成预测。

2. 数学公式

  • 隐藏状态更新
    h t = tanh ⁡ ( W h h h t − 1 + W x h x t + b h ) h_t = \tanh(W_{hh} h_{t-1} + W_{xh} x_t + b_h) ht=tanh(Whhht−1+Wxhxt+bh)

    • W h h ∈ R d h × d h W_{hh} \in \mathbb{R}^{d_h \times d_h} Whh∈Rdh×dh: 隐藏状态权重
    • W x h ∈ R d x × d h W_{xh} \in \mathbb{R}^{d_x \times d_h} Wxh∈Rdx×dh: 输入权重
    • tanh ⁡ \tanh tanh: 激活函数(压缩到[-1,1])
  • 输出计算
    y t = W h y h t + b y y_t = W_{hy} h_t + b_y yt=Whyht+by

    • W h y ∈ R d h × d y W_{hy} \in \mathbb{R}^{d_h \times d_y} Why∈Rdh×dy: 输出权重

三、PyTorch代码实现

1. RNN模型定义

python 复制代码
import torch
import torch.nn as nn

class SimpleRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleRNN, self).__init__()
        self.hidden_size = hidden_size
        
        # 定义权重参数
        self.W_hh = nn.Parameter(torch.randn(hidden_size, hidden_size))
        self.W_xh = nn.Parameter(torch.randn(input_size, hidden_size))
        self.W_hy = nn.Parameter(torch.randn(hidden_size, output_size))
        self.b_h = nn.Parameter(torch.zeros(hidden_size))
        self.b_y = nn.Parameter(torch.zeros(output_size))

    def forward(self, x_seq):
        # x_seq形状: (seq_length, batch_size, input_size)
        batch_size = x_seq.size(1)
        h = torch.zeros(batch_size, self.hidden_size)  # 初始隐藏状态
        
        outputs = []
        for x_t in x_seq:  # 按时间步迭代
            # 更新隐藏状态
            h = torch.tanh(
                torch.mm(h, self.W_hh) + 
                torch.mm(x_t, self.W_xh) + 
                self.b_h
            )
            # 计算输出
            y_t = torch.mm(h, self.W_hy) + self.b_y
            outputs.append(y_t)
        
        return torch.stack(outputs), h
相关推荐
数据皮皮侠2 分钟前
最新上市公司业绩说明会文本数据(2017.02-2025.08)
大数据·数据库·人工智能·笔记·物联网·小程序·区块链
智算菩萨7 分钟前
【计算机视觉与深度学习实战】05计算机视觉与深度学习在蚊子检测中的应用综述与假设
人工智能·深度学习·计算机视觉
hllqkbb8 分钟前
人体姿态估计-动手学计算机视觉14
人工智能·opencv·计算机视觉·分类
XiongLiding15 分钟前
我的第一个MCP,以及开发过程中的经验感悟
人工智能
三花AI30 分钟前
阿里 20B 参数 Qwen-Image-Edit 全能图像编辑模型
人工智能
EthanLifeGreat42 分钟前
ParallelWaveGAN-KaldiFree:纯Pytorch的PWG
人工智能·pytorch·深度学习·音频·语音识别
盏灯1 小时前
据说,80%的人都搞不懂MCP底层?
人工智能·aigc·mcp
机器之心1 小时前
机器人也会「摸鱼」了?宇树G1赛后葛优瘫刷美女视频,网友:比人还懂享受生活
人工智能·openai
胡耀超1 小时前
从哲学(业务)视角看待数据挖掘:从认知到实践的螺旋上升
人工智能·python·数据挖掘·大模型·特征工程·crisp-dm螺旋认知·批判性思维
新智元1 小时前
Meta没做的,英伟达做了!全新架构吞吐量狂飙6倍,20万亿Token训练
人工智能·openai