pytorch nn.RNN demo

之前已经讲过关于RNNCell的实现了.

这里用LLM写了一个简单的nn.RNN demo:

python 复制代码
import torch
import torch.nn as nn

# 设置随机种子以便结果可复现
torch.manual_seed(42)

# 定义模型参数
input_size = 4      # 输入特征维度
hidden_size = 8     # 隐藏层维度
num_layers = 2      # RNN 层数(修改为2层)
seq_len = 10        # 序列长度
batch_size = 3      # 批量大小

# 创建2层RNN模型
model = nn.RNN(
    input_size=input_size,
    hidden_size=hidden_size,
    num_layers=num_layers,
    batch_first=False  # 输入输出格式: [seq_len, batch_size, feature_size]
)

# 生成随机输入数据 [seq_len, batch_size, input_size]
x = torch.randn(seq_len, batch_size, input_size)
print(f"输入 x 的形状: {x.shape}  # [seq_len, batch_size, input_size]")

# 初始化隐藏状态 (可选)
h0 = torch.zeros(num_layers, batch_size, hidden_size)
print(f"初始隐藏状态 h0 的形状: {h0.shape}  # [num_layers, batch_size, hidden_size]")

# 前向传播
output, h_n = model(x, h0)
# output: 所有时间步的最后一层隐藏状态
# h_n: 所有层的最后一个时间步的隐藏状态

print(f"\n输出结果:")
print(f"output (所有时间步的最后一层隐藏状态) 的形状: {output.shape}  # [seq_len, batch_size, hidden_size]")
print(f"h_n (所有层的最后时间步隐藏状态) 的形状: {h_n.shape}  # [num_layers, batch_size, hidden_size]")

# 验证 h_n 与 output 的关系(修正后的逻辑)
print(f"\n验证 h_n 与 output 的关系:")
# 最后一层的最后状态应等于 output 的最后时间步
assert torch.allclose(h_n[-1], output[-1]), "最后一层的最后状态应等于output的最后时间步"
print(" 最后一层的最后状态与 output 的最后时间步相等")

# 打印第一层和第二层的最后隐藏状态
print(f"\n第一层的最后隐藏状态:")
print(h_n[0, 0, :5])  # 打印第一个样本的前5个元素
print(f"\n第二层的最后隐藏状态:")
print(h_n[1, 0, :5])  # 打印第一个样本的前5个元素

可以看到,nn.RNN默认会输出两个张量:一个是最后一个时间步的所有层,一个是最后一层的所有时间步。它是不会输出"所有时间步的所有层"的。

最后再给出与RNNCell部分类似的,一个完整的训练+测试的demo:

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim

# 配置
input_size = 4
hidden_size = 16
seq_len = 6
batch_size = 8
num_classes = 2
epochs = 30

# 模型定义
class RNNClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super().__init__()
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=False)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        # x: [seq_len, batch_size, input_size]
        output, h_n = self.rnn(x)  # h_n: [num_layers=1, batch_size, hidden_size]
        out = self.fc(h_n.squeeze(0))  # 使用最后一层的隐藏状态
        return out

# 数据生成逻辑不变
def generate_batch(batch_size, seq_len, input_size):
    x = torch.randn(seq_len, batch_size, input_size)
    last_step = x[-1]
    labels = (last_step[:, 0] > 0).long()
    return x, labels

# 初始化模型与训练配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = RNNClassifier(input_size, hidden_size, num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# 训练
for epoch in range(epochs):
    model.train()
    x_batch, y_batch = generate_batch(batch_size, seq_len, input_size)
    x_batch, y_batch = x_batch.to(device), y_batch.to(device)

    logits = model(x_batch)
    loss = criterion(logits, y_batch)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 5 == 0 or epoch == 0:
        pred = logits.argmax(dim=1)
        acc = (pred == y_batch).float().mean().item()
        print(f"[Epoch {epoch+1}] Loss: {loss.item():.4f}, Acc: {acc:.2f}")

# 测试
model.eval()
with torch.no_grad():
    x_test, y_test = generate_batch(1, seq_len, input_size)
    x_test, y_test = x_test.to(device), y_test.to(device)
    pred = model(x_test).argmax(dim=1)
    print("\nTest sample:")
    print("Target label:", y_test.item())
    print("Predicted   :", pred.item())
相关推荐
不会学习的小白O^O5 小时前
神经网络----卷积层(Conv2D)
人工智能·深度学习·神经网络
cosX+sinY6 小时前
10 卷积神经网络
python·深度学习·cnn
CodeShare7 小时前
多模态统一框架:基于下一帧预测的视频化方法
深度学习·计算机视觉·多模态学习
时序之心9 小时前
ICML 2025 | 深度剖析时序 Transformer:为何有效,瓶颈何在?
人工智能·深度学习·transformer
图灵学术计算机论文辅导10 小时前
提示+掩膜+注意力=Mamba三连击,跨模态任务全面超越
论文阅读·人工智能·经验分享·科技·深度学习·考研·计算机视觉
计算机科研圈11 小时前
不靠海量数据,精准喂养大模型!上交Data Whisperer:免训练数据选择法,10%数据逼近全量效果
人工智能·深度学习·机器学习·llm·ai编程
大千AI助手11 小时前
FEVER数据集:事实验证任务的大规模基准与评估框架
人工智能·深度学习·数据集·fever·事实验证·事实抽取·虚假信息
格林威12 小时前
Baumer工业相机堡盟工业相机如何通过YoloV8深度学习模型实现道路汽车的检测识别(C#代码,UI界面版)
人工智能·深度学习·数码相机·yolo·视觉检测
8Qi813 小时前
深度学习(鱼书)day08--误差反向传播(后三节)
人工智能·python·深度学习·神经网络
wow_DG13 小时前
【PyTorch✨】01 初识PyTorch
人工智能·pytorch·python