RNN预测正弦时间点

复制代码
import torch.nn as nn
import torch
import numpy as np
import matplotlib
matplotlib.use('TkAgg')
from matplotlib import pyplot as plt
# net = nn.RNN(100,10) #100个单词,每个单词10个维度
# print(net._parameters.keys())
#序列时间点预测

num_time_steps =50
input_size =1
hidden_size =16
output_size = 1
lr=0.01
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.rnn = nn.RNN(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=1,
            batch_first=True,  #[b,seq,feature]   batch_first=False [seq,b,feature] ,
        )
        self.linear = nn.Linear(hidden_size,output_size)

    def forward(self,x,hidden_prev):
        # hidden_prev=h0 表示最后一个Ht的输出,out是表示[h0,h1,h2,h3....]每一个时间t的输出
        out,hidden_prev = self.rnn(x,hidden_prev)
        #[1,seq,h] => [seq,h]
        out = out.view(-1,hidden_size)
        out = self.linear(out) #[seq,h] => [seq,1]
        out = out.unsqueeze(dim=0)  #=>[1,seq,1]

        return out,hidden_prev

model =Net()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(),lr)

hidden_prev = torch.zeros(1,1,hidden_size) #[b,1,10]

for iter in range(6000):
    start = np.random.randint(10,size=1)[0]
    time_steps = np.linspace(start,start+10,num_time_steps)
    data = np.sin(time_steps)
    data = data.reshape(num_time_steps,1)
    x = torch.tensor(data[:-1]).float().view(1,num_time_steps-1,1)
    y = torch.tensor(data[1:]).float().view(1,num_time_steps-1,1)

    output,hidden_prev = model(x,hidden_prev)
    hidden_prev =hidden_prev.detach()

    loss = criterion(output,y)
    model.zero_grad()
    loss.backward()
    optimizer.step()


    if iter%100 == 0:
        print("Iteration:{} loss{}".format(iter,loss.item()))

predictions = []
input = x[:,0,:]
for _ in range(x.shape[1]):
    input = input.view(1,1,1)
    (pred,hidden_prev) = model(input,hidden_prev)
    input = pred
    predictions.append(pred.detach().numpy().ravel()[0])

x= x.data.numpy().ravel()
y = y.data.numpy()
plt.scatter(time_steps[:-1],x.ravel(),s=90)
plt.plot(time_steps[:-1],predictions)

plt.scatter(time_steps[1:],predictions)
plt.show()
相关推荐
深圳市青牛科技实业有限公司 小芋圆8 小时前
30V N 沟道 MOSFET SP30N06NK 全面解析:参数、特性与应用场景
人工智能·单片机·嵌入式硬件·无人机·高频dc-dc谐振变换器·笔记本电脑开合检测
leafff1239 小时前
AI数据库研究:RAG 架构运行算力需求?
数据库·人工智能·语言模型·自然语言处理·架构
陈辛chenxin9 小时前
【大数据技术01】数据科学的基础理论
大数据·人工智能·python·深度学习·机器学习·数据挖掘·数据分析
极客BIM工作室9 小时前
扩散模型核心机制解析:U-Net调用逻辑、反向传播时机与步骤对称性
人工智能·深度学习·机器学习
从零开始的奋豆9 小时前
计算机视觉(三):特征检测与光流法
人工智能·计算机视觉
一只小风华~9 小时前
HarmonyOS:相对布局(RelativeContainer)
深度学习·华为·harmonyos·鸿蒙
IT_陈寒9 小时前
JavaScript 性能优化实战:我通过这7个技巧将页面加载速度提升了65%
前端·人工智能·后端
骄傲的心别枯萎9 小时前
RV1126 NO.47:RV1126+OPENCV对视频流进行视频腐蚀操作
人工智能·opencv·计算机视觉·音视频·rv1126
骄傲的心别枯萎9 小时前
RV1126 NO.48:RV1126+OPENCV在视频中添加时间戳
人工智能·opencv·计算机视觉·音视频·视频编解码·rv1126
沉迷单车的追风少年9 小时前
Diffusion Models与视频超分(3): 解读当前最快和最强的开源模型FlashVSR
人工智能·深度学习·计算机视觉·aigc·音视频·视频生成·视频超分