rnn 和lstm源码学习笔记

目录

rnn学习笔记

lstm学习笔记


rnn学习笔记

python 复制代码
import torch

def rnn(inputs, state, params):
    # inputs的形状: (时间步数量, 批次大小, 词表大小)
    W_xh, W_hh, b_h, W_hq, b_q = params
    H = state
    outputs = []
    # 遍历每个时间步
    for X in inputs:
        # 计算隐藏状态 H
        H = torch.tanh(torch.mm(X, W_xh) + torch.mm(H, W_hh) + b_h)
        # 计算输出 Y
        Y = torch.mm(H, W_hq) + b_q
        outputs.append(Y)
    # 返回输出和新的隐藏状态
    return torch.cat(outputs, dim=0), (H,)

# 参数示例初始化(根据实际情况调整)
input_size = 10  # 词表大小
hidden_size = 20  # 隐藏层大小
output_size = 5  # 输出大小

# 初始化参数
W_xh = torch.randn(input_size, hidden_size)
W_hh = torch.randn(hidden_size, hidden_size)
b_h = torch.randn(hidden_size)
W_hq = torch.randn(hidden_size, output_size)
b_q = torch.randn(output_size)

params = (W_xh, W_hh, b_h, W_hq, b_q)
state = (torch.zeros(4,hidden_size))

# 输入示例
time_steps = 3
batch_size = 4
inputs = torch.randn(time_steps, batch_size, input_size)

# 调用RNN函数
outputs, new_state = rnn(inputs, state, params)
print(outputs)
print(new_state)

lstm学习笔记

python 复制代码
import torch
import torch.nn as nn

def lstm(inputs, state, params):
    # inputs的形状: (时间步数量, 批次大小, 词表大小)
    W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q = params
    (H, C) = state
    outputs = []
    # 遍历每个时间步
    for X in inputs:
        I = torch.sigmoid(torch.mm(X, W_xi) + torch.mm(H, W_hi) + b_i)
        F = torch.sigmoid(torch.mm(X, W_xf) + torch.mm(H, W_hf) + b_f)
        O = torch.sigmoid(torch.mm(X, W_xo) + torch.mm(H, W_ho) + b_o)
        C_tilda = torch.tanh(torch.mm(X, W_xc) + torch.mm(H, W_hc) + b_c)

        C = F * C + I * C_tilda
        H = O * torch.tanh(C)

        Y = torch.mm(H, W_hq) + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H, C)

# 参数示例初始化(根据实际情况调整)
input_size = 10  # 词表大小
hidden_size = 20  # 隐藏层大小
output_size = 5  # 输出大小

# 初始化参数
W_xi = torch.randn(input_size, hidden_size)
W_hi = torch.randn(hidden_size, hidden_size)
b_i = torch.zeros(hidden_size)
W_xf = torch.randn(input_size, hidden_size)
W_hf = torch.randn(hidden_size, hidden_size)
b_f = torch.zeros(hidden_size)
W_xo = torch.randn(input_size, hidden_size)
W_ho = torch.randn(hidden_size, hidden_size)
b_o = torch.zeros(hidden_size)
W_xc = torch.randn(input_size, hidden_size)
W_hc = torch.randn(hidden_size, hidden_size)
b_c = torch.zeros(hidden_size)
W_hq = torch.randn(hidden_size, output_size)
b_q = torch.zeros(output_size)



# 输入示例
time_steps = 3
batch_size = 4
inputs = torch.randn(time_steps, batch_size, input_size)

params = (W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q)
state = (torch.zeros(batch_size, hidden_size), torch.zeros(batch_size, hidden_size))  # 初始隐藏状态和单元状态

# 调用LSTM函数
outputs, new_state = lstm(inputs, state, params)
print(outputs)
print(new_state)
相关推荐
炽烈小老头10 小时前
【每天学习一点算法 2026/03/08】相交链表
学习·算法·链表
red_redemption13 小时前
自由学习记录(130)
学习·soa·aos·ecs已成核心包·shading!=ps
双叶83613 小时前
(Python)Python爬虫入门教程:从零开始学习网页抓取(爬虫教学)(Python教学)
后端·爬虫·python·学习
天外来鹿13 小时前
Map/Set/WeakMap/WeakSet学习笔记
前端·javascript·笔记·学习
峥嵘life14 小时前
Android16 【GTS】 GtsDevicePolicyTestCases 测试存在Failed项
android·linux·学习
leixj02514 小时前
SVN学习笔记
笔记·学习·svn
毕设源码_廖学姐14 小时前
计算机毕业设计springboot古诗词学习App 基于SpringBoot的中华经典诗文数字化研习平台 SpringBoot框架下的传统诗词文化移动学习系统
spring boot·学习·课程设计
盐焗西兰花16 小时前
鸿蒙学习实战之路-Share Kit系列(7/17)-自定义分享面板操作区
linux·学习·harmonyos
香水5只用六神16 小时前
【RTOS快速入门】07_同步互斥与通信概述
单片机·嵌入式硬件·学习·操作系统·freertos·rtos·嵌入式软件
庭前云落17 小时前
从零开始的Hardhat学习 1| Hardhat 的基本使用、部署智能合约
学习·智能合约