【循环神经网络5】GRU模型实战,从零开始构建文本生成器

【循环神经网络3】门控循环单元GRU详解-CSDN博客https://blog.csdn.net/colus_SEU/article/details/152218119?spm=1001.2014.3001.5501在以上▲笔记中,我们详解了GRU模型,接下来我们通过实战来进一步理解。

1 项目概述

本项目旨在构建一个字符级的语言模型,使用PyTorch框架实现GRU(门控循环单元)网络,来学习莎士比亚戏剧文本的语言模式,并生成具有莎士比亚风格的新文本。

  • 核心任务:文本生成。

  • 数据集 :Kaggle上的莎士比亚戏剧数据集,包含超过11万行角色台词,下载地址:Shakespeare plays(直接点击download下载zip文件解压后将文件夹中的Shakespeare_data.csv文件放到项目目录指定位置即可)

  • 最终成果:一个能够生成语法基本正确、词汇风格贴近莎士比亚戏剧的文本的GRU模型。

2 项目目录

python 复制代码
 GRU_shakespeare_sonnets/
 ├── data/
 │   └── raw/
 │       └── Shakespeare_data.csv
 ├── models/
 │   └── gru_model.pth
 ├── src/
 │   ├── data_loader.py   # 数据下载、预处理和加载
 │   ├── model.py         # GRU模型定义
 │   ├── train.py         # 训练脚本
 │   └── generate.py      # 文本生成脚本
 ├── vocab.pkl            # 保存的词汇表
 └── requirements.txt     # 项目依赖

3 项目代码

python 复制代码
 # src/model.py
 import torch
 import torch.nn as nn
 ​
 class GRUModel(nn.Module):
     def __init__(self, vocab_size, embed_dim=256, hidden_dim=512, num_layers=2):
         '''
         :param vocab_size: 词汇表的大小,即模型中使用的不同单词的数量。
         :param embed_dim: 词嵌入的维度,默认为256。词嵌入是将单词转换为固定长度向量的过程,这些向量捕获单词的语义信息。
         :param hidden_dim: GRU隐藏层的维度,默认为512。这决定了模型的学习能力和复杂性。
         :param num_layers: GRU层的数量,默认为2。多层GRU可以捕获更复杂的序列模式。
         '''
         super(GRUModel, self).__init__()
         self.hidden_dim = hidden_dim
         self.num_layers = num_layers
 ​
         # 词嵌入层
         self.embedding = nn.Embedding(vocab_size, embed_dim)
         # 创建一个词嵌入层,使用nn.Embedding类。这个层将词汇表中的每个单词映射到一个embed_dim维的向量。
 ​
         # GRU层
         self.gru = nn.GRU(embed_dim, hidden_dim, num_layers, batch_first=True)
         # 创建一个GRU层,使用nn.GRU类。
         # batch_first=True:表示输入数据的第一个维度是批量大小,这使得数据组织更加直观。
 ​
         # 全连接输出层
         self.fc = nn.Linear(hidden_dim, vocab_size)
         # 这个层将GRU的输出映射到词汇表的大小,从而预测下一个单词。hidden_dim是输入尺寸,vocab_size是输出尺寸。
 ​
 ​
     def forward(self, x, hidden):
         # x shape: (batch_size, seq_length)
         # hidden shape: (num_layers, batch_size, hidden_dim)
 ​
         # 嵌入层
         embedded = self.embedding(x)  # shape: (batch_size, seq_length, embed_dim)
 ​
         # GRU层
         # out shape: (batch_size, seq_length, hidden_dim)
         # hidden shape: (num_layers, batch_size, hidden_dim)
         out, hidden = self.gru(embedded, hidden)
 ​
         # 将GRU的输出传入全连接层
         # 我们需要将out重塑以便通过fc层
         out = out.contiguous().view(-1, self.hidden_dim)  # shape: (batch_size * seq_length, hidden_dim)
         out = self.fc(out)  # shape: (batch_size * seq_length, vocab_size)
 ​
         return out, hidden
 ​
     def init_hidden(self, batch_size, device):
         """初始化隐藏状态"""
         # 形状: (num_layers, batch_size, hidden_dim)
         weight = next(self.parameters()).data
         hidden = weight.new(self.num_layers, batch_size, self.hidden_dim).zero_().to(device)
         return hidden
复制代码
 ​
python 复制代码
 # src/train.py
 ​
 import torch
 import torch.nn as nn
 import torch.optim as optim
 import os
 import pickle
 from tqdm import tqdm
 ​
 from data_loader import get_data_loaders, PROJECT_ROOT  # 导入项目根路径
 from model import GRUModel
 ​
 ​
 # --- 超参数配置 ---
 class Config:
     # 路径配置 (使用动态路径)
     MODEL_DIR = os.path.join(PROJECT_ROOT, "models")
     VOCAB_FILE = os.path.join(PROJECT_ROOT, "vocab.pkl")
     MODEL_SAVE_PATH = os.path.join(MODEL_DIR, "gru_model.pth")
 ​
     BATCH_SIZE = 128
     SEQUENCE_LENGTH = 100
     N_EPOCHS = 20
     LEARNING_RATE = 0.0005
     EMBED_DIM = 256
     HIDDEN_DIM = 512
     N_LAYERS = 2
 ​
     # 创建模型目录
     if not os.path.exists(MODEL_DIR):
         os.makedirs(MODEL_DIR)
 ​
 ​
 def train():
     """主训练函数"""
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
     print(f"使用设备: {device}")
 ​
     train_loader, val_loader, vocab_size = get_data_loaders(
         batch_size=Config.BATCH_SIZE,
         sequence_length=Config.SEQUENCE_LENGTH
     )
 ​
     model = GRUModel(
         vocab_size=vocab_size,
         embed_dim=Config.EMBED_DIM,
         hidden_dim=Config.HIDDEN_DIM,
         num_layers=Config.N_LAYERS
     ).to(device)
 ​
     criterion = nn.CrossEntropyLoss()
     optimizer = optim.Adam(model.parameters(), lr=Config.LEARNING_RATE)
 ​
     print("模型训练开始...")
     for epoch in range(Config.N_EPOCHS):
         model.train()
         total_loss = 0
 ​
         progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{Config.N_EPOCHS}")
         for i, (inputs, targets) in enumerate(progress_bar):
             # 将数据移动到指定设备
             inputs, targets = inputs.to(device), targets.to(device)
 ​
             # --- 关键修改 ---
             # 在每个批次开始时,根据当前批次的大小重新初始化隐藏状态
             current_batch_size = inputs.size(0)
             hidden = model.init_hidden(current_batch_size, device)
 ​
             # --- 关键修改:添加梯度裁剪 ---
             # 防止梯度爆炸,将梯度的L2范数限制在5.0以内
             torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
 ​
             # 梯度清零
             optimizer.zero_grad()
 ​
             # 前向传播
             output, hidden = model(inputs, hidden)
 ​
             # 将 targets 展平,以便与 output 的维度匹配
             targets = targets.view(-1)
 ​
             # 计算损失
             loss = criterion(output, targets)
 ​
             # 反向传播和优化
             loss.backward()
             optimizer.step()
 ​
             total_loss += loss.item()
             progress_bar.set_postfix(loss=loss.item())
 ​
         avg_loss = total_loss / len(train_loader)
         print(f"Epoch {epoch + 1} 完成, 平均训练损失: {avg_loss:.4f}")
 ​
         torch.save({
             'epoch': epoch,
             'model_state_dict': model.state_dict(),
             'optimizer_state_dict': optimizer.state_dict(),
             'loss': avg_loss,
             'config': Config
         }, Config.MODEL_SAVE_PATH)
         print(f"模型已保存至 {Config.MODEL_SAVE_PATH}")
 ​
     print("模型训练完成!")
 ​
 ​
 if __name__ == '__main__':
     train()
复制代码
 ​ 
python 复制代码
# src/generate.py
 ​
 import torch
 import pickle
 import os
 import random
 ​
 from model import GRUModel
 from train import Config, PROJECT_ROOT  # 导入项目根路径
 ​
 ​
 def generate_text(start_str="shall i compare thee to a summer's day?\n", gen_length=500, temperature=0.8):
     """使用训练好的模型生成文本"""
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 ​
     # 路径配置 (使用动态路径)
     VOCAB_FILE = os.path.join(PROJECT_ROOT, "vocab.pkl")
     MODEL_SAVE_PATH = Config.MODEL_SAVE_PATH
 ​
     # 1. 加载词汇表
     try:
         with open(VOCAB_FILE, 'rb') as f:
             char_to_ix, ix_to_char, vocab_size = pickle.load(f)
     except FileNotFoundError:
         print(f"错误: 词汇表文件 {VOCAB_FILE} 未找到。请先运行训练脚本。")
         return
 ​
     # 2. 初始化模型
     model = GRUModel(
         vocab_size=vocab_size,
         embed_dim=Config.EMBED_DIM,
         hidden_dim=Config.HIDDEN_DIM,
         num_layers=Config.N_LAYERS
     ).to(device)
 ​
     # 3. 加载模型权重
     try:
         checkpoint = torch.load(MODEL_SAVE_PATH, map_location=device)
         model.load_state_dict(checkpoint['model_state_dict'])
         print(f"成功加载模型权重: {MODEL_SAVE_PATH}")
     except FileNotFoundError:
         print(f"错误: 模型文件 {MODEL_SAVE_PATH} 未找到。请先运行训练脚本。")
         return
 ​
     model.eval()
 ​
     input_seq = [char_to_ix[ch] for ch in start_str]
     input_tensor = torch.tensor(input_seq, dtype=torch.long).unsqueeze(0).to(device)
     hidden = model.init_hidden(1, device)
 ​
     generated_text = start_str
     with torch.no_grad():
         for _ in range(gen_length):
             output, hidden = model(input_tensor, hidden)
             output = output.squeeze(0).div(temperature).exp()
             top_i = torch.multinomial(output, 1)[0]
             predicted_char = ix_to_char[top_i.item()]
             generated_text += predicted_char
             input_tensor = torch.tensor([[top_i]], dtype=torch.long).to(device)
 ​
     print("\n--- 生成的文本 ---")
     print(generated_text)
     print("------------------\n")
 ​
 ​
 if __name__ == '__main__':
     generate_text(start_str="from fairest creatures we desire increase,\n", gen_length=400, temperature=0.8)
复制代码
 ​
python 复制代码
 # src/data_loader.py 
 import os
 import pandas as pd
 import torch
 from torch.utils.data import Dataset, DataLoader
 import pickle
 import sys
 ​
 # --- 路径配置 (本地优化) ---
 # 获取当前脚本所在的目录
 CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
 # 获取项目根目录
 PROJECT_ROOT = os.path.dirname(CURRENT_DIR)
 DATA_DIR = os.path.join(PROJECT_ROOT, "data")
 RAW_DATA_DIR = os.path.join(DATA_DIR, "raw")
 # 修改为你实际的文件名
 CSV_FILE = os.path.join(RAW_DATA_DIR, "Shakespeare_data.csv")
 VOCAB_FILE = os.path.join(PROJECT_ROOT, "vocab.pkl")
 ​
 ​
 def check_data_exists():
     """检查数据文件是否存在,如果不存在则给出提示并退出"""
     if not os.path.exists(CSV_FILE):
         print("=" * 50)
         print("错误:数据文件未找到!")
         print(f"请确保数据文件已放置在以下位置:\n{CSV_FILE}")
         print("\n请按以下步骤操作:")
         print("1. 访问 https://www.kaggle.com/datasets/kingburrito666/shakespeare-sonnets")
         print("2. 点击 'Download' 按钮。")
         print("3. 解压下载的 zip 文件。")
         print("4. 将 'Shakespeare_data.csv' 文件复制到 'data/raw/' 目录下。")
         print("=" * 50)
         sys.exit(1)  # 退出程序
 ​
 ​
 def prepare_data(sequence_length=100):
     """加载、预处理数据并创建DataLoader"""
     # 1. 检查数据文件
     check_data_exists()
 ​
     # 2. 加载数据
     print(f"正在从 {CSV_FILE} 加载数据...")
     df = pd.read_csv(CSV_FILE)
 ​
     # 3. 数据清洗:过滤掉舞台说明,只保留角色台词
     player_lines_df = df.dropna(subset=['Player'])
     print(f"数据加载完成,共加载了 {len(player_lines_df)} 行台词。")
 ​
     # 4. 合并所有台词为一个长字符串
     text = player_lines_df['PlayerLine'].str.cat(sep='\n')
 ​
     # 5. 创建字符集
     chars = sorted(list(set(text)))
     vocab_size = len(chars)
 ​
     char_to_ix = {ch: i for i, ch in enumerate(chars)}
     ix_to_char = {i: ch for i, ch in enumerate(chars)}
 ​
     # 保存词汇表
     with open(VOCAB_FILE, 'wb') as f:
         pickle.dump((char_to_ix, ix_to_char, vocab_size), f)
     print(f"词汇表已保存至 {VOCAB_FILE},词汇量大小: {vocab_size}")
 ​
     # 6. 将整个文本转换为整数序列
     text_as_int = [char_to_ix[ch] for ch in text]
 ​
     # 7. 创建输入序列和目标序列 (关键修改!)
     input_sequences = []
     target_sequences = []
     for i in range(0, len(text_as_int) - sequence_length):
         # 输入是从 i 到 i+sequence_length-1
         input_sequences.append(text_as_int[i: i + sequence_length])
         # 目标是从 i+1 到 i+sequence_length (向左移动一位)
         target_sequences.append(text_as_int[i + 1: i + sequence_length + 1])
 ​
     print(f"总序列数: {len(input_sequences)}")
 ​
     # 8. 创建自定义Dataset (关键修改!)
     class ShakespeareDataset(Dataset):
         def __init__(self, sequences, targets):
             self.sequences = sequences
             self.targets = targets
 ​
         def __len__(self):
             return len(self.sequences)
 ​
         def __getitem__(self, idx):
             # 返回一个输入序列和对应的目标序列
             return torch.tensor(self.sequences[idx], dtype=torch.long), torch.tensor(self.targets[idx],
                                                                                      dtype=torch.long)
 ​
     dataset = ShakespeareDataset(input_sequences, target_sequences)
 ​
     return dataset, vocab_size
 ​
 ​
 def get_data_loaders(batch_size=64, sequence_length=100):
     """获取训练和验证DataLoader"""
     dataset, vocab_size = prepare_data(sequence_length)
 ​
     train_size = int(0.8 * len(dataset))
     val_size = len(dataset) - train_size
     train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
 ​
     train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
     val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
 ​
     return train_loader, val_loader, vocab_size
 ​

4 结果分析

模型训练

运行train.py,可以看到以下类似输出:

python 复制代码
 使用设备: cuda
 正在从 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/data/raw/Shakespeare_data.csv 加载数据...
 数据加载完成,共加载了 111389 行台词。
 词汇表已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/vocab.pkl,词汇量大小: 77
 总序列数: 4365922
 模型训练开始...
 Epoch 1/20: 100%|██████████| 27288/27288 [08:51<00:00, 51.34it/s, loss=0.953]
 Epoch 1 完成, 平均训练损失: 1.1017
 模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
 Epoch 2/20: 100%|██████████| 27288/27288 [08:41<00:00, 52.28it/s, loss=0.927]
 Epoch 2 完成, 平均训练损失: 0.9059
 模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
 Epoch 3/20: 100%|██████████| 27288/27288 [08:50<00:00, 51.47it/s, loss=0.728]
 Epoch 3 完成, 平均训练损失: 0.8570
 模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
 Epoch 4/20: 100%|██████████| 27288/27288 [08:49<00:00, 51.58it/s, loss=0.992]
 Epoch 4 完成, 平均训练损失: 0.8321
 模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
 Epoch 5/20: 100%|██████████| 27288/27288 [08:51<00:00, 51.34it/s, loss=0.912]
 Epoch 5 完成, 平均训练损失: 0.8161
 模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
 Epoch 6/20: 100%|██████████| 27288/27288 [08:49<00:00, 51.55it/s, loss=0.786]
 Epoch 6 完成, 平均训练损失: 0.8047
 模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
 Epoch 7/20: 100%|██████████| 27288/27288 [08:50<00:00, 51.48it/s, loss=0.797]
 Epoch 7 完成, 平均训练损失: 0.7960
 模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
 Epoch 8/20: 100%|██████████| 27288/27288 [08:41<00:00, 52.37it/s, loss=0.74]
 Epoch 8 完成, 平均训练损失: 0.7892
 模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
 Epoch 9/20: 100%|██████████| 27288/27288 [08:45<00:00, 51.97it/s, loss=0.792]
 Epoch 9 完成, 平均训练损失: 0.7837
 模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
 Epoch 10/20: 100%|██████████| 27288/27288 [08:51<00:00, 51.37it/s, loss=0.865]
 Epoch 10 完成, 平均训练损失: 0.7794
 模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
 Epoch 11/20: 100%|██████████| 27288/27288 [08:50<00:00, 51.42it/s, loss=0.681]
 Epoch 11 完成, 平均训练损失: 0.7753
 模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
 Epoch 12/20: 100%|██████████| 27288/27288 [08:40<00:00, 52.45it/s, loss=0.852]
 Epoch 12 完成, 平均训练损失: 0.7720
 模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
 Epoch 13/20: 100%|██████████| 27288/27288 [08:44<00:00, 52.01it/s, loss=0.819]
 Epoch 13 完成, 平均训练损失: 0.7694
 模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
 Epoch 14/20: 100%|██████████| 27288/27288 [08:41<00:00, 52.37it/s, loss=0.678]
 Epoch 14 完成, 平均训练损失: 0.7670
 模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
 Epoch 15/20: 100%|██████████| 27288/27288 [08:33<00:00, 53.19it/s, loss=0.532]
 Epoch 15 完成, 平均训练损失: 0.7647
 模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
 Epoch 16/20: 100%|██████████| 27288/27288 [08:34<00:00, 53.05it/s, loss=0.696]
 Epoch 16 完成, 平均训练损失: 0.7631
 模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
 Epoch 17/20: 100%|██████████| 27288/27288 [08:46<00:00, 51.85it/s, loss=0.561]
 Epoch 17 完成, 平均训练损失: 0.7614
 模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
 Epoch 18/20: 100%|██████████| 27288/27288 [08:34<00:00, 53.05it/s, loss=0.715]
 Epoch 18 完成, 平均训练损失: 0.7598
 模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
 Epoch 19/20: 100%|██████████| 27288/27288 [08:33<00:00, 53.17it/s, loss=0.674]
 Epoch 19 完成, 平均训练损失: 0.7582
 模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
 Epoch 20/20: 100%|██████████| 27288/27288 [08:53<00:00, 51.15it/s, loss=0.864]
 Epoch 20 完成, 平均训练损失: 0.7572
 模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
 模型训练完成!

模型成功完成了20个Epoch的训练。训练损失从初始的1.10稳定下降到最终的0.75,表明模型有效地学习了数据中的语言模式。训练过程稳定,速度高效。

模型测试

模型训练完成后,使用generate.py脚本生成文本

python 复制代码
 from fairest creatures we desire increase,
 that never shall become him.
 As you have that, you rascal, yet the saying is,
 Shall wait upon the bed of day to-morrow.
 So that, for mine I pray you:
 I make you these rocks to the elements,
 I'll never wear some strange enemy prove
 A full, and then I will be by her foot,
 Yet I am remember'd with the host:
 When such proceeding breedings, dogs she stands,
 A thousand of her many hours of good Clifford
  • 优点:

    • 词汇与拼写 :生成的单词拼写完全正确,并成功使用了fairest, creatures, rascal等符合风格的词汇。

    • 语法与格式 :基本遵循英文语法,标点符号使用合理。最令人惊喜的是,模型学会了剧本格式,在最后生成了一个角色名Clifford

    • 风格捕捉:文本整体带有一种戏剧化和古雅的腔调,成功模仿了莎士比亚文本的"形"。

  • 不足之处:

    • 语义不连贯:句子之间缺乏逻辑联系,整体内容是随机的、无意义的。这是字符级语言模型的普遍局限。

    • 逻辑混乱:模型不理解其生成内容的含义,只是在进行概率上的字符预测。

5 原理总结

【循环神经网络3】门控循环单元GRU详解-CSDN博客笔记中,我们对GRU的数学原理和结构有了非常扎实的理论了解。但是反观代码,似乎代码中并未体现相关内容。为什么呢?


从宏观到微观------PyTorch的"魔法"

首先需要明白的是,PyTorch这样的深度学习框架,其核心目标之一就是封装复杂的数学运算,让你能用更简洁、更高层的代码来表达模型。

我们在项目中使用的 nn.GRU 就是这样一个高度封装的"魔法盒子"。

python 复制代码
 # 项目中的代码
 self.gru = nn.GRU(embed_dim, hidden_dim, num_layers, batch_first=True)

这短短一行代码,背后就对应了【循环神经网络3】门控循环单元GRU详解-CSDN博客笔记中所有的公式:重置门、更新门、候选隐状态的计算等等。框架自动初始化了所有的 Wb 参数,并实现了整个前向传播流程。

我们的目标:现在,我们要打开这个"魔法盒子",亲手制作里面的每一个零件,看看它们是如何协同工作的。


从零开始,构建一个GRU单元

为了彻底理解,我们不直接用 nn.GRU,而是自己动手写一个 GRUCell。一个 GRUCell 就是GRU在单个时间步 t 所做的所有计算。

【循环神经网络3】门控循环单元GRU详解-CSDN博客笔记中列出了所有需要学习的参数矩阵和偏置向量。在我们的PyTorch代码中,它们通常以 nn.Linear 层的形式存在。

点拨nn.Linear(in_features, out_features) 本质上就是实现 Y = X @ W.T + b。注意这里有个转置 .T,所以 W 的形状是 (out_features, in_features)。这和【循环神经网络3】门控循环单元GRU详解-CSDN博客笔记中的 X_t * W_xr (其中 W_xrd x h) 是完全一致的。

python 复制代码
 import torch
 import torch.nn as nn
 ​
 ​
 class GRUCell_FromScratch(nn.Module):
     """
     从零实现的GRU单元,其计算逻辑与PyTorch官方的nn.GRUCell完全一致。
     """
     def __init__(self, input_size, hidden_size):
         super(GRUCell_FromScratch, self).__init__()
         self.input_size = input_size
         self.hidden_size = hidden_size
 ​
         # --- 定义所有线性变换的权重矩阵 ---
         # 注意:所有nn.Linear层都设置bias=False,因为我们将手动创建和管理所有偏置项,
         # 这样做可以更精确地控制偏置项的运算方式,以匹配官方实现。
 ​
         # 重置门 的权重: W_xr, W_hr
         self.linear_xr = nn.Linear(input_size, hidden_size, bias=False)
         self.linear_hr = nn.Linear(hidden_size, hidden_size, bias=False)
 ​
         # 更新门 的权重: W_xz, W_hz
         self.linear_xz = nn.Linear(input_size, hidden_size, bias=False)
         self.linear_hz = nn.Linear(hidden_size, hidden_size, bias=False)
 ​
         # 候选隐状态 的权重: W_xh, W_hh
         self.linear_xh = nn.Linear(input_size, hidden_size, bias=False)
         self.linear_hh = nn.Linear(hidden_size, hidden_size, bias=False)
 ​
         # --- 手动创建所有偏置项 ---
         # 重置门和更新门的偏置 (官方实现中,输入偏置和隐藏状态偏置是相加的)
         self.bias_r = nn.Parameter(torch.zeros(hidden_size))  # 对应 b_r = b_ir + b_hr
         self.bias_z = nn.Parameter(torch.zeros(hidden_size))  # 对应 b_z = b_iz + b_hz
         
         # 候选隐状态的两个独立偏置项
         self.bias_in = nn.Parameter(torch.zeros(hidden_size))  # 对应 b_in, 不被重置门门控
         self.bias_hn = nn.Parameter(torch.zeros(hidden_size))  # 对应 b_hn, 被重置门门控
 ​
     def forward(self, x_t, h_prev):
         """
         前向传播计算单个时间步。
         参数:
             x_t: 当前时间步的输入, shape (batch_size, input_size)
             h_prev: 前一时间步的隐状态, shape (batch_size, hidden_size)
         返回:
             h_t: 当前时间步计算出的新隐状态, shape (batch_size, hidden_size)
         """
         # --- 步骤 1: 计算重置门 ---
         # 公式: R_t = σ(X_t * W_xr + H_{t-1} * W_hr + b_r)
         r_t = torch.sigmoid(self.linear_xr(x_t) + self.linear_hr(h_prev) + self.bias_r)
 ​
         # --- 步骤 2: 计算更新门 ---
         # 公式: Z_t = σ(X_t * W_xz + H_{t-1} * W_hz + b_z)
         z_t = torch.sigmoid(self.linear_xz(x_t) + self.linear_hz(h_prev) + self.bias_z)
 ​
         # --- 步骤 3: 计算候选隐状态 ---
         # 公式: H_tilde = tanh(X_t * W_xh + R_t ⊙ (H_{t-1} * W_hh + b_hn) + b_in)
         # 这里的 * 运算符实现了哈达玛积 (⊙)
         h_tilde = torch.tanh(
             self.linear_xh(x_t) +
             r_t * (self.linear_hh(h_prev) + self.bias_hn) +
             self.bias_in
         )
 ​
         # --- 步骤 4: 计算最终的隐状态 ---
         # 公式: H_t = Z_t ⊙ H_{t-1} + (1 - Z_t) ⊙ H_tilde
         h_t = z_t * h_prev + (1 - z_t) * h_tilde
 ​
         return h_t

运行并验证我们的"手搓"GRU

理论讲完了,我们来做个小实验,验证一下我们写的 GRUCell_FromScratch 和PyTorch官方的 nn.GRUCell 是不是等价的。

python 复制代码
 # --- 实验验证部分 ---
 # 定义超参数
 batch_size = 4
 input_size = 10
 hidden_size = 20
 ​
 # 创建随机的输入数据
 x_t = torch.randn(batch_size, input_size)
 h_prev = torch.randn(batch_size, hidden_size)
 ​
 # 实例化我们实现的GRU单元和PyTorch官方的GRU单元
 gru_scratch = GRUCell_FromScratch(input_size, hidden_size)
 gru_official = nn.GRUCell(input_size, hidden_size)
 ​
 # --- 将官方模型的参数复制到我们实现的模型中 ---
 # PyTorch将所有输入权重和偏置堆叠在 weight_ih 和 bias_ih 中
 # 将所有隐藏状态权重和偏置堆叠在 weight_hh 和 bias_hh 中
 ​
 # 复制权重
 gru_scratch.linear_xr.weight.data.copy_(gru_official.weight_ih[:hidden_size, :])
 gru_scratch.linear_hr.weight.data.copy_(gru_official.weight_hh[:hidden_size, :])
 gru_scratch.linear_xz.weight.data.copy_(gru_official.weight_ih[hidden_size:hidden_size * 2, :])
 gru_scratch.linear_hz.weight.data.copy_(gru_official.weight_hh[hidden_size:hidden_size * 2, :])
 gru_scratch.linear_xh.weight.data.copy_(gru_official.weight_ih[hidden_size * 2:, :])
 gru_scratch.linear_hh.weight.data.copy_(gru_official.weight_hh[hidden_size * 2:, :])
 ​
 # 复制偏置
 gru_scratch.bias_r.data.copy_(gru_official.bias_ih[:hidden_size] + gru_official.bias_hh[:hidden_size])
 gru_scratch.bias_z.data.copy_(gru_official.bias_ih[hidden_size:hidden_size * 2] + gru_official.bias_hh[hidden_size:hidden_size * 2])
 gru_scratch.bias_in.data.copy_(gru_official.bias_ih[hidden_size * 2:])
 gru_scratch.bias_hn.data.copy_(gru_official.bias_hh[hidden_size * 2:])
 ​
 # 分别用两个模型进行前向传播
 h_t_scratch = gru_scratch(x_t, h_prev)
 h_t_official = gru_official(x_t, h_prev)
 ​
 # 比较结果
 print("我们手搓的GRU输出:", h_t_scratch)
 print("PyTorch官方GRU输出:", h_t_official)
 ​
 # 检查两个输出是否在数值上几乎相同
 print("\n两个输出是否几乎相同?", torch.allclose(h_t_scratch, h_t_official))

当你运行这段代码,如果最后打印出 True,那么恭喜你!你已经成功地用代码复现了GRU的核心数学原理。这证明了你对GRU的理解已经深入到了"像素级"。以下是我运行的结果:

python 复制代码
 我们手搓的GRU输出: tensor([[-0.8295, -0.2595,  0.9684,  0.2140, -0.5785, -0.1933,  0.5778,  0.3714,
           0.5419, -1.2080,  0.5973,  0.6886, -0.7297, -0.1810, -0.3550,  0.1375,
           0.3358, -0.2605, -0.4440,  0.6498],
         [ 0.3155,  0.0616, -0.5738, -0.1972, -0.0984,  0.0601,  0.3601,  0.1683,
          -0.4179,  0.4705,  0.4867, -0.5043,  1.2716,  0.0027, -0.4619, -0.3631,
          -0.4136, -0.6153, -0.3496, -0.8575],
         [-0.1301, -0.4527, -0.3129,  0.2685, -0.3576, -0.3155, -0.4003,  0.4550,
          -0.3802,  0.3482,  0.8009,  0.1505,  0.2446,  0.0780,  0.4634, -0.1107,
           0.2131,  0.3837, -0.4669,  0.0181],
         [-0.0648,  0.0902, -0.0132,  0.0585, -0.1076, -0.5664,  0.1125, -0.1067,
          -0.0702, -0.5483,  0.5603,  0.2239,  0.0498,  0.8238, -0.0751, -0.4099,
           0.1920,  0.5400, -0.1944,  0.4914]], grad_fn=<AddBackward0>)
 PyTorch官方GRU输出: tensor([[-0.8295, -0.2595,  0.9684,  0.2140, -0.5785, -0.1933,  0.5778,  0.3714,
           0.5419, -1.2080,  0.5973,  0.6886, -0.7297, -0.1810, -0.3550,  0.1375,
           0.3358, -0.2605, -0.4440,  0.6498],
         [ 0.3155,  0.0616, -0.5738, -0.1972, -0.0984,  0.0601,  0.3601,  0.1683,
          -0.4179,  0.4705,  0.4867, -0.5043,  1.2716,  0.0027, -0.4619, -0.3631,
          -0.4136, -0.6153, -0.3496, -0.8575],
         [-0.1301, -0.4527, -0.3129,  0.2685, -0.3576, -0.3155, -0.4003,  0.4550,
          -0.3802,  0.3482,  0.8009,  0.1505,  0.2446,  0.0780,  0.4634, -0.1107,
           0.2131,  0.3837, -0.4669,  0.0181],
         [-0.0648,  0.0902, -0.0132,  0.0585, -0.1076, -0.5664,  0.1125, -0.1067,
          -0.0702, -0.5483,  0.5603,  0.2239,  0.0498,  0.8238, -0.0751, -0.4099,
           0.1920,  0.5400, -0.1944,  0.4914]], grad_fn=<AddBackward0>)
 ​
 两个输出是否几乎相同? True

回归项目,融会贯通

现在,我们再回头看项目代码 src/model.py

python 复制代码
 # src/model.py
 class GRUModel(nn.Module):
     # ...
     def forward(self, x, hidden):
         embedded = self.embedding(x)
         out, hidden = self.gru(embedded, hidden) # <--- 就是这里!
         out = out.contiguous().view(-1, self.hidden_dim)
         out = self.fc(out)
         return out, hidden

这里的 self.gru 就是我们上面验证的 nn.GRUCell 的"循环版本"。nn.GRU 会在内部自动地遍历输入序列的每一个时间步,反复调用 GRUCell 的计算逻辑,并把每一步的隐状态传递给下一步。

  • self.embedding:将你的字符ID转换为密集向量。

  • self.gru:处理整个序列,输出每个时间步的隐状态 out 和最后一个时间步的隐状态 hidden

  • self.fc:这就是【循环神经网络3】门控循环单元GRU详解-CSDN博客笔记 2.4 节 的最终输出层 O_t = H_t * W_hq + b_q,它将GRU的输出映射到词汇表大小,用于预测下一个字符的概率。

如果读者还想学习LSTM模型实战,可移步至⬇️:

【循环神经网络6】LSTM实战------基于LSTM的IMDb电影评论情感分析-CSDN博客https://blog.csdn.net/colus_SEU/article/details/152564800?spm=1001.2014.3001.5501

相关推荐
Penguin大阪2 小时前
GRU模型这波牛市应用股价预测
人工智能·深度学习·gru
格林威2 小时前
UV紫外相机在工业视觉检测中的应用
人工智能·深度学习·数码相机·算法·计算机视觉·视觉检测·uv
格林威3 小时前
工业视觉检测里的 “柔性” 是什么?
图像处理·人工智能·深度学习·yolo·计算机视觉·视觉检测
停走的风3 小时前
(CVPR2025)DEIM模型训练自己的数据集教程(基于Pycharm)
python·深度学习·pycharm·模型训练·deim
丁学文武3 小时前
大模型原理与实践:第三章-预训练语言模型详解_第3部分-Decoder-Only(GPT、LLama、GLM)
人工智能·gpt·语言模型·自然语言处理·大模型·llama·glm
说私域3 小时前
公域流量转化困境下开源AI智能名片与链动2+1模式的S2B2C商城小程序应用研究
人工智能·小程序·开源
每天一个java小知识3 小时前
Spring-AI 接入(本地大模型 deepseek + 阿里云百炼 + 硅基流动)
java·人工智能·spring
格林威3 小时前
近红外相机在机器视觉检测中的应用
人工智能·数码相机·opencv·计算机视觉·视觉检测
罗小罗同学3 小时前
覆盖9个癌种,基于11671张病理切片训练的模型登上Nature子刊,可精准“读出”分子标志物,突破传统分类局限
人工智能·深度学习·分类·数据挖掘·病理组学·医学人工智能·医工交叉