RNN循环神经网络(一):基础RNN结构、双向RNN

RNN循环神经网络

什么是循环神经网络?

循环神经网络(Recurrent Neural Network, RNN)是一类专门用于处理序列数据的神经网络架构。与传统的前馈神经网络不同,RNN具有"记忆"能力,能够捕捉数据中的时间依赖关系。

核心特点:

  1. 循环连接:RNN单元之间存在循环连接,使得信息能够在网络内部持续传递
  2. 参数共享:相同的权重参数在时间步之间共享,大大减少了模型参数数量
  3. 序列处理:能够处理可变长度的输入序列,适用于时序数据

基本结构:

RNN的基本单元包含一个隐藏状态(hidden state),它在每个时间步都会被更新:

  • 新隐藏状态 = f(当前输入, 前一个隐藏状态)

举一个简单的例子:

简单的循环神经网络例子(多对多)

我们来做一个简单的循环神经网络,其实也就是跟上图一致。

python 复制代码
import torch
from torch import nn

class RNNCell(nn.Module):
    def __init__(self,input_size,hidden_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.w_hidden = torch.randn(hidden_size,hidden_size)
        self.w_input = torch.randn(input_size,hidden_size)
        self.tanh = nn.Tanh()

    def forward(self,x,hidden_state=None):
        N,input_size = x.shape
        if hidden_state is None:
            hidden_state = torch.zeros(N,self.hidden_size)
        hidden_state = self.tanh(hidden_state @ self.w_hidden + x @ self.w_input)
        return hidden_state


class RNN(nn.Module):
    def __init__(self,input_size,hidden_size):
        super().__init__()
        self.cell = RNNCell(input_size,hidden_size)
        self.w_output = torch.randn(hidden_size,hidden_size)

    def forward(self,x,hidden_state=None):
        N,L,input_size = x.shape

        outputs = []

        for i in range(L):
            x_i = x[:,i]
            hidden_state = self.cell(x_i,hidden_state)
            out = hidden_state @ self.w_output
            outputs.append(out)

        outputs = torch.stack(outputs,dim=1)
        return outputs,hidden_state

if __name__ == "__main__":
    x = torch.randn(5,3,10)
    model = RNN(10,20)
    y,h = model(x)
    print(y.shape)
    print(h.shape)

双向循环神经网络

双向RNN其实也就是两层RNN的叠加,分别更新的是两层隐藏状态以及两层输出。

python 复制代码
import torch
from torch import nn

class BiRNN(nn.Module):
    def __init__(self,input_size,hidden_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
		#前向RNN和线性层
        self.forward_cell = nn.RNNCell(input_size,hidden_size)
        self.backward_cell = nn.RNNCell(input_size,hidden_size)
		#反向RNN和线性层
        self.forward_Linear = nn.Linear(hidden_size,hidden_size)
        self.backward_Linear = nn.Linear(hidden_size,hidden_size)

    def forward(self,x,hidden = None):
        N,L,input_size = x.shape
        if hidden is None:
            #堆叠两层隐藏层
            hidden = torch.zeros(2,N,self.hidden_size)
        h_forward = hidden[0]
        out_forward = []
        for i in range(L):
            h_forward = self.forward_cell(x[:,i],h_forward)
            out = self.forward_Linear(h_forward)
            out_forward.append(out)

        out_forward = torch.stack(out_forward,dim=1)

        x = torch.flip(x,dims=[1])
        h_backward = hidden[1]
        out_backward = []
        for i in range(L):
            h_backward = self.backward_cell(x[:,i],h_backward)
            out = self.backward_Linear(h_backward)
            out_backward.append(out)

        out_backward = torch.stack(out_backward,dim=1)
        
        outputs = torch.concat((out_forward,out_backward),dim=-1)

        hidden = torch.stack([h_forward,h_backward])
        return outputs,hidden

if __name__ == '__main__':
    x = torch.randn((5,3,10))
    model = BiRNN(10,20)
    outputs,hidden = model(x)
    print(outputs.shape)
    print(hidden.shape)
相关推荐
互联网Ai好者2 分钟前
MiyoAI数参首发体验——不止于监控,更是你的智能决策参谋
人工智能
island13142 分钟前
CANN HIXL 通信库深度解析:单边点对点数据传输、异步模型与异构设备间显存直接访问
人工智能·深度学习·神经网络
心疼你的一切6 分钟前
解锁CANN仓库核心能力:从零搭建AIGC轻量文本生成实战(附代码+流程图)
数据仓库·深度学习·aigc·流程图·cann
初恋叫萱萱7 分钟前
CANN 生态中的图优化引擎:深入 `ge` 项目实现模型自动调优
人工智能
不爱学英文的码字机器8 分钟前
深度解读CANN生态核心仓库——catlass,打造高效可扩展的分类器技术底座
人工智能·cann
Kiyra9 分钟前
作为后端开发你不得不知的 AI 知识——RAG
人工智能·语言模型
共享家952712 分钟前
Vibe Coding 与 LangChain、LangGraph 的协同进化
人工智能
dvlinker14 分钟前
2026远程桌面安全白皮书:ToDesk/TeamViewer/向日葵核心安全性与合规性横向测评
人工智能
2的n次方_16 分钟前
CANN ascend-transformer-boost 深度解析:针对大模型的高性能融合算子库与算力优化机制
人工智能·深度学习·transformer
熊猫_豆豆16 分钟前
YOLOP车道检测
人工智能·python·算法