1、推导BPTT
2、推导梯度
3、分析其可能存在梯度爆炸的原因并给出解决方法
为了改善循环神经网络的长程依赖问题,一种非常好的解决方案是在公 式(6.50)的基础上引入门控机制来控制信息的累积速度,包括有选择地加入新 的信息,并有选择地遗忘之前累积的信息.这一类网络可以称为基于门控的循环 神经网络(Gated RNN).
为了便于理解,下面是几张完整的数据处理的流程:
遗忘门:
输入门:
细胞状态:
输出门:
4、设计简单RNN模型,分别用Numpy、Pytorch实现反向传播算子,并代入数值测试.
python
import torch
import numpy as np
class RNNCell:
# 初始化RNN单元的属性
def __init__(self, weight_ih, weight_hh,
bias_ih, bias_hh):
# 输入和隐藏状态的权重矩阵
self.weight_ih = weight_ih
self.weight_hh = weight_hh
# 输入和隐藏状态的偏置
self.bias_ih = bias_ih
self.bias_hh = bias_hh
self.x_stack = [] # 存储输入输入
self.dx_list = [] # 存储对输入的梯度
# 存储对权重矩阵的梯度
self.dw_ih_stack = []
self.dw_hh_stack = []
# 存储对偏置的梯度
self.db_ih_stack = []
self.db_hh_stack = []
self.prev_hidden_stack = [] # 存储前一步的隐藏状态
self.next_hidden_stack = [] # 存储下一步的隐藏状态
# temporary cache
self.prev_dh = None # 临时缓存上一步的梯度
# 接收当前的输入和前一步的隐藏状态,计算下一步的隐藏状态,并返回
def __call__(self, x, prev_hidden):
self.x_stack.append(x) # 将输入添加到x_stack中
# 计算下一步的隐藏状态,通过激活函数tanh 到线性组合的权重和偏置上
next_h = np.tanh(
np.dot(x, self.weight_ih.T)
+ np.dot(prev_hidden, self.weight_hh.T)
+ self.bias_ih + self.bias_hh)
# 将前一步的隐藏状态和下一步的隐藏状态加到相应的堆栈中
self.prev_hidden_stack.append(prev_hidden)
self.next_hidden_stack.append(next_h)
# clean cache 清空上一步的梯度缓存
self.prev_dh = np.zeros(next_h.shape)
return next_h
# 计算损失函数对权重的梯度
def backward(self, dh):
# 从堆栈中取出当前的输入、前一步的隐藏状态和下一步的隐藏状态
x = self.x_stack.pop()
prev_hidden = self.prev_hidden_stack.pop()
next_hidden = self.next_hidden_stack.pop()
# 使用链式法则,计算tanh激活函数的梯度
# 更新上一步的梯度缓存,计算对权重的梯度并存储到相应的堆栈中
d_tanh = (dh + self.prev_dh) * (1 - next_hidden ** 2)
self.prev_dh = np.dot(d_tanh, self.weight_hh)
dx = np.dot(d_tanh, self.weight_ih)
self.dx_list.insert(0, dx)
dw_ih = np.dot(d_tanh.T, x)
self.dw_ih_stack.append(dw_ih)
dw_hh = np.dot(d_tanh.T, prev_hidden)
self.dw_hh_stack.append(dw_hh)
self.db_ih_stack.append(d_tanh)
self.db_hh_stack.append(d_tanh)
return self.dx_list
if __name__ == '__main__':
# 设置随机种子 为确保numpy和pytorch在生成随机数时得到相同的结果
np.random.seed(123)
torch.random.manual_seed(123)
# 设置numpy的打印选项、输出格式
# precision设置小数点后保留位数 suppress=True表示打印小数时不显示末尾的零
np.set_printoptions(precision=6, suppress=True)
# 使用pytorch定义RNN,输入特征数量为4,隐藏层单元数量为5 double确保使用双精度浮点数
rnn_PyTorch = torch.nn.RNN(4, 5).double()
# 将pytorch的RNN权重转换为numpy格式,自定义numpy RNN
rnn_numpy = RNNCell(rnn_PyTorch.all_weights[0][0].data.numpy(),
rnn_PyTorch.all_weights[0][1].data.numpy(),
rnn_PyTorch.all_weights[0][2].data.numpy(),
rnn_PyTorch.all_weights[0][3].data.numpy())
nums = 3 # 设置轮数
# 生成随机输入、隐藏状态和梯度数据
# 输入数据
x3_numpy = np.random.random((nums, 3, 4))
x3_tensor = torch.tensor(x3_numpy, requires_grad=True)
# 初始隐藏状态
h3_numpy = np.random.random((1, 3, 5))
h3_tensor = torch.tensor(h3_numpy, requires_grad=True)
# 反向传播梯度
dh_numpy = np.random.random((nums, 3, 5))
dh_tensor = torch.tensor(dh_numpy, requires_grad=True)
# 使用pytorch的RNN进行前向传播
h3_tensor = rnn_PyTorch(x3_tensor, h3_tensor)
# 使用自定义的numpy RNN进行前向传播
h_numpy_list = []
h_numpy = h3_numpy[0]
# 逐步将输入数据喂给numpy RNN,并收集每一步的输出隐藏状态
for i in range(nums):
h_numpy = rnn_numpy(x3_numpy[i], h_numpy)
h_numpy_list.append(h_numpy)
# pytorch RNN 进行反向传播
h3_tensor[0].backward(dh_tensor)
for i in reversed(range(nums)):
rnn_numpy.backward(dh_numpy[i])
# 打印两个RNN的各类参数以及隐藏状态 进行比较
print("numpy_hidden :\n", np.array(h_numpy_list))
print("torch_hidden :\n", h3_tensor[0].data.numpy())
print('=' * 20)
print("dx_numpy :\n", np.array(rnn_numpy.dx_list))
print("dx_torch :\n", x3_tensor.grad.data.numpy())
print('=' * 20)
print("dw_ih_numpy :\n",
np.sum(rnn_numpy.dw_ih_stack, axis=0))
print("dw_ih_torch :\n",
rnn_PyTorch.all_weights[0][0].grad.data.numpy())
print('=' * 20)
print("dw_hh_numpy :\n",
np.sum(rnn_numpy.dw_hh_stack, axis=0))
print("dw_hh_torch :\n",
rnn_PyTorch.all_weights[0][1].grad.data.numpy())
print('=' * 20)
print("db_ih_numpy :\n",
np.sum(rnn_numpy.db_ih_stack, axis=(0, 1)))
print("db_ih_torch :\n",
rnn_PyTorch.all_weights[0][2].grad.data.numpy())
print("db_hh_numpy :\n",
np.sum(rnn_numpy.db_hh_stack, axis=(0, 1)))
print("db_hh_torch :\n",
rnn_PyTorch.all_weights[0][3].grad.data.numpy())
总结:
这次作业算是从RNN到LSTM的一个过渡,总体难度不算太大
RNN反向传播的参数梯度推导没什么,那三个参数的推导过程都是重复的,只要改变符号就行
那个代码,参考别人的,自己写不出来。
参考链接:
NNDL 作业十 RNN-BPTT_rnn计算例题-CSDN博客