《动手学深度学习》-56GRN实现

复制代码
import torch
from torch import nn
from torch.nn import functional as F
import test_55RNNesay_realize
import d2l
import test_53LanguageModel
import test_55RNNdifficult_realize
batch_size,num_steps=32,35
train_iter,vocab=test_53LanguageModel.load_data_time_machine(batch_size,num_steps)
#初始化模型参数
def get_params(vocab_size,num_hiddens,device):
    num_inputs=num_outputs=vocab_size
    def normal(shape):
        return torch.randn(shape,device=device)*0.01
    def three():#参数值初始化
        return (normal((num_inputs,num_hiddens)),
                normal((num_hiddens, num_hiddens)),
                torch.zeros(num_hiddens,device=device))
    W_xz, W_hz, b_z=three()
    W_xr, W_hr, b_r = three()
    W_xh, W_hh, b_h = three()
    W_hq=normal((num_hiddens,num_outputs))
    b_q=torch.zeros(num_outputs,device=device)
    params=[W_xz,W_hz,b_z,W_xr, W_hr, b_r,W_xh, W_hh, b_h,W_hq,b_q]#参数有3种,更新门、重置门和模型参数
    for param in params:#令所有参数都可以求梯度
        param.requires_grad=True
    return params
def init_gru_state(batch_size,num_hiddens,device):
    return (torch.zeros((batch_size,num_hiddens),device=device),)#单元素(x,)表示元组
def gru(inputs,state,params):
    W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q=params
    H, =state
    outputs=[]
    for X in inputs:
        Z=torch.sigmoid((X @ W_xz)+(H @ W_hz)+b_z)
        R=torch.sigmoid((X @ W_xr)+(H @ W_hr)+b_r)
        H_tilda=torch.tanh((X @ W_xh)+((R*H) @ W_hh)+b_h)
        H=Z*H+(1-Z)* H_tilda
        y=H @ W_hq+b_q
        outputs.append(y)
    return torch.cat(outputs,dim=0),(H,)
vocab_size,num_hiddens,device=len(vocab),256,d2l.try_gpu()
num_epochs,lr=500,1
model=test_55RNNdifficult_realize.RNNModelScratch(len(vocab),num_hiddens,d2l.try_gpu(),get_params,init_gru_state,gru)
test_55RNNdifficult_realize.train_ch8(model,train_iter,vocab,lr,num_epochs,device)
#简单实现
# num_inputs=vocab_size
# gru_layer=nn.GRU(input_size=num_inputs,hidden_size=num_hiddens)
# model=test_55RNNesay_realize.RNNModel(gru_layer,len(vocab))
# model=model.to(device)
# test_55RNNdifficult_realize.train_ch8(model,train_iter,vocab,lr,num_epochs,device)

总结

1) 读取与编码数据

  1. 读原始文本 → 清洗/切分(字符级就是每个字符;词级就是分词)。

  2. 建词表 vocab:token↔id。

  3. 把整个语料变成一个长的 id 序列 corpus = [id0, id1, ...]

训练目标是:给定 X(一串 token),预测 Y(下一位 token),也就是 Y = X 向后错一位。


2) 小批量采样:顺序 vs 随机(两种都行,但 state 处理不同)

A. 随机采样(random sampling)

做法 :从 corpus 中随机抽取许多长度为 num_steps 的片段组成 batch。不同 batch 之间通常互不连续。

优点 :样本独立、训练稳定、实现简单。
state 处理 :因为片段之间不连续,每个 batch 开始都把 state 置零 (或 None 让框架自动置零)。
适用:大多数情况下都可以,尤其是你不想纠结"跨 batch 传递状态"。

B. 顺序采样(sequential / consecutive sampling)

做法 :把 corpus 切成 batch_size 条"连续子序列",然后按时间向前滑动取 num_steps 的窗口。这样 同一个 batch 的第 i 行 在下一次迭代里和上一次迭代是连续的。

优点 :能利用跨 batch 的长程连续信息(模拟真正的连续文本流)。
state 处理 :在同一个 epoch 内,state 可以从上一个 batch 直接传到下一个 batch (这就是你说的"转到下一次时间中",完全对)。

但要做关键一步:detach(切断计算图),否则反传会把整段历史都连起来,显存炸裂、梯度爆炸更严重。


3) 定义网络层(RNN / GRU 通用)

典型语言模型结构:

  • 输入:token id →(one-hot 或 embedding)→ RNN/GRU/LSTM → 线性层 → vocab_size → softmax

  • 损失:交叉熵(CrossEntropyLoss),直接喂 logits(不手动 softmax)

你当前写法是 one-hot,所以 input_size = vocab_size;如果换 embedding,就变成 input_size = embed_dim

RNN 和 GRU 的区别:内部递推公式不同 ,但在训练流程上几乎一样;PyTorch 接口也几乎一样(LSTM 例外:state 是 (H, C))。


4) 每个小批量怎么跑(最核心的训练循环)

对每个 batch,拿到 XY

  1. 初始化/携带 state

    • 随机采样:state = zeros(或 None)

    • 顺序采样:第一次迭代 state = zeros,之后 state = detach(state) 再继续用

  2. 前向

    • 输入形状要一致:通常 (num_steps, batch_size, input_size)(你现在就是这样)

    • 得到 logits 形状 (num_steps*batch_size, vocab_size)

  3. 计算损失

    • Y reshape 成一维:(num_steps*batch_size,)

    • loss = CrossEntropyLoss(logits, Y_flat)

  4. 反向与更新

    • optimizer.zero_grad()

    • loss.backward()

    • 梯度裁剪

5)循环

相关推荐
执风挽^1 小时前
Python基础编程题2
开发语言·python·算法·visual studio code
纤纡.2 小时前
PyTorch 入门精讲:从框架选择到 MNIST 手写数字识别实战
人工智能·pytorch·python
kjkdd2 小时前
6.1 核心组件(Agent)
python·ai·语言模型·langchain·ai编程
小镇敲码人2 小时前
剖析CANN框架中Samples仓库:从示例到实战的AI开发指南
c++·人工智能·python·华为·acl·cann
萧鼎2 小时前
Python 包管理的“超音速”革命:全面上手 uv 工具链
开发语言·python·uv
程序员清洒2 小时前
CANN模型安全:从对抗防御到隐私保护的全栈安全实战
人工智能·深度学习·安全
User_芊芊君子2 小时前
CANN_PTO_ISA虚拟指令集全解析打造跨平台高性能计算的抽象层
人工智能·深度学习·神经网络
alvin_20053 小时前
python之OpenGL应用(二)Hello Triangle
python·opengl
铁蛋AI编程实战3 小时前
通义千问 3.5 Turbo GGUF 量化版本地部署教程:4G 显存即可运行,数据永不泄露
java·人工智能·python
HyperAI超神经3 小时前
在线教程|DeepSeek-OCR 2公式/表格解析同步改善,以低视觉token成本实现近4%的性能跃迁
开发语言·人工智能·深度学习·神经网络·机器学习·ocr·创业创新