机器学习笔记:门控循环单元的建立

目录

介绍

结构

模型原理

重置门与更新门

候选隐状态

输出隐状态

模型实现

引入数据

初始化参数

定义模型

训练与预测

简洁实现GRU

思考


介绍

门控循环单元(Gated Recurrent Unit,简称GRU)是循环神经网络一种较为复杂的构成形式,其用途也是处理时序数据,相比具有单隐藏状态的RNN,GRU具有忘记的能力,可以忘记无用的数据。

结构

与传统RNN相比,GRU的结构引入了 的概念,比RNN复杂许多,不过可以看出,其输入仍然是和上一时间步隐状态,输出仍然是本时间步隐状态。区别在于"细胞"内部结构,RNN只需要将H和X分别处理,之后结合在一起,激活函数激活后将其输出即可。而GRU内部处理十分复杂。

模型原理

我们以处理的顺序来依次讲解各个组成部分的模型原理。

重置门与更新门

首先介绍重置门 (reset gate)更新门 (update gate)。 我们把它们设计成(0,1)区间中的向量。 重置门允许我们控制"可能还想记住"的过去状态的数量; 更新门将允许我们控制新状态中有多少个是旧状态的副本。后面还会再提到两个门的具体作用。

重置门和更新门的计算公式如下所示,由于使用sigmoid函数,的值在(0,1)区间内。

候选隐状态

候选隐状态的计算公式如下,是RNN中计算公式的升级版。(是哈达玛积)

当重置门R的值接近1时,则候选隐状态的计算与RNN一致,当重置门R的值接近0时,则候选隐状态计算时会完全"忘记"之前的值。

输出隐状态

输出隐状态需要更新门,候选隐状态和上一阶段隐状态共同计算得到。

由公式可以看出,当接近0时,隐状态即为候选隐状态,当接近1时,隐状态即为上一阶段隐状态,更新门决定隐状态中有多少部分进行更新。

模型实现

引入数据

我们从零开始实现一个GRU,首先引入相关的库,并定义相关的一系列超参数。

python 复制代码
from mxnet import np, npx
from mxnet.gluon import rnn
from d2l import mxnet as d2l

npx.set_np()

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

初始化参数

将需要学习的参数进行初始化。

python 复制代码
def get_params(vocab_size, num_hiddens, device):
    num_inputs = num_outputs = vocab_size

    def normal(shape):
        return np.random.normal(scale=0.01, size=shape, ctx=device)

    def three():
        return (normal((num_inputs, num_hiddens)),
                normal((num_hiddens, num_hiddens)),
                np.zeros(num_hiddens, ctx=device))

    W_xz, W_hz, b_z = three()  # 更新门参数
    W_xr, W_hr, b_r = three()  # 重置门参数
    W_xh, W_hh, b_h = three()  # 候选隐状态参数
    # 输出层参数
    W_hq = normal((num_hiddens, num_outputs))
    b_q = np.zeros(num_outputs, ctx=device)
    # 附加梯度
    params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]
    for param in params:
        param.attach_grad()
    return params

定义模型

定义门控循环单元模型, 模型的架构与基本的循环神经网络单元是相同的, 只是权重更新公式更为复杂。

python 复制代码
def init_gru_state(batch_size, num_hiddens, device):
    return (np.zeros(shape=(batch_size, num_hiddens), ctx=device), )
def gru(inputs, state, params):
    W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params
    H, = state
    outputs = []
    for X in inputs:
        Z = npx.sigmoid(np.dot(X, W_xz) + np.dot(H, W_hz) + b_z)
        R = npx.sigmoid(np.dot(X, W_xr) + np.dot(H, W_hr) + b_r)
        H_tilda = np.tanh(np.dot(X, W_xh) + np.dot(R * H, W_hh) + b_h)
        H = Z * H + (1 - Z) * H_tilda
        Y = np.dot(H, W_hq) + b_q
        outputs.append(Y)
    return np.concatenate(outputs, axis=0), (H,)

训练与预测

python 复制代码
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params,
                            init_gru_state, gru)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

运行结果如下:

复制代码
perplexity 1.1, 10510.3 tokens/sec on gpu(0)
time travelleryou can show black is white by argument said filby
travelleryou can show black is white by argument said filby

简洁实现GRU

mxnet框架中自带GRU的API,可以直接调用。GRU唯一需要的参数就是隐藏单元的数量。

接下来根据上一篇文章中定义好的train_ch8进行反向计算更新参数并进行预测即可。

python 复制代码
gru_layer = rnn.GRU(num_hiddens)
model = d2l.RNNModel(gru_layer, len(vocab))
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

运行结果如下:

复制代码
perplexity 1.1, 183591.3 tokens/sec on gpu(0)
time traveller for so it will be convenient to speak of himwas e
travelleryou can show black is white by argument said filby

思考

  1. 如果仅仅实现门控循环单元的一部分,例如,只有一个重置门或一个更新门会怎样?

  2. 比较rnn.RNNrnn.GRU的不同实现对运行时间、困惑度和输出字符串的影响。

相关推荐
葫三生8 分钟前
如何评价《论三生原理》在科技界的地位?
人工智能·算法·机器学习·数学建模·量子计算
m0_751336391 小时前
突破性进展:超短等离子体脉冲实现单电子量子干涉,为飞行量子比特奠定基础
人工智能·深度学习·量子计算·材料科学·光子器件·光子学·无线电电子
美狐美颜sdk4 小时前
跨平台直播美颜SDK集成实录:Android/iOS如何适配贴纸功能
android·人工智能·ios·架构·音视频·美颜sdk·第三方美颜sdk
DeepSeek-大模型系统教程5 小时前
推荐 7 个本周 yyds 的 GitHub 项目。
人工智能·ai·语言模型·大模型·github·ai大模型·大模型学习
有Li5 小时前
通过具有一致性嵌入的大语言模型实现端到端乳腺癌放射治疗计划制定|文献速递-最新论文分享
论文阅读·深度学习·分类·医学生
郭庆汝5 小时前
pytorch、torchvision与python版本对应关系
人工智能·pytorch·python
IT古董5 小时前
【第二章:机器学习与神经网络概述】03.类算法理论与实践-(3)决策树分类器
神经网络·算法·机器学习
小雷FansUnion7 小时前
深入理解MCP架构:智能服务编排、上下文管理与动态路由实战
人工智能·架构·大模型·mcp
资讯分享周7 小时前
扣子空间PPT生产力升级:AI智能生成与多模态创作新时代
人工智能·powerpoint
叶子爱分享8 小时前
计算机视觉与图像处理的关系
图像处理·人工智能·计算机视觉