跟着问题学15——GRU网络结构详解及代码实战

1 RNN的缺陷------长期依赖的问题 (The Problem of Long-Term Dependencies)

前面一节我们学习了RNN神经网络,它可以用来处理序列型的数据,比如一段文字,视频等等。RNN网络的基本单元如下图所示,可以将前面的状态作为当前状态的输入。

但也有一些情况,我们需要更"长期"的上下文信息。比如预测最后一个单词"我在中国长大......我说一口流利的**。""短期"的信息显示,下一个单词很可能是一种语言的名字,但如果我们想缩小范围,我们需要更长期语境------"我在中国长大",但这个相关信息与需要它的点之间的距离完全有可能变得非常大。

不幸的是,随着这种距离的扩大,RNN无法学会连接这些信息。

从理论上讲,RNN绝对有能力处理这种"长期依赖性"。人们可以为他们精心选择参数,以解决这种形式的问题。遗憾的是,在实践中,RNN似乎无法学习它们。

幸运的是,GRU也没有这个问题!

2、GRU

什么是GRU

GRU(Gate Recurrent Unit)是循环神经网络(Recurrent Neural Network, RNN)的一种。和LSTM(Long-Short Term Memory)一样,也是为了解决长期记忆和反向传播中的梯度等问题而提出来的。

GRU和LSTM在很多情况下实际表现上相差无几,那么为什么我们要使用新人GRU(2014年提出)而不是相对经受了更多考验的LSTM(1997提出)呢。

用论文中的话说,相比LSTM,使用GRU能够达到相当的效果,并且相比之下更容易进行训练,能够很大程度上提高训练效率,因此很多时候会更倾向于使用GRU。

2.1总体结构框架

前面我们讲到,神经网络的各种结构都是为了挖掘变换数据特征的,所以下面我们也将结合数据特征的维度来对比介绍一下RNN&&LSTM的网络结构。

多层感知机(线性连接层)结构

从特征角度考虑:

输入特征:是n*1的单维向量(这也是为什么卷积神经网络在linear层前要把所有特征层展平),

隐藏层:然后根据隐藏层神经元的数量m将前层输入的特征用m*1的单维向量进行表示(对特征进行了提取变换,隐藏层的数据特征),单个隐藏层的神经元数量就代表网络参数,可以设置多个隐藏层;

输出特征:最终根据输出层的神经元数量y输出y*1的单维向量。

卷积神经网络结构

从特征角度考虑:

输入特征:是(batch)*channel*width*height的张量,

卷积层(等):然后根据输入通道channel的数量c_in和输出通道channel的数量c_out会有c_out*c_in*k*k个卷积核将前层输入的特征进行卷积(对特征进行了提取变换,k为卷积核尺寸),卷积核的大小和数量c_out*c_in*k*k就代表网络参数,可以设置多个卷积层;每一个channel都代表提取某方面的一种特征,该特征用width*height的二维张量表示,不同特征层之间是相互独立的(可以进行融合)。

输出特征:根据场景的需要设置后面的输出,可以是多分类的单维向量等等。

循环神经网络RNN系列结构

从特征角度考虑:

输入特征:是(batch)*T_seq*feature_size的张量(T_seq代表序列长度,注意不是batch_size).

我们来详细对比一下卷积神经网络的输入特征,

(batch)*T_seq*feature_size

(batch)*channel*width*height,

逐个进行分析,RNN系列的基础输入特征表示是feature_size*1的单维向量,比如一个单词的词向量,比如一个股票价格的影响因素向量,而CNN系列的基础输入特征是width*height的二维张量;

再来看一下序列T_seq和通道channel,RNN系列的序列T_seq是指一个连续的输入,比如一句话,一周的股票信息,而且这个序列是有时间先后顺序且互相关联的,而CNN系列的通道channel则是指不同角度的特征,比如彩色图像的RGB三色通道,过程中每个通道代表提取了每个方面的特征,不同通道之间是没有强相关性的,不过也可以进行融合。

最后就是batch,两者都有,在RNN系列,batch就是有多个句子,在CNN系列,就是有多张图片(每个图片可以有多个通道)

隐藏层:明确了输入特征之后,我们再来看看隐藏层代表着什么。隐藏层有T_seq个隐状态H_t(和输入序列长度相同),每个隐状态H_t类似于一个channel,对应着T_seq中的t时刻的输入特征;而每个隐状态H_t是用hidden_size*1的单维向量表示的,所以一个隐含层是T_seq*hidden_size的张量;对应时刻t的输入特征由feature_size*1变为hidden_size*1的向量。如图中所示,同一个隐含层不同时刻的参数W_ih和W_hh是共享的;隐藏层可以有num_layers个(图中只有1个)

以t时刻具体阐述一下:

X_t是t时刻的输入,是一个feature_size*1的向量

W_ih是输入层到隐藏层的权重矩阵

H_t是t时刻的隐藏层的值,是一个hidden_size*1的向量

W_hh是上一时刻的隐藏层的值传入到下一时刻的隐藏层时的权重矩阵

Ot是t时刻RNN网络的输出

从上右图中可以看出这个RNN网络在t时刻接受了输入Xt之后,隐藏层的值是St,输出的值是Ot。但是从结构图中我们可以发现St并不单单只是由Xt决定,还与t-1时刻的隐藏层的值St-1有关。

2.2 GRU的输入输出结构

GRU的输入输出结构与普通的RNN是一样的。有一个当前的输入xt,和上一个节点传递下来的隐状态(hidden state)ht-1 ,这个隐状态包含了之前节点的相关信息。结合xt和 ht-1,GRU会得到当前隐藏节点的输出yt 和传递给下一个节点的隐状态 ht。

图 GRU的输入输出结构

那么,GRU到底有什么特别之处呢?下面来对它的内部结构进行分析!

2.3 GRU的内部结构

不同于LSTM有3个门控,GRU仅有2个门控,

第一个是"重置门"(reset gate),其根据当前时刻的输入xt和上一时刻的隐状态ht-1变换后经sigmoid函数输出介于0和1之间的数字,用于将上一时刻隐状态ht-1重置为ht-1',即ht-1'=ht-1*r。

再将ht-1'与输入xt进行拼接,再通过一个tanh激活函数来将数据放缩到**-1~1**的范围内。即得到如下图2-3所示的h'。

第一个是"更新门"(update gate),其根据当前时刻的输入xt和上一时刻的隐状态ht-1变换后经sigmoid函数输出介于0和1之间的数字,

最终的隐状态ht的更新表达式即为:

再次强调一下,门控信号(这里的z)的范围为0~1。门控信号越接近1,代表"记忆"下来的数据越多;而越接近0则代表"遗忘"的越多。

2.4 小结

GRU很聪明的一点就在于,使用了同一个门控z就同时可以进行遗忘和选择记忆(LSTM则要使用多个门控) 。与LSTM相比,GRU内部少了一个"门控",参数比LSTM少,但是却也能够达到与LSTM相当的功能。考虑到硬件的计算能力时间成本,因而很多时候我们也就会选择更加"实用"的GRU。

3代码

python 复制代码
import torch
import torch.nn as nn


def my_gru(input,initial_states,w_ih,w_hh,b_ih,b_hh):
    h_prev=initial_states
    batch_size,T_seq,feature_size=input.shape
    hidden_size=w_ih.shape[0]//3

    batch_w_ih=w_ih.unsqueeze(0).tile(batch_size,1,1)
    batch_w_hh=w_hh.unsqueeze(0).tile(batch_size,1,1)

    output=torch.zeros(batch_size,T_seq,hidden_size)

    for t in range(T_seq):
        x=input[:,t,:]
        w_times_x=torch.bmm(batch_w_ih,x.unsqueeze(-1))
        w_times_x=w_times_x.squeeze(-1)

       # print(batch_w_hh.shape,h_prev.shape)
        # 计算两个tensor的矩阵乘法,torch.bmm(a,b),tensor a 的size为(b,h,w),tensor b的size为(b,w,m)
        # 也就是说两个tensor的第一维是相等的,然后第一个数组的第三维和第二个数组的第二维度要求一样,
        # 对于剩下的则不做要求,输出维度 (b,h,m)
        # batch_w_hh=batch_size*(3*hidden_size)*hidden_size
        # h_prev=batch_size*hidden_size*1
        # w_times_x=batch_size*hidden_size*1
        ##squeeze,在给定维度(维度值必须为1)上压缩维度,负数代表从后开始数
        w_times_h_prev=torch.bmm(batch_w_hh,h_prev.unsqueeze(-1))
        w_times_h_prev=w_times_h_prev.squeeze(-1)

        r_t=torch.sigmoid(w_times_x[:,:hidden_size]+w_times_h_prev[:,:hidden_size]+b_ih[:hidden_size]
                          +b_hh[:hidden_size])
        z_t=torch.sigmoid(w_times_x[:,hidden_size:2*hidden_size]+w_times_h_prev[:,hidden_size:2*hidden_size]
                          +b_ih[hidden_size:2*hidden_size]+b_hh[hidden_size:2*hidden_size])
        n_t=torch.tanh(w_times_x[:,2*hidden_size:3*hidden_size]+w_times_h_prev[:,2*hidden_size:3*hidden_size]
                          +b_ih[2*hidden_size:3*hidden_size]+b_hh[2*hidden_size:3*hidden_size])

        h_prev=(1-z_t)*n_t+z_t*h_prev
        output[:,t,:]=h_prev

    return output,h_prev


if __name__=="__main__":

    fc=nn.Linear(12,6)
   

    batch_size=2
    T_seq=5
    feature_size=4

    hidden_size=3
   # output_feature_size=3

    input=torch.randn(batch_size,T_seq,feature_size)
    h_prev=torch.randn(batch_size,hidden_size)

    gru_layer=nn.GRU(feature_size,hidden_size,batch_first=True)
    output,h_final=gru_layer(input,h_prev.unsqueeze(0))
    # for k,v in gru_layer.named_parameters():
    #     print(k,v.shape)
    # print(output,h_final)

    my_output, my_h_final=my_gru(input,h_prev,gru_layer.weight_ih_l0,gru_layer.weight_hh_l0,gru_layer.bias_ih_l0,gru_layer.bias_hh_l0)

    # print(my_output, my_h_final)
    # print(torch.allclose(output,my_output))

参考资料

https://zhuanlan.zhihu.com/p/32481747

https://speech.ee.ntu.edu.tw/\~tlkagk/courses/MLDS_2018/Lecture/Seq (v2).pdf

https://www.bilibili.com/video/BV1jm4y1Q7uh/?spm_id_from=333.788\&vd_source=cf7630d31a6ad93edecfb6c5d361c659

相关推荐
不爱原创的Yoga2 分钟前
自动驾驶汽车需要哪些传感器来感知环境
人工智能·自动驾驶·汽车
Golinie8 分钟前
2025年最新深度学习环境搭建:Win11+ cuDNN + CUDA + Pytorch +深度学习环境配置保姆级教程
人工智能·pytorch·深度学习
周杰伦_Jay12 分钟前
Ollama能本地部署Llama 3等大模型的原因解析(ollama核心架构、技术特性、实际应用)
数据结构·人工智能·深度学习·架构·transformer·llama
几道之旅26 分钟前
论文阅读笔记:AI+RPA
人工智能
池央26 分钟前
GAN - 生成对抗网络:生成新的数据样本
人工智能·神经网络·生成对抗网络
golitter.37 分钟前
vscode导入模块不显示类型注解
python
马红权40 分钟前
pyautogui自动化鼠标键盘操作
前端·python
鸭鸭鸭进京赶烤44 分钟前
OpenAI秘密重塑机器人军团: 实体AGI的崛起!
人工智能·opencv·机器学习·ai·机器人·agi·机器翻译引擎
ZHOU_WUYI1 小时前
lightrag源码 : Generate chunks from document
人工智能·rag
cfjybgkmf1 小时前
Python数据类型间的转换及eval函数
开发语言·python