【人工智能基础】RNN实验

一、RNN特性

权重共享

word~i~ · weight + bais

持久记忆单元

word~i~ · weight~word~ + bais~word~ + h~i~ · weight~h~ + bais~h~

二、公式化表达

h~t~ = f(h~t - 1~, x~t~)

h~t~ = tanh(W~hh~h~t - 1~ + W~xh~x~t~)

y~t~ = W~hy~h~t~

三、RNN网络正弦波波形预测

环境准备

python 复制代码
import numpy as np
import torch
from torch import nn,optim
from matplotlib import pyplot as plt

# 时间轴采样数
num_time_steps = 50
input_size = 1
hidden_size = 16
output_size = 1
lr = 0.01

RNN类

python 复制代码
class Net(nn.Module):
    def __init__(self,):
        super(Net, self).__init__()
        self.rnn = nn.RNN(
            input_size = input_size, 
            hidden_size = hidden_size, 
            num_layers = 1,
            # 格式为[batch, seq, feature]
            batch_first = True
        )
        for p in self.rnn.parameters():
            nn.init.normal_(p,mean=0.0, std=0.001)
        self.linear = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden_prev):
        out, hidden_prev = self.rnn(x, hidden_prev)
        # [1, seq, h] => [seq, h]
        out = out.view(-1,hidden_size)
        # [seq, h] => [seq, 1]
        out = self.linear(out)
        # [seq, 1] => [1, seq, 1], 需要和y做均方差
        out = out.unsqueeze(dim=0)
        return out, hidden_prev.clone()

正弦数据构建函数

python 复制代码
def create_image():
    start = np.random.randint(3, size=1)[0]
    time_steps = np.linspace(start, start + 10, num_time_steps)
    data = np.sin(time_steps)
    data = data.reshape(num_time_steps, 1)
    x = torch.tensor(data[:-1]).float().view(1, num_time_steps - 1, 1)
    y = torch.tensor(data[1:]).float().view(1, num_time_steps - 1, 1)
    return time_steps,x, y

训练模型

python 复制代码
model = Net()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr)

hidden_prev = torch.zeros(1,1, hidden_size)
for iter in range(6000):
    time_steps,x, y = create_image()
    output, hidden_prev = model(x, hidden_prev)
    hidden_prev = hidden_prev.detach()

    loss = criterion(output,y)
    model.zero_grad()
    loss.backward()
    for p in model.parameters():
        torch.nn.utils.clip_grad_norm_(p,10)
    optimizer.step()

    if iter % 1000 == 0:
        plt.plot(time_steps[:-1], x.ravel(), c = 'b')
        plt.plot(time_steps[:-1], y.ravel(), c= 'r')
        plt.plot(time_steps[:-1], output.detach().numpy().ravel(), c= 'g')
        plt.show()
        print('Iteration:{} loss {}'.format(iter, loss.item()))

可以看到第二次绘制图像的时候,输出曲线基本拟合了目标曲线

图像预测

python 复制代码
time_steps,x, y = create_image()

predictions = []
# input = x[:, 0, :]
for i in range(x.shape[1]):
    input = x[:, i, :].view(1, 1, 1)
    (pred, hiden_prev) = model(input, hidden_prev)
    input = pred
    predictions.append(pred.detach().numpy().ravel()[0])

x = x.data.numpy().ravel()

y = y.data.numpy()
plt.scatter(time_steps[:-1], x.ravel(), s=90)
plt.plot(time_steps[:-1], x.ravel())

plt.scatter(time_steps[1:],predictions)
plt.show()
    

输出的预测曲线基本与目标曲线相同


p.s. 最后的实验应该是输入一个点,通过这个点来预测出整个正弦曲线,但是我尝试了很多次都失败了,只能修改成根据正弦函数的上一个点来预测下一个点

相关推荐
AI极客菌1 小时前
Controlnet作者新作IC-light V2:基于FLUX训练,支持处理风格化图像,细节远高于SD1.5。
人工智能·计算机视觉·ai作画·stable diffusion·aigc·flux·人工智能作画
阿_旭1 小时前
一文读懂| 自注意力与交叉注意力机制在计算机视觉中作用与基本原理
人工智能·深度学习·计算机视觉·cross-attention·self-attention
王哈哈^_^1 小时前
【数据集】【YOLO】【目标检测】交通事故识别数据集 8939 张,YOLO道路事故目标检测实战训练教程!
前端·人工智能·深度学习·yolo·目标检测·计算机视觉·pyqt
Power20246662 小时前
NLP论文速读|LongReward:基于AI反馈来提升长上下文大语言模型
人工智能·深度学习·机器学习·自然语言处理·nlp
数据猎手小k2 小时前
AIDOVECL数据集:包含超过15000张AI生成的车辆图像数据集,目的解决旨在解决眼水平分类和定位问题。
人工智能·分类·数据挖掘
好奇龙猫2 小时前
【学习AI-相关路程-mnist手写数字分类-win-硬件:windows-自我学习AI-实验步骤-全连接神经网络(BPnetwork)-操作流程(3) 】
人工智能·算法
沉下心来学鲁班2 小时前
复现LLM:带你从零认识语言模型
人工智能·语言模型
数据猎手小k2 小时前
AndroidLab:一个系统化的Android代理框架,包含操作环境和可复现的基准测试,支持大型语言模型和多模态模型。
android·人工智能·机器学习·语言模型
YRr YRr2 小时前
深度学习:循环神经网络(RNN)详解
人工智能·rnn·深度学习
sp_fyf_20242 小时前
计算机前沿技术-人工智能算法-大语言模型-最新研究进展-2024-11-01
人工智能·深度学习·神经网络·算法·机器学习·语言模型·数据挖掘