【循环神经网络3】门控循环单元GRU详解-CSDN博客https://blog.csdn.net/colus_SEU/article/details/152218119?spm=1001.2014.3001.5501在以上▲笔记中,我们详解了GRU模型,接下来我们通过实战来进一步理解。
1 项目概述
本项目旨在构建一个字符级的语言模型,使用PyTorch框架实现GRU(门控循环单元)网络,来学习莎士比亚戏剧文本的语言模式,并生成具有莎士比亚风格的新文本。
-
核心任务:文本生成。
-
数据集 :Kaggle上的莎士比亚戏剧数据集,包含超过11万行角色台词,下载地址:Shakespeare plays(直接点击download下载zip文件解压后将文件夹中的Shakespeare_data.csv文件放到项目目录指定位置即可)
-
最终成果:一个能够生成语法基本正确、词汇风格贴近莎士比亚戏剧的文本的GRU模型。
2 项目目录
python
GRU_shakespeare_sonnets/
├── data/
│ └── raw/
│ └── Shakespeare_data.csv
├── models/
│ └── gru_model.pth
├── src/
│ ├── data_loader.py # 数据下载、预处理和加载
│ ├── model.py # GRU模型定义
│ ├── train.py # 训练脚本
│ └── generate.py # 文本生成脚本
├── vocab.pkl # 保存的词汇表
└── requirements.txt # 项目依赖
3 项目代码
python
# src/model.py
import torch
import torch.nn as nn
class GRUModel(nn.Module):
def __init__(self, vocab_size, embed_dim=256, hidden_dim=512, num_layers=2):
'''
:param vocab_size: 词汇表的大小,即模型中使用的不同单词的数量。
:param embed_dim: 词嵌入的维度,默认为256。词嵌入是将单词转换为固定长度向量的过程,这些向量捕获单词的语义信息。
:param hidden_dim: GRU隐藏层的维度,默认为512。这决定了模型的学习能力和复杂性。
:param num_layers: GRU层的数量,默认为2。多层GRU可以捕获更复杂的序列模式。
'''
super(GRUModel, self).__init__()
self.hidden_dim = hidden_dim
self.num_layers = num_layers
# 词嵌入层
self.embedding = nn.Embedding(vocab_size, embed_dim)
# 创建一个词嵌入层,使用nn.Embedding类。这个层将词汇表中的每个单词映射到一个embed_dim维的向量。
# GRU层
self.gru = nn.GRU(embed_dim, hidden_dim, num_layers, batch_first=True)
# 创建一个GRU层,使用nn.GRU类。
# batch_first=True:表示输入数据的第一个维度是批量大小,这使得数据组织更加直观。
# 全连接输出层
self.fc = nn.Linear(hidden_dim, vocab_size)
# 这个层将GRU的输出映射到词汇表的大小,从而预测下一个单词。hidden_dim是输入尺寸,vocab_size是输出尺寸。
def forward(self, x, hidden):
# x shape: (batch_size, seq_length)
# hidden shape: (num_layers, batch_size, hidden_dim)
# 嵌入层
embedded = self.embedding(x) # shape: (batch_size, seq_length, embed_dim)
# GRU层
# out shape: (batch_size, seq_length, hidden_dim)
# hidden shape: (num_layers, batch_size, hidden_dim)
out, hidden = self.gru(embedded, hidden)
# 将GRU的输出传入全连接层
# 我们需要将out重塑以便通过fc层
out = out.contiguous().view(-1, self.hidden_dim) # shape: (batch_size * seq_length, hidden_dim)
out = self.fc(out) # shape: (batch_size * seq_length, vocab_size)
return out, hidden
def init_hidden(self, batch_size, device):
"""初始化隐藏状态"""
# 形状: (num_layers, batch_size, hidden_dim)
weight = next(self.parameters()).data
hidden = weight.new(self.num_layers, batch_size, self.hidden_dim).zero_().to(device)
return hidden
python
# src/train.py
import torch
import torch.nn as nn
import torch.optim as optim
import os
import pickle
from tqdm import tqdm
from data_loader import get_data_loaders, PROJECT_ROOT # 导入项目根路径
from model import GRUModel
# --- 超参数配置 ---
class Config:
# 路径配置 (使用动态路径)
MODEL_DIR = os.path.join(PROJECT_ROOT, "models")
VOCAB_FILE = os.path.join(PROJECT_ROOT, "vocab.pkl")
MODEL_SAVE_PATH = os.path.join(MODEL_DIR, "gru_model.pth")
BATCH_SIZE = 128
SEQUENCE_LENGTH = 100
N_EPOCHS = 20
LEARNING_RATE = 0.0005
EMBED_DIM = 256
HIDDEN_DIM = 512
N_LAYERS = 2
# 创建模型目录
if not os.path.exists(MODEL_DIR):
os.makedirs(MODEL_DIR)
def train():
"""主训练函数"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
train_loader, val_loader, vocab_size = get_data_loaders(
batch_size=Config.BATCH_SIZE,
sequence_length=Config.SEQUENCE_LENGTH
)
model = GRUModel(
vocab_size=vocab_size,
embed_dim=Config.EMBED_DIM,
hidden_dim=Config.HIDDEN_DIM,
num_layers=Config.N_LAYERS
).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=Config.LEARNING_RATE)
print("模型训练开始...")
for epoch in range(Config.N_EPOCHS):
model.train()
total_loss = 0
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{Config.N_EPOCHS}")
for i, (inputs, targets) in enumerate(progress_bar):
# 将数据移动到指定设备
inputs, targets = inputs.to(device), targets.to(device)
# --- 关键修改 ---
# 在每个批次开始时,根据当前批次的大小重新初始化隐藏状态
current_batch_size = inputs.size(0)
hidden = model.init_hidden(current_batch_size, device)
# --- 关键修改:添加梯度裁剪 ---
# 防止梯度爆炸,将梯度的L2范数限制在5.0以内
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
# 梯度清零
optimizer.zero_grad()
# 前向传播
output, hidden = model(inputs, hidden)
# 将 targets 展平,以便与 output 的维度匹配
targets = targets.view(-1)
# 计算损失
loss = criterion(output, targets)
# 反向传播和优化
loss.backward()
optimizer.step()
total_loss += loss.item()
progress_bar.set_postfix(loss=loss.item())
avg_loss = total_loss / len(train_loader)
print(f"Epoch {epoch + 1} 完成, 平均训练损失: {avg_loss:.4f}")
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': avg_loss,
'config': Config
}, Config.MODEL_SAVE_PATH)
print(f"模型已保存至 {Config.MODEL_SAVE_PATH}")
print("模型训练完成!")
if __name__ == '__main__':
train()
python
# src/generate.py
import torch
import pickle
import os
import random
from model import GRUModel
from train import Config, PROJECT_ROOT # 导入项目根路径
def generate_text(start_str="shall i compare thee to a summer's day?\n", gen_length=500, temperature=0.8):
"""使用训练好的模型生成文本"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 路径配置 (使用动态路径)
VOCAB_FILE = os.path.join(PROJECT_ROOT, "vocab.pkl")
MODEL_SAVE_PATH = Config.MODEL_SAVE_PATH
# 1. 加载词汇表
try:
with open(VOCAB_FILE, 'rb') as f:
char_to_ix, ix_to_char, vocab_size = pickle.load(f)
except FileNotFoundError:
print(f"错误: 词汇表文件 {VOCAB_FILE} 未找到。请先运行训练脚本。")
return
# 2. 初始化模型
model = GRUModel(
vocab_size=vocab_size,
embed_dim=Config.EMBED_DIM,
hidden_dim=Config.HIDDEN_DIM,
num_layers=Config.N_LAYERS
).to(device)
# 3. 加载模型权重
try:
checkpoint = torch.load(MODEL_SAVE_PATH, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"成功加载模型权重: {MODEL_SAVE_PATH}")
except FileNotFoundError:
print(f"错误: 模型文件 {MODEL_SAVE_PATH} 未找到。请先运行训练脚本。")
return
model.eval()
input_seq = [char_to_ix[ch] for ch in start_str]
input_tensor = torch.tensor(input_seq, dtype=torch.long).unsqueeze(0).to(device)
hidden = model.init_hidden(1, device)
generated_text = start_str
with torch.no_grad():
for _ in range(gen_length):
output, hidden = model(input_tensor, hidden)
output = output.squeeze(0).div(temperature).exp()
top_i = torch.multinomial(output, 1)[0]
predicted_char = ix_to_char[top_i.item()]
generated_text += predicted_char
input_tensor = torch.tensor([[top_i]], dtype=torch.long).to(device)
print("\n--- 生成的文本 ---")
print(generated_text)
print("------------------\n")
if __name__ == '__main__':
generate_text(start_str="from fairest creatures we desire increase,\n", gen_length=400, temperature=0.8)
python
# src/data_loader.py
import os
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
import pickle
import sys
# --- 路径配置 (本地优化) ---
# 获取当前脚本所在的目录
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
# 获取项目根目录
PROJECT_ROOT = os.path.dirname(CURRENT_DIR)
DATA_DIR = os.path.join(PROJECT_ROOT, "data")
RAW_DATA_DIR = os.path.join(DATA_DIR, "raw")
# 修改为你实际的文件名
CSV_FILE = os.path.join(RAW_DATA_DIR, "Shakespeare_data.csv")
VOCAB_FILE = os.path.join(PROJECT_ROOT, "vocab.pkl")
def check_data_exists():
"""检查数据文件是否存在,如果不存在则给出提示并退出"""
if not os.path.exists(CSV_FILE):
print("=" * 50)
print("错误:数据文件未找到!")
print(f"请确保数据文件已放置在以下位置:\n{CSV_FILE}")
print("\n请按以下步骤操作:")
print("1. 访问 https://www.kaggle.com/datasets/kingburrito666/shakespeare-sonnets")
print("2. 点击 'Download' 按钮。")
print("3. 解压下载的 zip 文件。")
print("4. 将 'Shakespeare_data.csv' 文件复制到 'data/raw/' 目录下。")
print("=" * 50)
sys.exit(1) # 退出程序
def prepare_data(sequence_length=100):
"""加载、预处理数据并创建DataLoader"""
# 1. 检查数据文件
check_data_exists()
# 2. 加载数据
print(f"正在从 {CSV_FILE} 加载数据...")
df = pd.read_csv(CSV_FILE)
# 3. 数据清洗:过滤掉舞台说明,只保留角色台词
player_lines_df = df.dropna(subset=['Player'])
print(f"数据加载完成,共加载了 {len(player_lines_df)} 行台词。")
# 4. 合并所有台词为一个长字符串
text = player_lines_df['PlayerLine'].str.cat(sep='\n')
# 5. 创建字符集
chars = sorted(list(set(text)))
vocab_size = len(chars)
char_to_ix = {ch: i for i, ch in enumerate(chars)}
ix_to_char = {i: ch for i, ch in enumerate(chars)}
# 保存词汇表
with open(VOCAB_FILE, 'wb') as f:
pickle.dump((char_to_ix, ix_to_char, vocab_size), f)
print(f"词汇表已保存至 {VOCAB_FILE},词汇量大小: {vocab_size}")
# 6. 将整个文本转换为整数序列
text_as_int = [char_to_ix[ch] for ch in text]
# 7. 创建输入序列和目标序列 (关键修改!)
input_sequences = []
target_sequences = []
for i in range(0, len(text_as_int) - sequence_length):
# 输入是从 i 到 i+sequence_length-1
input_sequences.append(text_as_int[i: i + sequence_length])
# 目标是从 i+1 到 i+sequence_length (向左移动一位)
target_sequences.append(text_as_int[i + 1: i + sequence_length + 1])
print(f"总序列数: {len(input_sequences)}")
# 8. 创建自定义Dataset (关键修改!)
class ShakespeareDataset(Dataset):
def __init__(self, sequences, targets):
self.sequences = sequences
self.targets = targets
def __len__(self):
return len(self.sequences)
def __getitem__(self, idx):
# 返回一个输入序列和对应的目标序列
return torch.tensor(self.sequences[idx], dtype=torch.long), torch.tensor(self.targets[idx],
dtype=torch.long)
dataset = ShakespeareDataset(input_sequences, target_sequences)
return dataset, vocab_size
def get_data_loaders(batch_size=64, sequence_length=100):
"""获取训练和验证DataLoader"""
dataset, vocab_size = prepare_data(sequence_length)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
return train_loader, val_loader, vocab_size
4 结果分析
模型训练
运行train.py,可以看到以下类似输出:
python
使用设备: cuda
正在从 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/data/raw/Shakespeare_data.csv 加载数据...
数据加载完成,共加载了 111389 行台词。
词汇表已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/vocab.pkl,词汇量大小: 77
总序列数: 4365922
模型训练开始...
Epoch 1/20: 100%|██████████| 27288/27288 [08:51<00:00, 51.34it/s, loss=0.953]
Epoch 1 完成, 平均训练损失: 1.1017
模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
Epoch 2/20: 100%|██████████| 27288/27288 [08:41<00:00, 52.28it/s, loss=0.927]
Epoch 2 完成, 平均训练损失: 0.9059
模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
Epoch 3/20: 100%|██████████| 27288/27288 [08:50<00:00, 51.47it/s, loss=0.728]
Epoch 3 完成, 平均训练损失: 0.8570
模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
Epoch 4/20: 100%|██████████| 27288/27288 [08:49<00:00, 51.58it/s, loss=0.992]
Epoch 4 完成, 平均训练损失: 0.8321
模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
Epoch 5/20: 100%|██████████| 27288/27288 [08:51<00:00, 51.34it/s, loss=0.912]
Epoch 5 完成, 平均训练损失: 0.8161
模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
Epoch 6/20: 100%|██████████| 27288/27288 [08:49<00:00, 51.55it/s, loss=0.786]
Epoch 6 完成, 平均训练损失: 0.8047
模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
Epoch 7/20: 100%|██████████| 27288/27288 [08:50<00:00, 51.48it/s, loss=0.797]
Epoch 7 完成, 平均训练损失: 0.7960
模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
Epoch 8/20: 100%|██████████| 27288/27288 [08:41<00:00, 52.37it/s, loss=0.74]
Epoch 8 完成, 平均训练损失: 0.7892
模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
Epoch 9/20: 100%|██████████| 27288/27288 [08:45<00:00, 51.97it/s, loss=0.792]
Epoch 9 完成, 平均训练损失: 0.7837
模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
Epoch 10/20: 100%|██████████| 27288/27288 [08:51<00:00, 51.37it/s, loss=0.865]
Epoch 10 完成, 平均训练损失: 0.7794
模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
Epoch 11/20: 100%|██████████| 27288/27288 [08:50<00:00, 51.42it/s, loss=0.681]
Epoch 11 完成, 平均训练损失: 0.7753
模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
Epoch 12/20: 100%|██████████| 27288/27288 [08:40<00:00, 52.45it/s, loss=0.852]
Epoch 12 完成, 平均训练损失: 0.7720
模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
Epoch 13/20: 100%|██████████| 27288/27288 [08:44<00:00, 52.01it/s, loss=0.819]
Epoch 13 完成, 平均训练损失: 0.7694
模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
Epoch 14/20: 100%|██████████| 27288/27288 [08:41<00:00, 52.37it/s, loss=0.678]
Epoch 14 完成, 平均训练损失: 0.7670
模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
Epoch 15/20: 100%|██████████| 27288/27288 [08:33<00:00, 53.19it/s, loss=0.532]
Epoch 15 完成, 平均训练损失: 0.7647
模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
Epoch 16/20: 100%|██████████| 27288/27288 [08:34<00:00, 53.05it/s, loss=0.696]
Epoch 16 完成, 平均训练损失: 0.7631
模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
Epoch 17/20: 100%|██████████| 27288/27288 [08:46<00:00, 51.85it/s, loss=0.561]
Epoch 17 完成, 平均训练损失: 0.7614
模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
Epoch 18/20: 100%|██████████| 27288/27288 [08:34<00:00, 53.05it/s, loss=0.715]
Epoch 18 完成, 平均训练损失: 0.7598
模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
Epoch 19/20: 100%|██████████| 27288/27288 [08:33<00:00, 53.17it/s, loss=0.674]
Epoch 19 完成, 平均训练损失: 0.7582
模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
Epoch 20/20: 100%|██████████| 27288/27288 [08:53<00:00, 51.15it/s, loss=0.864]
Epoch 20 完成, 平均训练损失: 0.7572
模型已保存至 /hdd/ug_share/LiuLu/RNN/GRU_shakespeare_sonnets/models/gru_model.pth
模型训练完成!
模型成功完成了20个Epoch的训练。训练损失从初始的1.10
稳定下降到最终的0.75
,表明模型有效地学习了数据中的语言模式。训练过程稳定,速度高效。
模型测试
模型训练完成后,使用generate.py脚本生成文本
python
from fairest creatures we desire increase,
that never shall become him.
As you have that, you rascal, yet the saying is,
Shall wait upon the bed of day to-morrow.
So that, for mine I pray you:
I make you these rocks to the elements,
I'll never wear some strange enemy prove
A full, and then I will be by her foot,
Yet I am remember'd with the host:
When such proceeding breedings, dogs she stands,
A thousand of her many hours of good Clifford
-
优点:
-
词汇与拼写 :生成的单词拼写完全正确,并成功使用了
fairest
,creatures
,rascal
等符合风格的词汇。 -
语法与格式 :基本遵循英文语法,标点符号使用合理。最令人惊喜的是,模型学会了剧本格式,在最后生成了一个角色名
Clifford
。 -
风格捕捉:文本整体带有一种戏剧化和古雅的腔调,成功模仿了莎士比亚文本的"形"。
-
-
不足之处:
-
语义不连贯:句子之间缺乏逻辑联系,整体内容是随机的、无意义的。这是字符级语言模型的普遍局限。
-
逻辑混乱:模型不理解其生成内容的含义,只是在进行概率上的字符预测。
-
5 原理总结
在【循环神经网络3】门控循环单元GRU详解-CSDN博客笔记中,我们对GRU的数学原理和结构有了非常扎实的理论了解。但是反观代码,似乎代码中并未体现相关内容。为什么呢?
从宏观到微观------PyTorch的"魔法"
首先需要明白的是,PyTorch这样的深度学习框架,其核心目标之一就是封装复杂的数学运算,让你能用更简洁、更高层的代码来表达模型。
我们在项目中使用的 nn.GRU
就是这样一个高度封装的"魔法盒子"。
python
# 项目中的代码
self.gru = nn.GRU(embed_dim, hidden_dim, num_layers, batch_first=True)
这短短一行代码,背后就对应了【循环神经网络3】门控循环单元GRU详解-CSDN博客笔记中所有的公式:重置门、更新门、候选隐状态的计算等等。框架自动初始化了所有的 W
和 b
参数,并实现了整个前向传播流程。
我们的目标:现在,我们要打开这个"魔法盒子",亲手制作里面的每一个零件,看看它们是如何协同工作的。
从零开始,构建一个GRU单元
为了彻底理解,我们不直接用 nn.GRU
,而是自己动手写一个 GRUCell
。一个 GRUCell
就是GRU在单个时间步 t
所做的所有计算。
【循环神经网络3】门控循环单元GRU详解-CSDN博客笔记中列出了所有需要学习的参数矩阵和偏置向量。在我们的PyTorch代码中,它们通常以 nn.Linear
层的形式存在。
点拨 :
nn.Linear(in_features, out_features)
本质上就是实现Y = X @ W.T + b
。注意这里有个转置.T
,所以W
的形状是(out_features, in_features)
。这和【循环神经网络3】门控循环单元GRU详解-CSDN博客笔记中的X_t * W_xr
(其中W_xr
是d x h
) 是完全一致的。
python
import torch
import torch.nn as nn
class GRUCell_FromScratch(nn.Module):
"""
从零实现的GRU单元,其计算逻辑与PyTorch官方的nn.GRUCell完全一致。
"""
def __init__(self, input_size, hidden_size):
super(GRUCell_FromScratch, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
# --- 定义所有线性变换的权重矩阵 ---
# 注意:所有nn.Linear层都设置bias=False,因为我们将手动创建和管理所有偏置项,
# 这样做可以更精确地控制偏置项的运算方式,以匹配官方实现。
# 重置门 的权重: W_xr, W_hr
self.linear_xr = nn.Linear(input_size, hidden_size, bias=False)
self.linear_hr = nn.Linear(hidden_size, hidden_size, bias=False)
# 更新门 的权重: W_xz, W_hz
self.linear_xz = nn.Linear(input_size, hidden_size, bias=False)
self.linear_hz = nn.Linear(hidden_size, hidden_size, bias=False)
# 候选隐状态 的权重: W_xh, W_hh
self.linear_xh = nn.Linear(input_size, hidden_size, bias=False)
self.linear_hh = nn.Linear(hidden_size, hidden_size, bias=False)
# --- 手动创建所有偏置项 ---
# 重置门和更新门的偏置 (官方实现中,输入偏置和隐藏状态偏置是相加的)
self.bias_r = nn.Parameter(torch.zeros(hidden_size)) # 对应 b_r = b_ir + b_hr
self.bias_z = nn.Parameter(torch.zeros(hidden_size)) # 对应 b_z = b_iz + b_hz
# 候选隐状态的两个独立偏置项
self.bias_in = nn.Parameter(torch.zeros(hidden_size)) # 对应 b_in, 不被重置门门控
self.bias_hn = nn.Parameter(torch.zeros(hidden_size)) # 对应 b_hn, 被重置门门控
def forward(self, x_t, h_prev):
"""
前向传播计算单个时间步。
参数:
x_t: 当前时间步的输入, shape (batch_size, input_size)
h_prev: 前一时间步的隐状态, shape (batch_size, hidden_size)
返回:
h_t: 当前时间步计算出的新隐状态, shape (batch_size, hidden_size)
"""
# --- 步骤 1: 计算重置门 ---
# 公式: R_t = σ(X_t * W_xr + H_{t-1} * W_hr + b_r)
r_t = torch.sigmoid(self.linear_xr(x_t) + self.linear_hr(h_prev) + self.bias_r)
# --- 步骤 2: 计算更新门 ---
# 公式: Z_t = σ(X_t * W_xz + H_{t-1} * W_hz + b_z)
z_t = torch.sigmoid(self.linear_xz(x_t) + self.linear_hz(h_prev) + self.bias_z)
# --- 步骤 3: 计算候选隐状态 ---
# 公式: H_tilde = tanh(X_t * W_xh + R_t ⊙ (H_{t-1} * W_hh + b_hn) + b_in)
# 这里的 * 运算符实现了哈达玛积 (⊙)
h_tilde = torch.tanh(
self.linear_xh(x_t) +
r_t * (self.linear_hh(h_prev) + self.bias_hn) +
self.bias_in
)
# --- 步骤 4: 计算最终的隐状态 ---
# 公式: H_t = Z_t ⊙ H_{t-1} + (1 - Z_t) ⊙ H_tilde
h_t = z_t * h_prev + (1 - z_t) * h_tilde
return h_t
运行并验证我们的"手搓"GRU
理论讲完了,我们来做个小实验,验证一下我们写的 GRUCell_FromScratch
和PyTorch官方的 nn.GRUCell
是不是等价的。
python
# --- 实验验证部分 ---
# 定义超参数
batch_size = 4
input_size = 10
hidden_size = 20
# 创建随机的输入数据
x_t = torch.randn(batch_size, input_size)
h_prev = torch.randn(batch_size, hidden_size)
# 实例化我们实现的GRU单元和PyTorch官方的GRU单元
gru_scratch = GRUCell_FromScratch(input_size, hidden_size)
gru_official = nn.GRUCell(input_size, hidden_size)
# --- 将官方模型的参数复制到我们实现的模型中 ---
# PyTorch将所有输入权重和偏置堆叠在 weight_ih 和 bias_ih 中
# 将所有隐藏状态权重和偏置堆叠在 weight_hh 和 bias_hh 中
# 复制权重
gru_scratch.linear_xr.weight.data.copy_(gru_official.weight_ih[:hidden_size, :])
gru_scratch.linear_hr.weight.data.copy_(gru_official.weight_hh[:hidden_size, :])
gru_scratch.linear_xz.weight.data.copy_(gru_official.weight_ih[hidden_size:hidden_size * 2, :])
gru_scratch.linear_hz.weight.data.copy_(gru_official.weight_hh[hidden_size:hidden_size * 2, :])
gru_scratch.linear_xh.weight.data.copy_(gru_official.weight_ih[hidden_size * 2:, :])
gru_scratch.linear_hh.weight.data.copy_(gru_official.weight_hh[hidden_size * 2:, :])
# 复制偏置
gru_scratch.bias_r.data.copy_(gru_official.bias_ih[:hidden_size] + gru_official.bias_hh[:hidden_size])
gru_scratch.bias_z.data.copy_(gru_official.bias_ih[hidden_size:hidden_size * 2] + gru_official.bias_hh[hidden_size:hidden_size * 2])
gru_scratch.bias_in.data.copy_(gru_official.bias_ih[hidden_size * 2:])
gru_scratch.bias_hn.data.copy_(gru_official.bias_hh[hidden_size * 2:])
# 分别用两个模型进行前向传播
h_t_scratch = gru_scratch(x_t, h_prev)
h_t_official = gru_official(x_t, h_prev)
# 比较结果
print("我们手搓的GRU输出:", h_t_scratch)
print("PyTorch官方GRU输出:", h_t_official)
# 检查两个输出是否在数值上几乎相同
print("\n两个输出是否几乎相同?", torch.allclose(h_t_scratch, h_t_official))
当你运行这段代码,如果最后打印出 True
,那么恭喜你!你已经成功地用代码复现了GRU的核心数学原理。这证明了你对GRU的理解已经深入到了"像素级"。以下是我运行的结果:
python
我们手搓的GRU输出: tensor([[-0.8295, -0.2595, 0.9684, 0.2140, -0.5785, -0.1933, 0.5778, 0.3714,
0.5419, -1.2080, 0.5973, 0.6886, -0.7297, -0.1810, -0.3550, 0.1375,
0.3358, -0.2605, -0.4440, 0.6498],
[ 0.3155, 0.0616, -0.5738, -0.1972, -0.0984, 0.0601, 0.3601, 0.1683,
-0.4179, 0.4705, 0.4867, -0.5043, 1.2716, 0.0027, -0.4619, -0.3631,
-0.4136, -0.6153, -0.3496, -0.8575],
[-0.1301, -0.4527, -0.3129, 0.2685, -0.3576, -0.3155, -0.4003, 0.4550,
-0.3802, 0.3482, 0.8009, 0.1505, 0.2446, 0.0780, 0.4634, -0.1107,
0.2131, 0.3837, -0.4669, 0.0181],
[-0.0648, 0.0902, -0.0132, 0.0585, -0.1076, -0.5664, 0.1125, -0.1067,
-0.0702, -0.5483, 0.5603, 0.2239, 0.0498, 0.8238, -0.0751, -0.4099,
0.1920, 0.5400, -0.1944, 0.4914]], grad_fn=<AddBackward0>)
PyTorch官方GRU输出: tensor([[-0.8295, -0.2595, 0.9684, 0.2140, -0.5785, -0.1933, 0.5778, 0.3714,
0.5419, -1.2080, 0.5973, 0.6886, -0.7297, -0.1810, -0.3550, 0.1375,
0.3358, -0.2605, -0.4440, 0.6498],
[ 0.3155, 0.0616, -0.5738, -0.1972, -0.0984, 0.0601, 0.3601, 0.1683,
-0.4179, 0.4705, 0.4867, -0.5043, 1.2716, 0.0027, -0.4619, -0.3631,
-0.4136, -0.6153, -0.3496, -0.8575],
[-0.1301, -0.4527, -0.3129, 0.2685, -0.3576, -0.3155, -0.4003, 0.4550,
-0.3802, 0.3482, 0.8009, 0.1505, 0.2446, 0.0780, 0.4634, -0.1107,
0.2131, 0.3837, -0.4669, 0.0181],
[-0.0648, 0.0902, -0.0132, 0.0585, -0.1076, -0.5664, 0.1125, -0.1067,
-0.0702, -0.5483, 0.5603, 0.2239, 0.0498, 0.8238, -0.0751, -0.4099,
0.1920, 0.5400, -0.1944, 0.4914]], grad_fn=<AddBackward0>)
两个输出是否几乎相同? True
回归项目,融会贯通
现在,我们再回头看项目代码 src/model.py
:
python
# src/model.py
class GRUModel(nn.Module):
# ...
def forward(self, x, hidden):
embedded = self.embedding(x)
out, hidden = self.gru(embedded, hidden) # <--- 就是这里!
out = out.contiguous().view(-1, self.hidden_dim)
out = self.fc(out)
return out, hidden
这里的 self.gru
就是我们上面验证的 nn.GRUCell
的"循环版本"。nn.GRU
会在内部自动地遍历输入序列的每一个时间步,反复调用 GRUCell
的计算逻辑,并把每一步的隐状态传递给下一步。
-
self.embedding
:将你的字符ID转换为密集向量。 -
self.gru
:处理整个序列,输出每个时间步的隐状态out
和最后一个时间步的隐状态hidden
。 -
self.fc
:这就是【循环神经网络3】门控循环单元GRU详解-CSDN博客笔记 2.4 节 的最终输出层O_t = H_t * W_hq + b_q
,它将GRU的输出映射到词汇表大小,用于预测下一个字符的概率。
如果读者还想学习LSTM模型实战,可移步至⬇️: