import torch.nn as nn
import torch
import numpy as np
import matplotlib
matplotlib.use('TkAgg')
from matplotlib import pyplot as plt
# net = nn.RNN(100,10) #100个单词,每个单词10个维度
# print(net._parameters.keys())
#序列时间点预测
num_time_steps =50
input_size =1
hidden_size =16
output_size = 1
lr=0.01
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.rnn = nn.RNN(
input_size=input_size,
hidden_size=hidden_size,
num_layers=1,
batch_first=True, #[b,seq,feature] batch_first=False [seq,b,feature] ,
)
self.linear = nn.Linear(hidden_size,output_size)
def forward(self,x,hidden_prev):
# hidden_prev=h0 表示最后一个Ht的输出,out是表示[h0,h1,h2,h3....]每一个时间t的输出
out,hidden_prev = self.rnn(x,hidden_prev)
#[1,seq,h] => [seq,h]
out = out.view(-1,hidden_size)
out = self.linear(out) #[seq,h] => [seq,1]
out = out.unsqueeze(dim=0) #=>[1,seq,1]
return out,hidden_prev
model =Net()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(),lr)
hidden_prev = torch.zeros(1,1,hidden_size) #[b,1,10]
for iter in range(6000):
start = np.random.randint(10,size=1)[0]
time_steps = np.linspace(start,start+10,num_time_steps)
data = np.sin(time_steps)
data = data.reshape(num_time_steps,1)
x = torch.tensor(data[:-1]).float().view(1,num_time_steps-1,1)
y = torch.tensor(data[1:]).float().view(1,num_time_steps-1,1)
output,hidden_prev = model(x,hidden_prev)
hidden_prev =hidden_prev.detach()
loss = criterion(output,y)
model.zero_grad()
loss.backward()
optimizer.step()
if iter%100 == 0:
print("Iteration:{} loss{}".format(iter,loss.item()))
predictions = []
input = x[:,0,:]
for _ in range(x.shape[1]):
input = input.view(1,1,1)
(pred,hidden_prev) = model(input,hidden_prev)
input = pred
predictions.append(pred.detach().numpy().ravel()[0])
x= x.data.numpy().ravel()
y = y.data.numpy()
plt.scatter(time_steps[:-1],x.ravel(),s=90)
plt.plot(time_steps[:-1],predictions)
plt.scatter(time_steps[1:],predictions)
plt.show()
RNN预测正弦时间点
月疯2024-03-10 1:08
相关推荐
千天夜21 分钟前
激活函数解析:神经网络背后的“驱动力”大数据面试宝典22 分钟前
用AI来写SQL:让ChatGPT成为你的数据库助手封步宇AIGC27 分钟前
量化交易系统开发-实时行情自动化交易-3.4.1.2.A股交易数据m0_5236742129 分钟前
技术前沿:从强化学习到Prompt Engineering,业务流程管理的创新之路HappyAcmen39 分钟前
IDEA部署AI代写插件噜噜噜噜鲁先森1 小时前
看懂本文,入门神经网络Neural NetworkInheritGuo2 小时前
It’s All About Your Sketch: Democratising Sketch Control in Diffusion Modelsweixin_307779132 小时前
证明存在常数c, C > 0,使得在一系列特定条件下,某个特定投资时刻出现的概率与天数的对数成反比封步宇AIGC2 小时前
量化交易系统开发-实时行情自动化交易-3.4.1.6.A股宏观经济数据Jack黄从零学c++2 小时前
opencv(c++)图像的灰度转换