传统RNN模型笔记:输入数据长度变化的结构解析

一、案例背景

本案例通过PyTorch的nn.RNN构建单隐藏层RNN模型,重点展示RNN对变长序列数据的处理能力(序列长度从1变为20),帮助理解RNN的输入输出逻辑。

二、核心代码与结构拆解

python 复制代码
def dm_rnn_for_sequencelen():
    # 1. 定义RNN模型
    rnn = nn.RNN(5, 6, 1)  # input_size=5, hidden_size=6, num_layers=1
    
    # 2. 准备输入数据
    input = torch.randn(20, 3, 5)  # 序列长度=20,批次大小=3,输入维度=5
    
    # 3. 初始化隐状态
    h0 = torch.randn(1, 3, 6)  # 层数×方向=1,批次大小=3,隐藏层维度=6
    
    # 4. 前向传播
    output, hn = rnn(input, h0)
    
    # 输出结果
    print('output形状--->', output.shape)  # torch.Size([20, 3, 6])
    print('hn形状--->', hn.shape)          # torch.Size([1, 3, 6])
    print('模型结构--->', rnn)             # RNN(5, 6)

三、关键参数详解

1. 模型定义参数(nn.RNN

参数 含义 本案例取值 说明
input_size 输入特征维度 5 每个时间步的输入向量维度(如单词的 embedding 维度)
hidden_size 隐藏层输出维度 6 每个时间步的隐状态向量维度
num_layers 隐藏层层数 1 单隐藏层结构,简化计算

2. 输入数据格式(input

  • 形状:[sequence_length, batch_size, input_size]
  • 本案例:[20, 3, 5]
    • 20序列长度(sequence_length),每个样本包含20个时间步(如一句话有20个单词);
    • 3批次大小(batch_size),一次并行处理3个样本;
    • 5输入特征维度 ,与模型定义的input_size一致。

3. 初始隐状态(h0

  • 形状:[num_layers × num_directions, batch_size, hidden_size]
  • 本案例:[1, 3, 6]
    • 1num_layers × num_directions(1层+单向RNN);
    • 3:与输入的batch_size一致,每个样本对应一个初始隐状态;
    • 6:与模型定义的hidden_size一致,初始隐状态的维度。

四、输出结果解析

1. output(所有时间步的隐藏层输出)

  • 形状:[sequence_length, batch_size, hidden_size]
  • 本案例:[20, 3, 6]
    • 包含每个时间步、每个样本的隐藏层输出(20个时间步×3个样本×6维向量);
    • 体现RNN对序列的"逐步处理"特性,保留所有中间结果。

2. hn(最后一个时间步的隐状态)

  • 形状:[num_layers × num_directions, batch_size, hidden_size]
  • 本案例:[1, 3, 6]
    • 仅包含最后一个时间步(第20步)、每个样本的隐状态;
    • 因单隐藏层,hnoutput的最后一个时间步结果完全一致。

五、核心结论:RNN对变长序列的适应性

  • 序列长度可灵活变化 :只要输入特征维度(input_size)和批次大小(batch_size)不变,RNN可处理任意长度的序列(如示例1中长度=1,本案例中长度=20)。
  • 输出形状随序列长度调整output的第一个维度始终等于输入序列长度,体现RNN对时序数据的动态处理能力。

六、类比理解

将RNN比作"逐字阅读的处理器":

  • 输入:3篇文章(batch_size=3),每篇20个单词(sequence_length=20),每个单词用5维向量表示(input_size=5);
  • 处理过程:每读一个单词(时间步),结合上一步的记忆(隐状态),更新当前记忆(6维向量,hidden_size=6);
  • 输出:output是每读一个单词时的记忆记录,hn是读完最后一个单词的最终记忆。
相关推荐
墨北小七25 分钟前
小说大模型---全连接神经网络-大模型中真正的“守门人”
深度学习·神经网络
CheerWWW33 分钟前
C++学习笔记——栈内存与堆内存、宏、auto、std::array
c++·笔记·学习
SLAM必须dunk41 分钟前
四足强化入门3---Robot Lab重点机器人配置,训练和调参
人工智能·深度学习·机器学习·机器人
shy^-^cky42 分钟前
[特殊字符] Roberts、Sobel、Prewitt 边缘检测算子全对比
深度学习·图像分割·边缘检测·sobel·roberts·边缘检测算子·prewitt
AI医影跨模态组学42 分钟前
ESMO Open 中国医学科学院肿瘤医院:整合影像组学、病理组学和活检适应性免疫评分预测局部晚期直肠癌远处转移
人工智能·深度学习·机器学习·论文·医学·医学影像
-许平安-1 小时前
MCP项目笔记十(客户端 MCPClient)
c++·笔记·ai·raii·mcp·pluginapi·plugin system
一只旭宝1 小时前
【C++ 入门精讲2】函数重载、默认参数、函数指针、volatile | 手写笔记(附完整代码)
c++·笔记
jay神1 小时前
大米杂质检测数据集(YOLO格式)
人工智能·深度学习·yolo·目标检测·毕业设计
John.Lewis1 小时前
C++进阶(8)智能指针
开发语言·c++·笔记
薛定e的猫咪2 小时前
【Neural Networks 2025】TDAG 论文解读:多智能体不是重点,动态任务分解才是关键
人工智能·深度学习·计算机视觉