《动手学深度学习》-55-2RNN的简单实现

复制代码
import torch
from torch import nn
from torch.nn import functional as F
import d2l
import test_53LanguageModel
import test_55RNNdifficult_realize
batch_size,num_steps=32,35
train_iter,vocab=test_53LanguageModel.load_data_time_machine(batch_size,num_steps)
#定义模型
num_hiddens=512
rnn_layer=nn.RNN(len(vocab),num_hiddens)
#使用张量初始化隐藏状态,形状(隐藏数,批量大小,隐藏单元)
state=torch.zeros((1,batch_size,num_hiddens))
# print(state.shape)
# X=torch.randn(size=(num_steps,batch_size,len(vocab)))
# Y,state_new=rnn_layer(X,state)
# print(Y.shape,state_new.shape)
class RNNModel(nn.Module):
    def __init__(self,rnn_layer,vocab_size,**kwargs):
        super(RNNModel,self).__init__(**kwargs)
        self.rnn=rnn_layer
        self.vocab_size=vocab_size
        self.num_hiddens=self.rnn.hidden_size
        if not self.rnn.bidirectional:
            self.num_derections=1
            self.linear=nn.Linear(self.num_hiddens,self.vocab_size)
        else:
            self.num_derections=2
            self.linear=nn.Linear(self.num_hiddens*2,self.vocab_size)
    def forward(self,inputs,state):
        X=F.one_hot(inputs.T.long(),self.vocab_size)#将输入变成(num_steps, batch_size, vocab_size)
        X=X.to(torch.float32)
        y,state=self.rnn(X,state)
        output=self.linear(y.reshape((-1,y.shape[-1])))#将y转成二维(num_steps*batch_size,vocab_size)
        return output,state
    def begin_state(self,device,batch_size=1):
        if not isinstance(self.rnn,nn.LSTM):
            return torch.zeros((self.num_derections*self.rnn.num_layers,batch_size,self.num_hiddens),device=device)#(层数*方向数,batch,隐藏数量)
        else:
            return (torch.zeros((self.num_derections*self.rnn.num_layers,batch_size,self.num_hiddens),device=device),
                    torch.zeros((self.num_derections * self.rnn.num_layers, batch_size, self.num_hiddens),
                                device=device)
            )
device=d2l.try_gpu()
net=RNNModel(rnn_layer,vocab_size=len(vocab))
net=net.to(device=device)
print(test_55RNNdifficult_realize.predict_ch8('time traveller',10,net,vocab,device))
复制代码
num_epochs=500
lr=1
test_55RNNdifficult_realize.train_ch8(net,train_iter,vocab,lr,num_epochs,device)
相关推荐
小陈phd1 分钟前
多模态大模型学习笔记(二十一)—— 基于 Scaling Law方法 的大模型训练算力估算与 GPU 资源配置
笔记·深度学习·学习·自然语言处理·transformer
zm-v-159304339863 分钟前
Python 气象数据处理从入门到精通:机器学习订正 + 深度学习预测完整教程
python·深度学习·机器学习
F_U_N_3 分钟前
轻量化开源知识库落地路径研究:AI赋能、多端集成及合规管理指引
人工智能·开源
丝斯20114 分钟前
AI学习笔记整理(75)——Python学习4
人工智能·笔记·学习
TImCheng06094 分钟前
科学的兴趣评估模型:如何通过低成本试错与深度体验,确定 AI 是否为长期志业?
人工智能
物联网软硬件开发-轨物科技5 分钟前
【轨物洞见】从“人工时代”迈向“视觉语音时代”:轨物科技多模态智能感知与一键顺控专家系统全解析
大数据·人工智能·科技
FindAI发现力量6 分钟前
智能耳机:AI销售场景中的数据采集新范式
人工智能
大傻^8 分钟前
Spring AI Alibaba 向量数据库集成:Milvus与Elasticsearch配置详解
数据库·人工智能·spring·elasticsearch·milvus·springai·springaialibaba
大傻^13 分钟前
Spring AI Alibaba ChatClient实战:流式输出与多轮对话管理
java·人工智能·后端·spring·springai·springaialibaba
1941s13 分钟前
Google Agent Development Kit (ADK) 指南 第四章:Agent 开发与编排
人工智能·python·langchain·agent·adk