【Datawhale学习笔记】动手学RNN及LSTM

从零实现一个 RNN

RNN 公式简化

为了与后续的代码实现保持一致,此处采用一个不含偏置项(bias)的简化版 RNN,核心计算公式如下:

h t = tanh ⁡ ( U x t + W h t − 1 ) h_t = \tanh(U x_t + W h_{t-1}) ht=tanh(Uxt+Wht−1)

其中, h t h_t ht 是当前时刻的隐藏状态, x t x_t xt 是当前输入, h t − 1 h_{t-1} ht−1 是上一时刻的隐藏状态, U U U 和 W W W 是共享的权重矩阵。

数据准备

在实现 RNN 的计算过程之前,首先需要准备输入数据。我们可以先定义一个简单的词表,并为句子"播放周杰伦的《稻香》"中的每个词生成一个随机的词向量,将它们组合成形状为 (1, 4, 128) 的张量,作为 RNN 模型的输入;同时也设置了一些基本参数(例如将隐藏节点数设为 3,即 H=3,以便和前文的 RNN 结构图对应,实际应用中一般会远大于 3),并通过 prepare_inputs 函数将这一数据准备过程封装起来。具体代码如下:

python 复制代码
import numpy as np

# (B, T, E, H) 分别表示 批次/序列长度/输入维度/隐藏维度
B, E, H = 1, 128, 3

def prepare_inputs():
    """
    使用 NumPy 准备输入数据
    使用示例句子: "播放 周杰伦 的 《稻香》"
    构造最小词表和随机(可复现)词向量, 生成形状为 (B, T, E) 的输入张量。
    """
    np.random.seed(42)
    vocab = {"播放": 0, "周杰伦": 1, "的": 2, "《稻香》": 3}
    tokens = ["播放", "周杰伦", "的", "《稻香》"]
    ids = [vocab[t] for t in tokens]

    # 词向量表: (V, E)
    V = len(vocab)
    emb_table = np.random.randn(V, E).astype(np.float32)

    # 取出序列词向量并加上 batch 维度: (B, T, E)
    x_np = emb_table[ids][None]
    return tokens, x_np

基于 NumPy 实现 RNN

python 复制代码
def manual_rnn_numpy(x_np, U_np, W_np):
    B_local, T_local, _ = x_np.shape
    # 初始化 h_0 为零向量
    h_prev = np.zeros((B_local, H), dtype=np.float32)
    
    steps = []
    # 按时间步循环
    for t in range(T_local):
        x_t = x_np[:, t, :]
        # 核心公式实现
        h_t = np.tanh(x_t @ U_np + h_prev @ W_np)
        steps.append(h_t)
        h_prev = h_t # 更新状态
        
    return np.stack(steps, axis=1), h_prev

PyTorch 的 nn.RNN 实现

python 复制代码
def pytorch_rnn_forward(x, U, W):
    rnn = nn.RNN(
        input_size=E,
        hidden_size=H,
        num_layers=1,
        nonlinearity='tanh',
        bias=False,
        batch_first=True,
        bidirectional=False,
    )
    with torch.no_grad():
        # PyTorch 内部存放的是转置后的权重
        rnn.weight_ih_l0.copy_(U.T)
        rnn.weight_hh_l0.copy_(W.T)
    y, h_n = rnn(x)
    return y, h_n.squeeze(0)

参数解析

  • input_size( E E E): 输入特征 x t x_t xt 的维度。在 NLP 中,这通常是词嵌入的维度 embedding_dim。
  • hidden_size( H H H): 隐藏状态 h t h_t ht 的维度。这代表了 RNN "记忆"的容量,也是其隐藏层的节点数。
  • num_layers: RNN 的层数。默认是1。如果大于1,会构成一个"堆叠 RNN",即前一层RNN在所有时间步的输出,会作为后一层 RNN 的输入。
  • bias: 是否使用偏置项。默认为 True。如果为真,则公式会变为 h t = tanh ⁡ ( U x t + b i h + W h t − 1 + b h h ) h_t = \tanh(U x_t + b_{ih} + W h_{t-1} + b_{hh}) ht=tanh(Uxt+bih+Wht−1+bhh)。在示例中设为 False 以便与手写版本对齐。
  • batch_first: 一个非常重要的维度顺序参数。默认为 False,此时输入张量的形状应为 (T, B, E)。在代码中设为 True,使得输入形状为更符合直觉的 (B, T, E),其中 B是批次大小,T是序列长度。
  • bidirectional: 是否构建一个双向RNN。默认为 False。双向RNN能同时考虑过去和未来的上下文。

数值对齐验证

python 复制代码
# 将NumPy结果转回PyTorch张量
out_manual = torch.from_numpy(out_manual_np)

# 使用 allclose 进行浮点数精度下的严格比较
print("逐步输出一致:", torch.allclose(out_manual, out_torch, atol=1e-6))
# 输出: True

完整代码

python 复制代码
import torch
import torch.nn as nn
import numpy as np

# 约定: (B, T, E, H) 分别表示 批次/序列长度/输入维度/隐藏维度
B, E, H = 1, 128, 3


def prepare_inputs():
    """
    使用 NumPy 准备输入数据
    使用示例句子: "播放 周杰伦 的 《稻香》"
    构造最小词表和随机(可复现)词向量, 生成形状为 (B, T, E) 的输入张量。
    """
    np.random.seed(42)
    vocab = {"播放": 0, "周杰伦": 1, "的": 2, "《稻香》": 3}
    tokens = ["播放", "周杰伦", "的", "《稻香》"]
    ids = [vocab[t] for t in tokens]

    # 词向量表: (V, E)
    V = len(vocab)
    emb_table = np.random.randn(V, E).astype(np.float32)

    # 取出序列词向量并加上 batch 维度: (B, T, E)
    x_np = emb_table[ids][None]
    return tokens, x_np


def manual_rnn_numpy(x_np, U_np, W_np):
    """
    使用 NumPy 手动实现 RNN(无偏置): h_t = tanh(U x_t + W h_{t-1})
    
    Args:
        x_np: (B, T, E)
        U_np: (E, H)
        W_np: (H, H)
    Returns:
        outputs: (B, T, H)
        final_h: (B, H)
    """
    B_local, T_local, _ = x_np.shape
    h_prev = np.zeros((B_local, H), dtype=np.float32)
    steps = []
    for t in range(T_local):
        x_t = x_np[:, t, :]
        h_t = np.tanh(x_t @ U_np + h_prev @ W_np)
        steps.append(h_t)
        h_prev = h_t
    outputs = np.stack(steps, axis=1)
    return outputs, h_prev


def pytorch_rnn_forward(x, U, W):
    """
    使用api nn.RNN (tanh, bias=False)。
    Returns:
        outputs: (B, T, H)
        final_h: (B, H)
    """
    rnn = nn.RNN(
        input_size=E,
        hidden_size=H,
        num_layers=1,
        nonlinearity='tanh',
        bias=False,
        batch_first=True,
        bidirectional=False,
    )
    with torch.no_grad():
        # PyTorch 内部存放的是转置后的权重
        rnn.weight_ih_l0.copy_(U.T)
        rnn.weight_hh_l0.copy_(W.T)
    y, h_n = rnn(x)
    return y, h_n.squeeze(0)


def main():
    _, x_np = prepare_inputs()

    # PyTorch 张量,用于 nn.RNN 模块
    x = torch.from_numpy(x_np).float()
    
    # 使用可学习参数 U, W(无偏置)
    torch.manual_seed(7)
    U = torch.randn(E, H)
    W = torch.randn(H, H)

    # --- 手写 RNN (使用 NumPy) ---
    U_np = U.detach().numpy()
    W_np = W.detach().numpy()

    print("--- 手写 RNN (NumPy) ---")
    out_manual_np, hT_manual_np = manual_rnn_numpy(x_np, U_np, W_np)
    print("输入形状:", x_np.shape)
    print("手写输出形状:", out_manual_np.shape)
    print("手写最终隐藏形状:", hT_manual_np.shape)

    print("\n--- PyTorch nn.RNN ---")
    out_torch, hT_torch = pytorch_rnn_forward(x, U, W)
    print("模块输出形状:", out_torch.shape)
    print("模块最终隐藏形状:", hT_torch.shape)

    print("\n--- 对齐验证 ---")
    # 将 NumPy 结果转回 PyTorch 张量以进行比较
    out_manual = torch.from_numpy(out_manual_np)
    hT_manual = torch.from_numpy(hT_manual_np)

    print("逐步输出一致:", torch.allclose(out_manual, out_torch, atol=1e-6))
    print("最终隐藏一致:", torch.allclose(hT_manual, hT_torch, atol=1e-6))
    print("最后一步输出等于最终隐藏:", torch.allclose(out_torch[:, -1, :], hT_torch, atol=1e-6))


if __name__ == "__main__":
    main()

从零实现一个 LSTM

公式回顾

公式回顾

这里我们同样实现一个不含偏置项的简化版 LSTM,其计算公式如下:

遗忘门: f t = σ ( U f x t + W f h t − 1 ) f_t = \sigma(U_f x_t + W_f h_{t-1}) ft=σ(Ufxt+Wfht−1)

输入门: i t = σ ( U i x t + W i h t − 1 ) i_t = \sigma(U_i x_t + W_i h_{t-1}) it=σ(Uixt+Wiht−1)

候选记忆: c ~ t = tanh ⁡ ( U c x t + W c h t − 1 ) \tilde{c}t = \tanh(U_c x_t + W_c h{t-1}) c~t=tanh(Ucxt+Wcht−1)

细胞状态更新: c t = f t ⊙ c t − 1 + i t ⊙ c ~ t c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_t ct=ft⊙ct−1+it⊙c~t

输出门: o t = σ ( U o x t + W o h t − 1 ) o_t = \sigma(U_o x_t + W_o h_{t-1}) ot=σ(Uoxt+Woht−1)

隐藏状态更新: h t = o t ⊙ tanh ⁡ ( c t ) h_t = o_t \odot \tanh(c_t) ht=ot⊙tanh(ct)

基于 NumPy 实现 LSTM

python 复制代码
def manual_lstm_numpy(x_np, weights):
    U_f, W_f, U_i, W_i, U_c, W_c, U_o, W_o = weights
    B_local, T_local, _ = x_np.shape
    h_prev = np.zeros((B_local, H), dtype=np.float32)
    c_prev = np.zeros((B_local, H), dtype=np.float32)
    
    steps = []
    # 按时间步循环
    for t in range(T_local):
        x_t = x_np[:, t, :]
        
        # 1. 遗忘门
        f_t = sigmoid(x_t @ U_f + h_prev @ W_f)
        
        # 2. 输入门与候选记忆
        i_t = sigmoid(x_t @ U_i + h_prev @ W_i)
        c_tilde_t = np.tanh(x_t @ U_c + h_prev @ W_c)
        
        # 3. 更新细胞状态
        c_t = f_t * c_prev + i_t * c_tilde_t
        
        # 4. 输出门与隐藏状态
        o_t = sigmoid(x_t @ U_o + h_prev @ W_o)
        h_t = o_t * np.tanh(c_t)
        
        steps.append(h_t)
        h_prev, c_prev = h_t, c_t
        
    outputs = np.stack(steps, axis=1)
    return outputs, h_prev, c_prev

LSTM 的工作流程:

(1)初始化: h_prev 和 c_prev 分别被初始化为零向量,作为处理序列开始前的"短期记忆"和"长期记忆"。

(2)逐帧处理: for 循环遍历序列中的每一个时间步。

(3)核心计算: 循环内部的计算严格遵循了 LSTM 的四个步骤:

  • 计算遗忘门 f_t,决定要从 c_prev 中忘记多少信息。
  • 接着计算输入门 i_t 和候选记忆 c_tilde_t,准备要写入的新信息。
  • 然后,通过 c_t = f_t * c_prev + i_t * c_tilde_t 更新细胞状态,实现了信息的遗忘和记忆。
  • 最后计算输出门 o_t 并生成新的隐藏状态 h_t。

(4)状态更新: h_prev, c_prev = h_t, c_t 将当前计算出的状态传递给下一个时间步,完成"循环"过程。

通过这个实现,可以直观地看到 LSTM 是如何通过门控机制,在每个时间步对信息流进行控制的。

完整代码

python 复制代码
import numpy as np

B, E, H = 1, 128, 3

def prepare_inputs():
    """
    使用 NumPy 准备输入数据
    使用示例句子: "播放 周杰伦 的 《稻香》"
    构造最小词表和随机(可复现)词向量, 生成形状为 (B, T, E) 的输入张量。
    """
    np.random.seed(42)
    vocab = {"播放": 0, "周杰伦": 1, "的": 2, "《稻香》": 3}
    tokens = ["播放", "周杰伦", "的", "《稻香》"]
    ids = [vocab[t] for t in tokens]

    # 词向量表: (V, E)
    V = len(vocab)
    emb_table = np.random.randn(V, E).astype(np.float32)

    # 取出序列词向量并加上 batch 维度: (B, T, E)
    x_np = emb_table[ids][None]
    return tokens, x_np

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def manual_lstm_numpy(x_np, weights):
    """
    使用 NumPy 手动实现 LSTM (无偏置)
    
    Args:
        x_np: (B, T, E)
        weights: 包含 U_f, W_f, U_i, W_i, U_c, W_c, U_o, W_o 的元组
    Returns:
        outputs: (B, T, H)
        final_h: (B, H)
        final_c: (B, H)
    """
    U_f, W_f, U_i, W_i, U_c, W_c, U_o, W_o = weights
    B_local, T_local, _ = x_np.shape
    h_prev = np.zeros((B_local, H), dtype=np.float32)
    c_prev = np.zeros((B_local, H), dtype=np.float32)
    
    steps = []
    # 按时间步循环
    for t in range(T_local):
        x_t = x_np[:, t, :]
        
        # 1. 遗忘门
        f_t = sigmoid(x_t @ U_f + h_prev @ W_f)
        
        # 2. 输入门与候选记忆
        i_t = sigmoid(x_t @ U_i + h_prev @ W_i)
        c_tilde_t = np.tanh(x_t @ U_c + h_prev @ W_c)
        
        # 3. 更新细胞状态
        c_t = f_t * c_prev + i_t * c_tilde_t
        
        # 4. 输出门与隐藏状态
        o_t = sigmoid(x_t @ U_o + h_prev @ W_o)
        h_t = o_t * np.tanh(c_t)
        
        steps.append(h_t)
        h_prev, c_prev = h_t, c_t
        
    outputs = np.stack(steps, axis=1)
    return outputs, h_prev, c_prev


def main():
    _, x_np = prepare_inputs()
    
    # 初始化8个权重矩阵
    np.random.seed(7)
    U_f, W_f = np.random.randn(E, H).astype(np.float32), np.random.randn(H, H).astype(np.float32)
    U_i, W_i = np.random.randn(E, H).astype(np.float32), np.random.randn(H, H).astype(np.float32)
    U_c, W_c = np.random.randn(E, H).astype(np.float32), np.random.randn(H, H).astype(np.float32)
    U_o, W_o = np.random.randn(E, H).astype(np.float32), np.random.randn(H, H).astype(np.float32)
    
    weights_np = (U_f, W_f, U_i, W_i, U_c, W_c, U_o, W_o)

    # --- 手写 LSTM (使用 NumPy) ---
    print("--- 手写 LSTM (NumPy) ---")
    out_manual_np, hT_manual_np, cT_manual_np = manual_lstm_numpy(x_np, weights_np)
    print("输入形状:", x_np.shape)
    print("手写输出形状:", out_manual_np.shape)
    print("手写最终隐藏形状:", hT_manual_np.shape)
    print("手写最终细胞形状:", cT_manual_np.shape)


if __name__ == "__main__":
    main()

参考资料

https://github.com/datawhalechina/base-llm/blob/main/docs/chapter3/08_RNN.md
https://github.com/datawhalechina/base-llm/blob/main/docs/chapter3/09_LSTM%26GRU.md

相关推荐
风之子npu2 小时前
CPU基础知识(1)
笔记
JeffDingAI2 小时前
【Datawhale学习笔记】预训练模型实战
笔记·学习
来生硬件工程师2 小时前
【PCB设计笔记】PCB布局时,如何快速互换器件位置?(Altium Designer 25)
笔记
ljt27249606612 小时前
Flutter笔记--ValueNotifier
笔记·flutter
GISer_Jing2 小时前
AI Coding学习——dw|ali(持续更新)
人工智能·学习·prompt·aigc
振华说技能3 小时前
MasterCAM车铣复合都学哪些内容!
学习
世人万千丶3 小时前
鸿蒙跨端框架 Flutter 学习 Day 4:程序生存法则——异常捕获与异步错误处理的熔断艺术
学习·flutter·华为·harmonyos·鸿蒙
阿豪只会阿巴3 小时前
项目心得——发布者和订阅者问题解决思路
linux·开发语言·笔记·python·ubuntu·ros2
军军君013 小时前
Three.js基础功能学习十二:常量与核心
前端·javascript·学习·3d·threejs·three·三维