【人工智能基础】RNN实验

一、RNN特性

权重共享

wordi · weight + bais

持久记忆单元

wordi · weightword + baisword + hi · weighth + baish

二、公式化表达

ht = f(h~t - 1~, xt)

ht = tanh(Whhh~t - 1~ + Wxhxt)

yt = Whyht

三、RNN网络正弦波波形预测

环境准备

python 复制代码
import numpy as np
import torch
from torch import nn,optim
from matplotlib import pyplot as plt

# 时间轴采样数
num_time_steps = 50
input_size = 1
hidden_size = 16
output_size = 1
lr = 0.01

RNN类

python 复制代码
class Net(nn.Module):
    def __init__(self,):
        super(Net, self).__init__()
        self.rnn = nn.RNN(
            input_size = input_size, 
            hidden_size = hidden_size, 
            num_layers = 1,
            # 格式为[batch, seq, feature]
            batch_first = True
        )
        for p in self.rnn.parameters():
            nn.init.normal_(p,mean=0.0, std=0.001)
        self.linear = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden_prev):
        out, hidden_prev = self.rnn(x, hidden_prev)
        # [1, seq, h] => [seq, h]
        out = out.view(-1,hidden_size)
        # [seq, h] => [seq, 1]
        out = self.linear(out)
        # [seq, 1] => [1, seq, 1], 需要和y做均方差
        out = out.unsqueeze(dim=0)
        return out, hidden_prev.clone()

正弦数据构建函数

python 复制代码
def create_image():
    start = np.random.randint(3, size=1)[0]
    time_steps = np.linspace(start, start + 10, num_time_steps)
    data = np.sin(time_steps)
    data = data.reshape(num_time_steps, 1)
    x = torch.tensor(data[:-1]).float().view(1, num_time_steps - 1, 1)
    y = torch.tensor(data[1:]).float().view(1, num_time_steps - 1, 1)
    return time_steps,x, y

训练模型

python 复制代码
model = Net()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr)

hidden_prev = torch.zeros(1,1, hidden_size)
for iter in range(6000):
    time_steps,x, y = create_image()
    output, hidden_prev = model(x, hidden_prev)
    hidden_prev = hidden_prev.detach()

    loss = criterion(output,y)
    model.zero_grad()
    loss.backward()
    for p in model.parameters():
        torch.nn.utils.clip_grad_norm_(p,10)
    optimizer.step()

    if iter % 1000 == 0:
        plt.plot(time_steps[:-1], x.ravel(), c = 'b')
        plt.plot(time_steps[:-1], y.ravel(), c= 'r')
        plt.plot(time_steps[:-1], output.detach().numpy().ravel(), c= 'g')
        plt.show()
        print('Iteration:{} loss {}'.format(iter, loss.item()))

可以看到第二次绘制图像的时候,输出曲线基本拟合了目标曲线

图像预测

python 复制代码
time_steps,x, y = create_image()

predictions = []
# input = x[:, 0, :]
for i in range(x.shape[1]):
    input = x[:, i, :].view(1, 1, 1)
    (pred, hiden_prev) = model(input, hidden_prev)
    input = pred
    predictions.append(pred.detach().numpy().ravel()[0])

x = x.data.numpy().ravel()

y = y.data.numpy()
plt.scatter(time_steps[:-1], x.ravel(), s=90)
plt.plot(time_steps[:-1], x.ravel())

plt.scatter(time_steps[1:],predictions)
plt.show()
    

输出的预测曲线基本与目标曲线相同


p.s. 最后的实验应该是输入一个点,通过这个点来预测出整个正弦曲线,但是我尝试了很多次都失败了,只能修改成根据正弦函数的上一个点来预测下一个点

相关推荐
**梯度已爆炸**几秒前
自然语言处理入门
人工智能·自然语言处理
ctrlworks15 分钟前
楼宇自控核心功能:实时监控设备运行,快速诊断故障,赋能设备寿命延长
人工智能·ba系统厂商·楼宇自控系统厂家·ibms系统厂家·建筑管理系统厂家·能耗监测系统厂家
BFT白芙堂1 小时前
睿尔曼系列机器人——以创新驱动未来,重塑智能协作新生态(上)
人工智能·机器学习·机器人·协作机器人·复合机器人·睿尔曼机器人
aneasystone本尊1 小时前
使用 MCP 让 Claude Code 集成外部工具
人工智能
静心问道1 小时前
SEW:无监督预训练在语音识别中的性能-效率权衡
人工智能·语音识别
羊小猪~~1 小时前
【NLP入门系列五】中文文本分类案例
人工智能·深度学习·考研·机器学习·自然语言处理·分类·数据挖掘
xwz小王子1 小时前
从LLM到WM:大语言模型如何进化成具身世界模型?
人工智能·语言模型·自然语言处理
我爱一条柴ya1 小时前
【AI大模型】深入理解 Transformer 架构:自然语言处理的革命引擎
人工智能·ai·ai作画·ai编程·ai写作
静心问道1 小时前
FLAN-T5:规模化指令微调的语言模型
人工智能·语言模型·自然语言处理
李师兄说大模型1 小时前
KDD 2025 | 地理定位中的群体智能:一个多智能体大型视觉语言模型协同框架
人工智能·深度学习·机器学习·语言模型·自然语言处理·大模型·deepseek