TensorFlow2 Python深度学习 - 循环神经网络(GRU)示例

锋哥原创的TensorFlow2 Python深度学习视频教程:

https://www.bilibili.com/video/BV1X5xVz6E4w/

课程介绍

本课程主要讲解基于TensorFlow2的Python深度学习知识,包括深度学习概述,TensorFlow2框架入门知识,以及卷积神经网络(CNN),循环神经网络(RNN),生成对抗网络(GAN),模型保存与加载等。

TensorFlow2 Python深度学习 - 循环神经网络(GRU)示例

GRU(门控循环单元,Gated Recurrent Unit)是一种用于处理序列数据的递归神经网络(RNN)模型。它是为了克服传统RNN在长时间序列中训练时遇到的梯度消失问题(即记忆力衰减)而提出的。GRU相较于LSTM(长短时记忆网络),结构更简单,计算效率更高,但能够实现类似的性能。

GRU的工作原理

GRU通过引入两个重要的门控机制来控制信息的流动:

  1. 更新门(Update Gate):决定了当前时刻的隐藏状态有多少部分应该被更新,多少部分保留旧的隐藏状态信息。这个门类似于LSTM中的遗忘门和输入门的结合。它通过一个Sigmoid激活函数控制当前输入和上一时刻的隐藏状态对当前时刻的影响。

  2. 重置门(Reset Gate):控制遗忘多少旧的隐藏状态信息。它决定了当前时刻输入和过去隐藏状态的结合程度,从而决定了网络保留多少历史信息。通过一个Sigmoid激活函数来计算。

GRU的关键是利用这些门控机制在时间步之间传递信息,从而可以更好地捕获序列数据中的长期依赖性。

复制代码
tf.keras.layers.GRU(
    units,
    activation='tanh',
    recurrent_activation='sigmoid',
    use_bias=True,
    kernel_initializer='glorot_uniform',
    recurrent_initializer='orthogonal',
    bias_initializer='zeros',
    dropout=0.0,
    recurrent_dropout=0.0,
    return_sequences=False,
    return_state=False,
    go_backwards=False,
    stateful=False,
    time_major=False,
    unroll=False,
    reset_after=True,
    **kwargs
)

核心参数:

  1. units - 核心参数
  • 作用:定义GRU层中隐藏单元的数量

  • 影响:决定模型的容量和复杂度,值越大表示记忆能力越强

  • 选择:通常根据任务复杂度在32-512之间选择

  1. activation - 激活函数
  • 作用:定义输出计算的激活函数

  • 默认值'tanh'

  • 功能:控制候选隐藏状态的计算

  1. recurrent_activation - 循环激活函数
  • 作用:定义门控机制的激活函数

  • 默认值'sigmoid'

  • 功能:控制更新门和重置门的计算

  1. return_sequences - 输出控制
  • 作用:控制是否返回所有时间步的输出

  • 默认值False(只返回最后一个时间步)

  • 应用

    • False:用于序列分类任务

    • True:用于序列标注或多层GRU堆叠

  1. return_state - 状态返回
  • 作用:是否返回最后的隐藏状态

  • 使用场景:编码器-解码器架构或需要状态传递的场景

  1. dropout - 输入丢弃率
  • 作用:输入单元的随机丢弃比例,防止过拟合

  • 范围:0.0-1.0,推荐0.2-0.5

  1. recurrent_dropout - 循环丢弃率
  • 作用:循环连接的随机丢弃比例

  • 功能:专门防止循环连接上的过拟合

  1. reset_after - 重置门位置
  • 作用:控制重置门的应用顺序

  • 默认值True(与CuDNN兼容的模式)

  • 影响:影响计算图和性能,通常保持默认

示例:

复制代码
import tensorflow as tf
from keras import Input, layers
from keras.src.utils import pad_sequences
​
# 1. 加载 IMDB 数据集
max_features = 10000  # 使用词汇表中前 10000 个常见单词
maxlen = 100  # 每条评论的最大长度
​
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=max_features)
print(x_train.shape, x_test.shape)
print(x_train[0])
print(y_train)
​
# 2. 数据预处理:对每条评论进行填充,使其长度统一
x_train = pad_sequences(x_train, maxlen=maxlen)
x_test = pad_sequences(x_test, maxlen=maxlen)
​
# 3. 构建 RNN 模型
model = tf.keras.models.Sequential([
    Input(shape=(maxlen,)),
    layers.Embedding(input_dim=max_features, output_dim=128),  # 嵌入层,将单词索引映射为向量 output_dim  嵌入向量的维度(即每个输入词的嵌入表示的长度)
    layers.GRU(units=64, dropout=0.2, recurrent_dropout=0.2),
    # GRU 层:包含 64 个神经元,激活函数默认使用 tanh  dropout表示在每个时间步上丢弃20% recurrent_dropout 递归状态(即隐藏状态)的dropout比率为20%
    layers.Dense(1, activation='sigmoid')  # 输出层:用于二分类(正面或负面),激活函数为 sigmoid
])
​
# 4. 模型编译
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
​
# 5. 模型训练
history = model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test), verbose=1)
​
# 6. 模型评估
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f"Test accuracy: {test_acc}")

运行结果:

相关推荐
codists6 小时前
2025年11月文章一览
python
生而为虫6 小时前
31.Python语言进阶
python·scrapy·django·flask·fastapi·pygame·tornado
言之。6 小时前
Claude Code 实用开发手册
python
计算机毕设小月哥6 小时前
【Hadoop+Spark+python毕设】中国租房信息可视化分析系统、计算机毕业设计、包括数据爬取、Spark、数据分析、数据可视化、Hadoop
后端·python·mysql
2***c4356 小时前
Redis——使用 python 操作 redis 之从 hmse 迁移到 hset
数据库·redis·python
二川bro8 小时前
模型部署实战:Python结合ONNX与TensorRT
开发语言·python
秋邱8 小时前
AI + 社区服务:智慧老年康养助手(轻量化落地方案)
人工智能·python·重构·ar·推荐算法·agi
rising start8 小时前
三、FastAPI :POST 请求、用户接口设计与 Requests 测试
python·网络协议·http·fastapi
CM莫问8 小时前
详解机器学习经典模型(原理及应用)——岭回归
人工智能·python·算法·机器学习·回归
SunnyRivers8 小时前
Python打包指南:编写你的pyproject.toml
python·打包·toml