《动手学深度学习》-56GRN实现

复制代码
import torch
from torch import nn
from torch.nn import functional as F
import test_55RNNesay_realize
import d2l
import test_53LanguageModel
import test_55RNNdifficult_realize
batch_size,num_steps=32,35
train_iter,vocab=test_53LanguageModel.load_data_time_machine(batch_size,num_steps)
#初始化模型参数
def get_params(vocab_size,num_hiddens,device):
    num_inputs=num_outputs=vocab_size
    def normal(shape):
        return torch.randn(shape,device=device)*0.01
    def three():#参数值初始化
        return (normal((num_inputs,num_hiddens)),
                normal((num_hiddens, num_hiddens)),
                torch.zeros(num_hiddens,device=device))
    W_xz, W_hz, b_z=three()
    W_xr, W_hr, b_r = three()
    W_xh, W_hh, b_h = three()
    W_hq=normal((num_hiddens,num_outputs))
    b_q=torch.zeros(num_outputs,device=device)
    params=[W_xz,W_hz,b_z,W_xr, W_hr, b_r,W_xh, W_hh, b_h,W_hq,b_q]#参数有3种,更新门、重置门和模型参数
    for param in params:#令所有参数都可以求梯度
        param.requires_grad=True
    return params
def init_gru_state(batch_size,num_hiddens,device):
    return (torch.zeros((batch_size,num_hiddens),device=device),)#单元素(x,)表示元组
def gru(inputs,state,params):
    W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q=params
    H, =state
    outputs=[]
    for X in inputs:
        Z=torch.sigmoid((X @ W_xz)+(H @ W_hz)+b_z)
        R=torch.sigmoid((X @ W_xr)+(H @ W_hr)+b_r)
        H_tilda=torch.tanh((X @ W_xh)+((R*H) @ W_hh)+b_h)
        H=Z*H+(1-Z)* H_tilda
        y=H @ W_hq+b_q
        outputs.append(y)
    return torch.cat(outputs,dim=0),(H,)
vocab_size,num_hiddens,device=len(vocab),256,d2l.try_gpu()
num_epochs,lr=500,1
model=test_55RNNdifficult_realize.RNNModelScratch(len(vocab),num_hiddens,d2l.try_gpu(),get_params,init_gru_state,gru)
test_55RNNdifficult_realize.train_ch8(model,train_iter,vocab,lr,num_epochs,device)
#简单实现
# num_inputs=vocab_size
# gru_layer=nn.GRU(input_size=num_inputs,hidden_size=num_hiddens)
# model=test_55RNNesay_realize.RNNModel(gru_layer,len(vocab))
# model=model.to(device)
# test_55RNNdifficult_realize.train_ch8(model,train_iter,vocab,lr,num_epochs,device)

总结

1) 读取与编码数据

  1. 读原始文本 → 清洗/切分(字符级就是每个字符;词级就是分词)。

  2. 建词表 vocab:token↔id。

  3. 把整个语料变成一个长的 id 序列 corpus = [id0, id1, ...]

训练目标是:给定 X(一串 token),预测 Y(下一位 token),也就是 Y = X 向后错一位。


2) 小批量采样:顺序 vs 随机(两种都行,但 state 处理不同)

A. 随机采样(random sampling)

做法 :从 corpus 中随机抽取许多长度为 num_steps 的片段组成 batch。不同 batch 之间通常互不连续。

优点 :样本独立、训练稳定、实现简单。
state 处理 :因为片段之间不连续,每个 batch 开始都把 state 置零 (或 None 让框架自动置零)。
适用:大多数情况下都可以,尤其是你不想纠结"跨 batch 传递状态"。

B. 顺序采样(sequential / consecutive sampling)

做法 :把 corpus 切成 batch_size 条"连续子序列",然后按时间向前滑动取 num_steps 的窗口。这样 同一个 batch 的第 i 行 在下一次迭代里和上一次迭代是连续的。

优点 :能利用跨 batch 的长程连续信息(模拟真正的连续文本流)。
state 处理 :在同一个 epoch 内,state 可以从上一个 batch 直接传到下一个 batch (这就是你说的"转到下一次时间中",完全对)。

但要做关键一步:detach(切断计算图),否则反传会把整段历史都连起来,显存炸裂、梯度爆炸更严重。


3) 定义网络层(RNN / GRU 通用)

典型语言模型结构:

  • 输入:token id →(one-hot 或 embedding)→ RNN/GRU/LSTM → 线性层 → vocab_size → softmax

  • 损失:交叉熵(CrossEntropyLoss),直接喂 logits(不手动 softmax)

你当前写法是 one-hot,所以 input_size = vocab_size;如果换 embedding,就变成 input_size = embed_dim

RNN 和 GRU 的区别:内部递推公式不同 ,但在训练流程上几乎一样;PyTorch 接口也几乎一样(LSTM 例外:state 是 (H, C))。


4) 每个小批量怎么跑(最核心的训练循环)

对每个 batch,拿到 XY

  1. 初始化/携带 state

    • 随机采样:state = zeros(或 None)

    • 顺序采样:第一次迭代 state = zeros,之后 state = detach(state) 再继续用

  2. 前向

    • 输入形状要一致:通常 (num_steps, batch_size, input_size)(你现在就是这样)

    • 得到 logits 形状 (num_steps*batch_size, vocab_size)

  3. 计算损失

    • Y reshape 成一维:(num_steps*batch_size,)

    • loss = CrossEntropyLoss(logits, Y_flat)

  4. 反向与更新

    • optimizer.zero_grad()

    • loss.backward()

    • 梯度裁剪

5)循环

相关推荐
h64648564h4 分钟前
CANN 性能剖析与调优全指南:从 Profiling 到 Kernel 级优化
人工智能·深度学习
心疼你的一切4 分钟前
解密CANN仓库:AIGC的算力底座、关键应用与API实战解析
数据仓库·深度学习·aigc·cann
八零后琐话23 分钟前
干货:程序员必备性能分析工具——Arthas火焰图
开发语言·python
青春不朽5122 小时前
Scrapy框架入门指南
python·scrapy
MZ_ZXD0012 小时前
springboot旅游信息管理系统-计算机毕业设计源码21675
java·c++·vue.js·spring boot·python·django·php
学电子她就能回来吗2 小时前
深度学习速成:损失函数与反向传播
人工智能·深度学习·学习·计算机视觉·github
Coder_Boy_3 小时前
TensorFlow小白科普
人工智能·深度学习·tensorflow·neo4j
全栈老石3 小时前
Python 异步生存手册:给被 JS async/await 宠坏的全栈工程师
后端·python
大模型玩家七七3 小时前
梯度累积真的省显存吗?它换走的是什么成本
java·javascript·数据库·人工智能·深度学习