pytorch 中RNN接口参数

torch中RNN模块详细接口参数解析

rnn = torch.nn.RNN(
    input_size: int,
    hidden_size: int,
    num_layers: int = 1,
    nonlinearity: str = 'tanh',
    bias: bool = True,
    batch_first: bool = False,
    dropout: float = 0.0,
    bidirectional: bool = False,
)

input_size (int):输入序列中每个时间步的特征维度。

hidden_size (int):隐藏状态(记忆单元)的维度。

num_layers (int, 默认为1):RNN 层的堆叠数量。

nonlinearity (str, 默认为'tanh'):激活函数的选择,可以是 'tanh' 或 'relu'。不过在标准 RNN 中通常使用 'tanh'。

bias (bool, 默认为True):是否在计算中包含偏置项。

batch_first (bool, 默认为False):如果设为 True,则输入和输出张量的第一个维度将被视为批次大小,而不是时间步长。即数据格式为 (batch_size, seq_len, input_size) 而不是 (seq_len, batch, input_size)。

dropout (float, 默认为0.0):应用于隐层到隐层之间的失活率,用于正则化以防止过拟合。只有当 num_layers > 1 时才会生效。

bidirectional (bool, 默认为False):若设置为 True,将会创建一个双向 RNN,这样模型可以同时处理过去和未来的上下文信息。

注意torch.nn.RNN 本身并不直接支持双向模式;要实现双向RNN,应使用 torch.nn.Bidirectional 包装器包裹一个单向RNN。

outputs, hn = rnn(...)

outputs: Tensor 如果batch_first=True,则为则为 (batch_size, seq_len, num_directions * hidden_size)。否则 (seq_len, batch_size, num_directions * hidden_size);

RNN 对输入序列每个时间步的输出。对于双向 RNN,num_directions 为2,输出是正向和反向隐藏状态的串联或拼接结果。

hn: Tensor (h_n 或 hidden):形状(num_layers * num_directions, batch_size, hidden_size)

最后一个时间步的隐藏状态(或者在双向情况下,正向和反向隐藏状态)。等价 output[:, -1, :]

实例化一个单向的RNN单元

python 复制代码
import torch.nn as nn
import torch

batch_size = 2
seq_len = 7
input_size = 5
hidden_size = 3
num_layers = 1

rnn = nn.RNN(input_size, hidden_size, num_layers, nonlinearity="tanh", batch_first=True)

# (batch_size, seq_len, input_size) 来两个样本为一批,每个样本在时序上分7步,每一步的维度是5
input = torch.randn(batch_size, seq_len, input_size)

# (num_layers, batch_size, hidden_size) torch源码默认全零,建议使用默认值
h0 = torch.randn(1, 2, 3)

# output = (batch_size, seq_len, hidden_size)
# hn = (num_layers, batch_size, hidden_size)
# hn 每一个样本最后一步的信息,等价 output[:,-1,:]
# !!注意此处变量的维度大小都是基于本例计算的,并不是实际计算公式!!
output, hn = rnn(input, h0)
# print(output)
print(output.shape)  # torch.Size([2, 7, 3])
# print(hn)
print(hn.shape)  # torch.Size([1, 2, 3])
# print(output[:, -1, :])

为了方便复习现将源码参数说明附在这里

class RNN(RNNBase):
r"""Applies a multi-layer Elman RNN with :math:`\tanh` or :math:`\text{ReLU}` non-linearity to an
input sequence.
For each element in the input sequence, each layer computes the following
function:
.. math::
    h_t = \tanh(x_t W_{ih}^T + b_{ih} + h_{t-1}W_{hh}^T + b_{hh})
where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is
the input at time `t`, and :math:`h_{(t-1)}` is the hidden state of the
previous layer at time `t-1` or the initial hidden state at time `0`.
If :attr:`nonlinearity` is ``'relu'``, then :math:`\text{ReLU}` is used instead of :math:`\tanh`.

Args:
    input_size: The number of expected features in the input `x`
    hidden_size: The number of features in the hidden state `h`
    num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
        would mean stacking two RNNs together to form a `stacked RNN`,
        with the second RNN taking in outputs of the first RNN and
        computing the final results. Default: 1
    nonlinearity: The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'``
    bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
        Default: ``True``
    batch_first: If ``True``, then the input and output tensors are provided
        as `(batch, seq, feature)` instead of `(seq, batch, feature)`.
        Note that this does not apply to hidden or cell states. See the
        Inputs/Outputs sections below for details.  Default: ``False``
    dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
        RNN layer except the last layer, with dropout probability equal to
        :attr:`dropout`. Default: 0
    bidirectional: If ``True``, becomes a bidirectional RNN. Default: ``False``

Inputs: input, h_0
    * **input**: tensor of shape :math:`(L, H_{in})` for unbatched input,
      :math:`(L, N, H_{in})` when ``batch_first=False`` or
      :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of
      the input sequence.  The input can also be a packed variable length sequence.
      See :func:`torch.nn.utils.rnn.pack_padded_sequence` or
      :func:`torch.nn.utils.rnn.pack_sequence` for details.
    * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or
      :math:`(D * \text{num\_layers}, N, H_{out})` containing the initial hidden
      state for the input sequence batch. Defaults to zeros if not provided.

    where:

    .. math::
        \begin{aligned}
            N ={} & \text{batch size} \\
            L ={} & \text{sequence length} \\
            D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\
            H_{in} ={} & \text{input\_size} \\
            H_{out} ={} & \text{hidden\_size}
        \end{aligned}

Outputs: output, h_n
    * **output**: tensor of shape :math:`(L, D * H_{out})` for unbatched input,
      :math:`(L, N, D * H_{out})` when ``batch_first=False`` or
      :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features
      `(h_t)` from the last layer of the RNN, for each `t`. If a
      :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output
      will also be a packed sequence.
    * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or
      :math:`(D * \text{num\_layers}, N, H_{out})` containing the final hidden state
      for each element in the batch.

参考 torch.nn.RNN 源码接口文档。

相关推荐
千天夜34 分钟前
深度学习中的残差网络、加权残差连接(WRC)与跨阶段部分连接(CSP)详解
网络·人工智能·深度学习·神经网络·yolo·机器学习
一勺汤35 分钟前
YOLOv8模型改进 第二十五讲 添加基于卷积调制(Convolution based Attention) 替换自注意力机制
深度学习·yolo·计算机视觉·模块·yolov8·yolov8改进·魔改
Jamence3 小时前
【深度学习数学知识】-贝叶斯公式
人工智能·深度学习·概率论
feifeikon3 小时前
机器学习DAY4续:梯度提升与 XGBoost (完)
人工智能·深度学习·机器学习
深度学习机器3 小时前
LangGraph:基于图结构的大模型智能体开发框架
人工智能·python·深度学习
取个名字真难呐3 小时前
LossMaskMatrix损失函数掩码矩阵
python·深度学习·矩阵
Wishell20154 小时前
为什么深度学习和神经网络要使用 GPU?
pytorch
盼小辉丶4 小时前
TensorFlow深度学习实战(2)——使用TensorFlow构建神经网络
深度学习·神经网络·tensorflow
起名字什么的好难4 小时前
conda虚拟环境安装pytorch gpu版
人工智能·pytorch·conda
18号房客4 小时前
计算机视觉-人工智能(AI)入门教程一
人工智能·深度学习·opencv·机器学习·计算机视觉·数据挖掘·语音识别