《动手学深度学习》-56GRN实现

复制代码
import torch
from torch import nn
from torch.nn import functional as F
import test_55RNNesay_realize
import d2l
import test_53LanguageModel
import test_55RNNdifficult_realize
batch_size,num_steps=32,35
train_iter,vocab=test_53LanguageModel.load_data_time_machine(batch_size,num_steps)
#初始化模型参数
def get_params(vocab_size,num_hiddens,device):
    num_inputs=num_outputs=vocab_size
    def normal(shape):
        return torch.randn(shape,device=device)*0.01
    def three():#参数值初始化
        return (normal((num_inputs,num_hiddens)),
                normal((num_hiddens, num_hiddens)),
                torch.zeros(num_hiddens,device=device))
    W_xz, W_hz, b_z=three()
    W_xr, W_hr, b_r = three()
    W_xh, W_hh, b_h = three()
    W_hq=normal((num_hiddens,num_outputs))
    b_q=torch.zeros(num_outputs,device=device)
    params=[W_xz,W_hz,b_z,W_xr, W_hr, b_r,W_xh, W_hh, b_h,W_hq,b_q]#参数有3种,更新门、重置门和模型参数
    for param in params:#令所有参数都可以求梯度
        param.requires_grad=True
    return params
def init_gru_state(batch_size,num_hiddens,device):
    return (torch.zeros((batch_size,num_hiddens),device=device),)#单元素(x,)表示元组
def gru(inputs,state,params):
    W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q=params
    H, =state
    outputs=[]
    for X in inputs:
        Z=torch.sigmoid((X @ W_xz)+(H @ W_hz)+b_z)
        R=torch.sigmoid((X @ W_xr)+(H @ W_hr)+b_r)
        H_tilda=torch.tanh((X @ W_xh)+((R*H) @ W_hh)+b_h)
        H=Z*H+(1-Z)* H_tilda
        y=H @ W_hq+b_q
        outputs.append(y)
    return torch.cat(outputs,dim=0),(H,)
vocab_size,num_hiddens,device=len(vocab),256,d2l.try_gpu()
num_epochs,lr=500,1
model=test_55RNNdifficult_realize.RNNModelScratch(len(vocab),num_hiddens,d2l.try_gpu(),get_params,init_gru_state,gru)
test_55RNNdifficult_realize.train_ch8(model,train_iter,vocab,lr,num_epochs,device)
#简单实现
# num_inputs=vocab_size
# gru_layer=nn.GRU(input_size=num_inputs,hidden_size=num_hiddens)
# model=test_55RNNesay_realize.RNNModel(gru_layer,len(vocab))
# model=model.to(device)
# test_55RNNdifficult_realize.train_ch8(model,train_iter,vocab,lr,num_epochs,device)

总结

1) 读取与编码数据

  1. 读原始文本 → 清洗/切分(字符级就是每个字符;词级就是分词)。

  2. 建词表 vocab:token↔id。

  3. 把整个语料变成一个长的 id 序列 corpus = [id0, id1, ...]

训练目标是:给定 X(一串 token),预测 Y(下一位 token),也就是 Y = X 向后错一位。


2) 小批量采样:顺序 vs 随机(两种都行,但 state 处理不同)

A. 随机采样(random sampling)

做法 :从 corpus 中随机抽取许多长度为 num_steps 的片段组成 batch。不同 batch 之间通常互不连续。

优点 :样本独立、训练稳定、实现简单。
state 处理 :因为片段之间不连续,每个 batch 开始都把 state 置零 (或 None 让框架自动置零)。
适用:大多数情况下都可以,尤其是你不想纠结"跨 batch 传递状态"。

B. 顺序采样(sequential / consecutive sampling)

做法 :把 corpus 切成 batch_size 条"连续子序列",然后按时间向前滑动取 num_steps 的窗口。这样 同一个 batch 的第 i 行 在下一次迭代里和上一次迭代是连续的。

优点 :能利用跨 batch 的长程连续信息(模拟真正的连续文本流)。
state 处理 :在同一个 epoch 内,state 可以从上一个 batch 直接传到下一个 batch (这就是你说的"转到下一次时间中",完全对)。

但要做关键一步:detach(切断计算图),否则反传会把整段历史都连起来,显存炸裂、梯度爆炸更严重。


3) 定义网络层(RNN / GRU 通用)

典型语言模型结构:

  • 输入:token id →(one-hot 或 embedding)→ RNN/GRU/LSTM → 线性层 → vocab_size → softmax

  • 损失:交叉熵(CrossEntropyLoss),直接喂 logits(不手动 softmax)

你当前写法是 one-hot,所以 input_size = vocab_size;如果换 embedding,就变成 input_size = embed_dim

RNN 和 GRU 的区别:内部递推公式不同 ,但在训练流程上几乎一样;PyTorch 接口也几乎一样(LSTM 例外:state 是 (H, C))。


4) 每个小批量怎么跑(最核心的训练循环)

对每个 batch,拿到 XY

  1. 初始化/携带 state

    • 随机采样:state = zeros(或 None)

    • 顺序采样:第一次迭代 state = zeros,之后 state = detach(state) 再继续用

  2. 前向

    • 输入形状要一致:通常 (num_steps, batch_size, input_size)(你现在就是这样)

    • 得到 logits 形状 (num_steps*batch_size, vocab_size)

  3. 计算损失

    • Y reshape 成一维:(num_steps*batch_size,)

    • loss = CrossEntropyLoss(logits, Y_flat)

  4. 反向与更新

    • optimizer.zero_grad()

    • loss.backward()

    • 梯度裁剪

5)循环

相关推荐
naruto_lnq2 小时前
用户认证与授权:使用JWT保护你的API
jvm·数据库·python
m0_581124192 小时前
Python数据库操作:SQLAlchemy ORM指南
jvm·数据库·python
2401_841495642 小时前
【LeetCode刷题】二叉树的中序遍历
数据结构·python·算法·leetcode··递归·遍历
u0109272712 小时前
机器学习模型部署:将模型转化为Web API
jvm·数据库·python
盼小辉丶2 小时前
PyTorch实战(26)——PyTorch分布式训练
pytorch·分布式·深度学习·分布式训练
2401_838472512 小时前
使用Pandas进行数据分析:从数据清洗到可视化
jvm·数据库·python
郝学胜-神的一滴2 小时前
特征选择利器:深入理解SelectKBest与单变量特征选择
人工智能·python·程序人生·机器学习·数据分析·scikit-learn·sklearn
鹿衔`2 小时前
Apache Spark 任务资源配置与优先级指南
python·spark
Allen_LVyingbo2 小时前
医疗大模型预训练:从硬件选型到合规落地实战(2025总结版)
开发语言·git·python·github·知识图谱·健康医疗