门控循环单元(GRU)

门控循环单元(GRU)基本原理

一、GRU核心思想与设计动机

目标 :在保留LSTM长程记忆能力的前提下,简化网络结构
核心创新

  • 合并LSTM的输入门和遗忘门为更新门(Update Gate)
  • 去除细胞状态(Cell State),直接通过隐藏状态传递信息
  • 参数数量比LSTM减少1/3,训练速度提升20-30%

二、网络结构分解

1. 核心组件(两个门 + 候选状态)

组件 符号 功能描述
更新门 z t z_t zt 控制历史信息与当前信息的融合比例
重置门 r t r_t rt 决定忽略多少历史信息生成候选状态
候选隐藏状态 h ~ t \tilde{h}_t h~t 包含当前输入与部分历史信息的中间状态

2. 数学公式推导

更新门(Update Gate)

z t = σ ( W z ⋅ [ h t − 1 , x t ] + b z ) z_t = \sigma(W_z \cdot [h_{t-1}, x_t] + b_z) zt=σ(Wz⋅[ht−1,xt]+bz)

  • σ \sigma σ: Sigmoid函数(输出0-1间的保留比例)
重置门(Reset Gate)

r t = σ ( W r ⋅ [ h t − 1 , x t ] + b r ) r_t = \sigma(W_r \cdot [h_{t-1}, x_t] + b_r) rt=σ(Wr⋅[ht−1,xt]+br)

候选隐藏状态

h ~ t = tanh ⁡ ( W ⋅ [ r t ⊙ h t − 1 , x t ] + b ) \tilde{h}t = \tanh(W \cdot [r_t \odot h{t-1}, x_t] + b) h~t=tanh(W⋅[rt⊙ht−1,xt]+b)

  • ⊙ \odot ⊙: Hadamard积(控制历史信息流入量)
最终隐藏状态

h t = ( 1 − z t ) ⊙ h t − 1 + z t ⊙ h ~ t h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t ht=(1−zt)⊙ht−1+zt⊙h~t

  • 动态平衡历史信息保留与更新

三、PyTorch实现(手动版)

1. GRU单元实现

python 复制代码
import torch
import torch.nn as nn

class GRUCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        
        # 合并计算三个门的参数矩阵
        self.W = nn.Linear(input_size + hidden_size, 3*hidden_size)
        
    def forward(self, x, h_prev):
        # 拼接输入与历史状态
        combined = torch.cat((x, h_prev), dim=1)
        gates = self.W(combined)
        
        # 分割门控计算结果
        z, r, n = torch.split(gates, self.hidden_size, dim=1)
        
        # 激活函数应用
        z = torch.sigmoid(z)  # 更新门
        r = torch.sigmoid(r)  # 重置门
        n = torch.tanh(r * n) # 候选状态
        
        # 最终状态更新
        h = (1 - z) * h_prev + z * n
        
        return h
相关推荐
go546315846511 分钟前
基于深度学习的食管癌右喉返神经旁淋巴结预测系统研究
图像处理·人工智能·深度学习·神经网络·算法
Blossom.11813 分钟前
基于深度学习的图像分类:使用Capsule Networks实现高效分类
人工智能·python·深度学习·神经网络·机器学习·分类·数据挖掘
宇称不守恒4.015 分钟前
2025暑期—05神经网络-卷积神经网络
深度学习·神经网络·cnn
想变成树袋熊1 小时前
【自用】NLP算法面经(6)
人工智能·算法·自然语言处理
格林威1 小时前
Baumer工业相机堡盟工业相机如何通过YoloV8深度学习模型实现沙滩小人检测识别(C#代码UI界面版)
人工智能·深度学习·数码相机·yolo·计算机视觉
checkcheckck2 小时前
spring ai 适配 流式回答、mcp、milvus向量数据库、rag、聊天会话记忆
人工智能
Microvision维视智造2 小时前
从“人工眼”到‘智能眼’:EZ-Vision视觉系统如何重构生产线视觉检测精度?
图像处理·人工智能·重构·视觉检测
巫婆理发2222 小时前
神经网络(多层感知机)(第二课第二周)
人工智能·深度学习·神经网络
lxmyzzs2 小时前
【打怪升级 - 03】YOLO11/YOLO12/YOLOv10/YOLOv8 完全指南:从理论到代码实战,新手入门必看教程
人工智能·神经网络·yolo·目标检测·计算机视觉
SEO_juper2 小时前
企业级 AI 工具选型报告:9 个技术平台的 ROI 对比与部署策略
人工智能·搜索引擎·百度·llm·工具·geo·数字营销