TensorFlow2 Python深度学习 - 循环神经网络(GRU)示例

锋哥原创的TensorFlow2 Python深度学习视频教程:

https://www.bilibili.com/video/BV1X5xVz6E4w/

课程介绍

本课程主要讲解基于TensorFlow2的Python深度学习知识,包括深度学习概述,TensorFlow2框架入门知识,以及卷积神经网络(CNN),循环神经网络(RNN),生成对抗网络(GAN),模型保存与加载等。

TensorFlow2 Python深度学习 - 循环神经网络(GRU)示例

GRU(门控循环单元,Gated Recurrent Unit)是一种用于处理序列数据的递归神经网络(RNN)模型。它是为了克服传统RNN在长时间序列中训练时遇到的梯度消失问题(即记忆力衰减)而提出的。GRU相较于LSTM(长短时记忆网络),结构更简单,计算效率更高,但能够实现类似的性能。

GRU的工作原理

GRU通过引入两个重要的门控机制来控制信息的流动:

  1. 更新门(Update Gate):决定了当前时刻的隐藏状态有多少部分应该被更新,多少部分保留旧的隐藏状态信息。这个门类似于LSTM中的遗忘门和输入门的结合。它通过一个Sigmoid激活函数控制当前输入和上一时刻的隐藏状态对当前时刻的影响。

  2. 重置门(Reset Gate):控制遗忘多少旧的隐藏状态信息。它决定了当前时刻输入和过去隐藏状态的结合程度,从而决定了网络保留多少历史信息。通过一个Sigmoid激活函数来计算。

GRU的关键是利用这些门控机制在时间步之间传递信息,从而可以更好地捕获序列数据中的长期依赖性。

复制代码
tf.keras.layers.GRU(
    units,
    activation='tanh',
    recurrent_activation='sigmoid',
    use_bias=True,
    kernel_initializer='glorot_uniform',
    recurrent_initializer='orthogonal',
    bias_initializer='zeros',
    dropout=0.0,
    recurrent_dropout=0.0,
    return_sequences=False,
    return_state=False,
    go_backwards=False,
    stateful=False,
    time_major=False,
    unroll=False,
    reset_after=True,
    **kwargs
)

核心参数:

  1. units - 核心参数
  • 作用:定义GRU层中隐藏单元的数量

  • 影响:决定模型的容量和复杂度,值越大表示记忆能力越强

  • 选择:通常根据任务复杂度在32-512之间选择

  1. activation - 激活函数
  • 作用:定义输出计算的激活函数

  • 默认值'tanh'

  • 功能:控制候选隐藏状态的计算

  1. recurrent_activation - 循环激活函数
  • 作用:定义门控机制的激活函数

  • 默认值'sigmoid'

  • 功能:控制更新门和重置门的计算

  1. return_sequences - 输出控制
  • 作用:控制是否返回所有时间步的输出

  • 默认值False(只返回最后一个时间步)

  • 应用

    • False:用于序列分类任务

    • True:用于序列标注或多层GRU堆叠

  1. return_state - 状态返回
  • 作用:是否返回最后的隐藏状态

  • 使用场景:编码器-解码器架构或需要状态传递的场景

  1. dropout - 输入丢弃率
  • 作用:输入单元的随机丢弃比例,防止过拟合

  • 范围:0.0-1.0,推荐0.2-0.5

  1. recurrent_dropout - 循环丢弃率
  • 作用:循环连接的随机丢弃比例

  • 功能:专门防止循环连接上的过拟合

  1. reset_after - 重置门位置
  • 作用:控制重置门的应用顺序

  • 默认值True(与CuDNN兼容的模式)

  • 影响:影响计算图和性能,通常保持默认

示例:

复制代码
import tensorflow as tf
from keras import Input, layers
from keras.src.utils import pad_sequences
​
# 1. 加载 IMDB 数据集
max_features = 10000  # 使用词汇表中前 10000 个常见单词
maxlen = 100  # 每条评论的最大长度
​
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=max_features)
print(x_train.shape, x_test.shape)
print(x_train[0])
print(y_train)
​
# 2. 数据预处理:对每条评论进行填充,使其长度统一
x_train = pad_sequences(x_train, maxlen=maxlen)
x_test = pad_sequences(x_test, maxlen=maxlen)
​
# 3. 构建 RNN 模型
model = tf.keras.models.Sequential([
    Input(shape=(maxlen,)),
    layers.Embedding(input_dim=max_features, output_dim=128),  # 嵌入层,将单词索引映射为向量 output_dim  嵌入向量的维度(即每个输入词的嵌入表示的长度)
    layers.GRU(units=64, dropout=0.2, recurrent_dropout=0.2),
    # GRU 层:包含 64 个神经元,激活函数默认使用 tanh  dropout表示在每个时间步上丢弃20% recurrent_dropout 递归状态(即隐藏状态)的dropout比率为20%
    layers.Dense(1, activation='sigmoid')  # 输出层:用于二分类(正面或负面),激活函数为 sigmoid
])
​
# 4. 模型编译
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
​
# 5. 模型训练
history = model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test), verbose=1)
​
# 6. 模型评估
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f"Test accuracy: {test_acc}")

运行结果:

相关推荐
百锦再1 小时前
第11章 泛型、trait与生命周期
android·网络·人工智能·python·golang·rust·go
zbhbbedp282793cl3 小时前
如何在VSCode中安装Python扩展?
ide·vscode·python
Python私教5 小时前
Python 开发环境安装与配置全指南(2025版)
开发语言·python
百锦再5 小时前
第12章 测试编写
android·java·开发语言·python·rust·go·erlang
熠熠仔5 小时前
QGIS 3.34+ 网络分析基础数据自动化生成:从脚本到应用
python·数据分析
测试19985 小时前
Appium使用指南与自动化测试案例详解
自动化测试·软件测试·python·测试工具·职场和发展·appium·测试用例
神仙别闹6 小时前
基于 C++和 Python 实现计算机视觉
c++·python·计算机视觉
hongjianMa6 小时前
【论文阅读】Hypercomplex Prompt-aware Multimodal Recommendation
论文阅读·python·深度学习·机器学习·prompt·推荐系统
饼干,7 小时前
第23天python内容
开发语言·python
现在,此刻7 小时前
李沐深度学习笔记D3-线性回归
笔记·深度学习·线性回归