TensorFlow2 Python深度学习 - 循环神经网络(GRU)示例

锋哥原创的TensorFlow2 Python深度学习视频教程:

https://www.bilibili.com/video/BV1X5xVz6E4w/

课程介绍

本课程主要讲解基于TensorFlow2的Python深度学习知识,包括深度学习概述,TensorFlow2框架入门知识,以及卷积神经网络(CNN),循环神经网络(RNN),生成对抗网络(GAN),模型保存与加载等。

TensorFlow2 Python深度学习 - 循环神经网络(GRU)示例

GRU(门控循环单元,Gated Recurrent Unit)是一种用于处理序列数据的递归神经网络(RNN)模型。它是为了克服传统RNN在长时间序列中训练时遇到的梯度消失问题(即记忆力衰减)而提出的。GRU相较于LSTM(长短时记忆网络),结构更简单,计算效率更高,但能够实现类似的性能。

GRU的工作原理

GRU通过引入两个重要的门控机制来控制信息的流动:

  1. 更新门(Update Gate):决定了当前时刻的隐藏状态有多少部分应该被更新,多少部分保留旧的隐藏状态信息。这个门类似于LSTM中的遗忘门和输入门的结合。它通过一个Sigmoid激活函数控制当前输入和上一时刻的隐藏状态对当前时刻的影响。

  2. 重置门(Reset Gate):控制遗忘多少旧的隐藏状态信息。它决定了当前时刻输入和过去隐藏状态的结合程度,从而决定了网络保留多少历史信息。通过一个Sigmoid激活函数来计算。

GRU的关键是利用这些门控机制在时间步之间传递信息,从而可以更好地捕获序列数据中的长期依赖性。

复制代码
tf.keras.layers.GRU(
    units,
    activation='tanh',
    recurrent_activation='sigmoid',
    use_bias=True,
    kernel_initializer='glorot_uniform',
    recurrent_initializer='orthogonal',
    bias_initializer='zeros',
    dropout=0.0,
    recurrent_dropout=0.0,
    return_sequences=False,
    return_state=False,
    go_backwards=False,
    stateful=False,
    time_major=False,
    unroll=False,
    reset_after=True,
    **kwargs
)

核心参数:

  1. units - 核心参数
  • 作用:定义GRU层中隐藏单元的数量

  • 影响:决定模型的容量和复杂度,值越大表示记忆能力越强

  • 选择:通常根据任务复杂度在32-512之间选择

  1. activation - 激活函数
  • 作用:定义输出计算的激活函数

  • 默认值'tanh'

  • 功能:控制候选隐藏状态的计算

  1. recurrent_activation - 循环激活函数
  • 作用:定义门控机制的激活函数

  • 默认值'sigmoid'

  • 功能:控制更新门和重置门的计算

  1. return_sequences - 输出控制
  • 作用:控制是否返回所有时间步的输出

  • 默认值False(只返回最后一个时间步)

  • 应用

    • False:用于序列分类任务

    • True:用于序列标注或多层GRU堆叠

  1. return_state - 状态返回
  • 作用:是否返回最后的隐藏状态

  • 使用场景:编码器-解码器架构或需要状态传递的场景

  1. dropout - 输入丢弃率
  • 作用:输入单元的随机丢弃比例,防止过拟合

  • 范围:0.0-1.0,推荐0.2-0.5

  1. recurrent_dropout - 循环丢弃率
  • 作用:循环连接的随机丢弃比例

  • 功能:专门防止循环连接上的过拟合

  1. reset_after - 重置门位置
  • 作用:控制重置门的应用顺序

  • 默认值True(与CuDNN兼容的模式)

  • 影响:影响计算图和性能,通常保持默认

示例:

复制代码
import tensorflow as tf
from keras import Input, layers
from keras.src.utils import pad_sequences
​
# 1. 加载 IMDB 数据集
max_features = 10000  # 使用词汇表中前 10000 个常见单词
maxlen = 100  # 每条评论的最大长度
​
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=max_features)
print(x_train.shape, x_test.shape)
print(x_train[0])
print(y_train)
​
# 2. 数据预处理:对每条评论进行填充,使其长度统一
x_train = pad_sequences(x_train, maxlen=maxlen)
x_test = pad_sequences(x_test, maxlen=maxlen)
​
# 3. 构建 RNN 模型
model = tf.keras.models.Sequential([
    Input(shape=(maxlen,)),
    layers.Embedding(input_dim=max_features, output_dim=128),  # 嵌入层,将单词索引映射为向量 output_dim  嵌入向量的维度(即每个输入词的嵌入表示的长度)
    layers.GRU(units=64, dropout=0.2, recurrent_dropout=0.2),
    # GRU 层:包含 64 个神经元,激活函数默认使用 tanh  dropout表示在每个时间步上丢弃20% recurrent_dropout 递归状态(即隐藏状态)的dropout比率为20%
    layers.Dense(1, activation='sigmoid')  # 输出层:用于二分类(正面或负面),激活函数为 sigmoid
])
​
# 4. 模型编译
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
​
# 5. 模型训练
history = model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test), verbose=1)
​
# 6. 模型评估
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f"Test accuracy: {test_acc}")

运行结果:

相关推荐
王六岁6 小时前
# 🐍 前端开发 0 基础学 Python 入门指南: Python 元组和映射类型深入指南
前端·javascript·python
王六岁6 小时前
# 🐍 前端开发 0 基础学 Python 入门指南:常用的数据类型和列表
前端·javascript·python
南枝异客6 小时前
查找算法-顺序查找
python·算法
花开花富贵6 小时前
不敢去表白?来用代码画♥
python
人间乄惊鸿客6 小时前
python-day8
开发语言·python
渡我白衣6 小时前
未来的 AI 操作系统(八)——灵知之门:当智能系统开始理解存在
人工智能·深度学习·opencv·机器学习·计算机视觉·语言模型·人机交互
Mrliu__6 小时前
Python数据结构(七):Python 高级排序算法:希尔 快速 归并
数据结构·python·排序算法
C嘎嘎嵌入式开发6 小时前
(22)100天python从入门到拿捏《【网络爬虫】网络基础与HTTP协议》
网络·爬虫·python
xiaoxiaode_shu7 小时前
神经网络基础
人工智能·深度学习·神经网络