论文辅助笔记:T2VEC一个疑虑:stackingGRUCell和GRU的区别在哪里?

1 stackingGRUCell

python 复制代码
class StackingGRUCell(nn.Module):
    """
    Multi-layer CRU Cell
    """
    def __init__(self, input_size, hidden_size, num_layers, dropout):
        super(StackingGRUCell, self).__init__()
        self.num_layers = num_layers
        self.grus = nn.ModuleList()
        self.dropout = nn.Dropout(dropout)

        self.grus.append(nn.GRUCell(input_size, hidden_size))
        for i in range(1, num_layers):
            self.grus.append(nn.GRUCell(hidden_size, hidden_size))
python 复制代码
    def forward(self, input, h0):
        """
        Input:
        input (batch, input_size): input tensor
        h0 (num_layers, batch, hidden_size): initial hidden state
        ---
        Output:
        output (batch, hidden_size): the final layer output tensor
        hn (num_layers, batch, hidden_size): the hidden state of each layer
        """
        hn = []
        output = input
        for i, gru in enumerate(self.grus):
            hn_i = gru(output, h0[i])
            #在每一次循环中,输入output会经过一个GRU单元并更新隐藏状态

            hn.append(hn_i)
            if i != self.num_layers - 1:
                output = self.dropout(hn_i)
            else:
                output = hn_i
            #如果不是最后一层,输出会经过一个dropout层。

        hn = torch.stack(hn)
        #将hn列表转变为一个张量
        return output, hn
  • nn.GRU中,hn表示每层的最后一个时间步的隐藏状态。这意味着,对于一个具有seq_len的输入序列,hn会包含每层的seq_len时间步中的最后一个时间步的隐藏状态。
  • StackingGRUCell中,hn是通过每层的GRUCell为给定的单一时间步计算得到的。
  • 所以,**如果seq_len为1,那么nn.GRU的hn和StackingGRUCell的hn应该是相同的?**output更应是如此

2 作为对比的普通GRU

啥也没有的一个普通GRU:

python 复制代码
class StackingGRU_tst(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, dropout):
        super(StackingGRU_tst, self).__init__()
        self.gru = nn.GRU(input_size, hidden_size, num_layers=num_layers, dropout=dropout, batch_first=True)

    def forward(self, input, h0):
        output, hn = self.gru(input, h0)
        return output, hn
python 复制代码
input_size = 5
hidden_size = 10
num_layers = 3
dropout = 0.1
batch_size = 7

3 二者对比前的一些工作

3.1 创建模型

python 复制代码
gru_cell_model = StackingGRUCell(input_size, hidden_size, num_layers, dropout)
gru_cell_model
'''
StackingGRUCell(
  (grus): ModuleList(
    (0): GRUCell(5, 10)
    (1): GRUCell(10, 10)
    (2): GRUCell(10, 10)
  )
  (dropout): Dropout(p=0.1, inplace=False)
)
'''

gru_model = nn.GRU(input_size, hidden_size, num_layers, dropout=dropout)
gru_model
'''
GRU(5, 10, num_layers=3, dropout=0.1)
'''

3.2 参数复制:

python 复制代码
with torch.no_grad():
    for i in range(num_layers):
        # 对于每一层,复制权重和偏置
        getattr(gru_model, 'weight_ih_l' + str(i)).copy_(gru_cell_model.grus[i].weight_ih)
        getattr(gru_model, 'weight_hh_l' + str(i)).copy_(gru_cell_model.grus[i].weight_hh)
        getattr(gru_model, 'bias_ih_l' + str(i)).copy_(gru_cell_model.grus[i].bias_ih)
        getattr(gru_model, 'bias_hh_l' + str(i)).copy_(gru_cell_model.grus[i].bias_hh)

3.3 设置输入和相同的初始hidden state

python 复制代码
input_data = torch.randn(batch_size, input_size)
h0_cell = torch.randn(num_layers, batch_size, hidden_size)
h0_gru = h0_cell.clone()  # 确保从相同的初始状态开始

3.4 分别生成输出结果

由于有dropping的存在,所以每次前向传播之前,都需要设置相同的随机种子

python 复制代码
torch.manual_seed(1215)
output_cell, hn_cell = gru_cell_model(input_data, h0_cell)
torch.manual_seed(1215)
output_gru, hn_gru = gru_model(input_data.unsqueeze(0), h0_gru)

4 比较结果

python 复制代码
torch.allclose(output_cell, output_gru.squeeze(0)),torch.allclose(hn_cell, hn_gru)

#(True, True)

结果是一样的的,所以似乎论文代码里的stackingGRUCell可以被GRU平替?

相关推荐
晓纪同学33 分钟前
QT-简单视觉框架代码
开发语言·qt
威桑33 分钟前
Qt SizePolicy详解:minimum 与 minimumExpanding 的区别
开发语言·qt·扩张策略
飞飞-躺着更舒服36 分钟前
【QT】实现电子飞行显示器(简易版)
开发语言·qt
明月看潮生42 分钟前
青少年编程与数学 02-004 Go语言Web编程 16课题、并发编程
开发语言·青少年编程·并发编程·编程与数学·goweb
明月看潮生1 小时前
青少年编程与数学 02-004 Go语言Web编程 17课题、静态文件
开发语言·青少年编程·编程与数学·goweb
Java Fans1 小时前
C# 中串口读取问题及解决方案
开发语言·c#
盛派网络小助手1 小时前
微信 SDK 更新 Sample,NCF 文档和模板更新,更多更新日志,欢迎解锁
开发语言·人工智能·后端·架构·c#
算法小白(真小白)1 小时前
低代码软件搭建自学第二天——构建拖拽功能
python·低代码·pyqt
唐小旭1 小时前
服务器建立-错误:pyenv环境建立后python版本不对
运维·服务器·python
007php0071 小时前
Go语言zero项目部署后启动失败问题分析与解决
java·服务器·网络·python·golang·php·ai编程