目录
rnn学习笔记
python
import torch
def rnn(inputs, state, params):
# inputs的形状: (时间步数量, 批次大小, 词表大小)
W_xh, W_hh, b_h, W_hq, b_q = params
H = state
outputs = []
# 遍历每个时间步
for X in inputs:
# 计算隐藏状态 H
H = torch.tanh(torch.mm(X, W_xh) + torch.mm(H, W_hh) + b_h)
# 计算输出 Y
Y = torch.mm(H, W_hq) + b_q
outputs.append(Y)
# 返回输出和新的隐藏状态
return torch.cat(outputs, dim=0), (H,)
# 参数示例初始化(根据实际情况调整)
input_size = 10 # 词表大小
hidden_size = 20 # 隐藏层大小
output_size = 5 # 输出大小
# 初始化参数
W_xh = torch.randn(input_size, hidden_size)
W_hh = torch.randn(hidden_size, hidden_size)
b_h = torch.randn(hidden_size)
W_hq = torch.randn(hidden_size, output_size)
b_q = torch.randn(output_size)
params = (W_xh, W_hh, b_h, W_hq, b_q)
state = (torch.zeros(4,hidden_size))
# 输入示例
time_steps = 3
batch_size = 4
inputs = torch.randn(time_steps, batch_size, input_size)
# 调用RNN函数
outputs, new_state = rnn(inputs, state, params)
print(outputs)
print(new_state)
lstm学习笔记
python
import torch
import torch.nn as nn
def lstm(inputs, state, params):
# inputs的形状: (时间步数量, 批次大小, 词表大小)
W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q = params
(H, C) = state
outputs = []
# 遍历每个时间步
for X in inputs:
I = torch.sigmoid(torch.mm(X, W_xi) + torch.mm(H, W_hi) + b_i)
F = torch.sigmoid(torch.mm(X, W_xf) + torch.mm(H, W_hf) + b_f)
O = torch.sigmoid(torch.mm(X, W_xo) + torch.mm(H, W_ho) + b_o)
C_tilda = torch.tanh(torch.mm(X, W_xc) + torch.mm(H, W_hc) + b_c)
C = F * C + I * C_tilda
H = O * torch.tanh(C)
Y = torch.mm(H, W_hq) + b_q
outputs.append(Y)
return torch.cat(outputs, dim=0), (H, C)
# 参数示例初始化(根据实际情况调整)
input_size = 10 # 词表大小
hidden_size = 20 # 隐藏层大小
output_size = 5 # 输出大小
# 初始化参数
W_xi = torch.randn(input_size, hidden_size)
W_hi = torch.randn(hidden_size, hidden_size)
b_i = torch.zeros(hidden_size)
W_xf = torch.randn(input_size, hidden_size)
W_hf = torch.randn(hidden_size, hidden_size)
b_f = torch.zeros(hidden_size)
W_xo = torch.randn(input_size, hidden_size)
W_ho = torch.randn(hidden_size, hidden_size)
b_o = torch.zeros(hidden_size)
W_xc = torch.randn(input_size, hidden_size)
W_hc = torch.randn(hidden_size, hidden_size)
b_c = torch.zeros(hidden_size)
W_hq = torch.randn(hidden_size, output_size)
b_q = torch.zeros(output_size)
# 输入示例
time_steps = 3
batch_size = 4
inputs = torch.randn(time_steps, batch_size, input_size)
params = (W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q)
state = (torch.zeros(batch_size, hidden_size), torch.zeros(batch_size, hidden_size)) # 初始隐藏状态和单元状态
# 调用LSTM函数
outputs, new_state = lstm(inputs, state, params)
print(outputs)
print(new_state)