Pytorch torch.nn.utils.rnn.pad_sequence 介绍

torch.nn.utils.rnn.pad_sequence 是 PyTorch 中一个用于填充序列的实用函数,它主要用于处理长度不一的序列数据,将这些序列填充到相同的长度,以便能将它们组合成一个批量(batch)输入到神经网络中。以下是详细介绍:

函数定义

复制代码
torch.nn.utils.rnn.pad_sequence(sequences, batch_first=False, padding_value=0.0)

参数解释

  • sequences :这是一个必需的参数,是一个由 torch.Tensor 组成的列表,列表中的每个 Tensor 代表一个序列。这些序列的长度可以不同,但其他维度的大小必须一致。
  • batch_first :这是一个布尔类型的可选参数,默认值为 False。当 batch_firstFalse 时,输出的 Tensor 的形状为 (max_seq_length, batch_size, ...);当 batch_firstTrue 时,输出的 Tensor 的形状为 (batch_size, max_seq_length, ...)
  • padding_value :这是一个可选参数,默认值为 0.0。它指定了用于填充序列的数值。

返回值

返回一个填充后的 torch.Tensor,其形状根据 batch_first 参数的值而定。

使用场景

在自然语言处理(NLP)、语音识别等领域,输入的序列数据(如句子、语音片段)长度通常是不一致的。在将这些数据输入到神经网络之前,需要将它们填充到相同的长度,以便进行批量处理。torch.nn.utils.rnn.pad_sequence 就是为解决这个问题而设计的。

示例代码

复制代码
import torch
from torch.nn.utils.rnn import pad_sequence

# 创建长度不同的序列
seq1 = torch.tensor([1, 2, 3])
seq2 = torch.tensor([4, 5])
seq3 = torch.tensor([6])

# 将序列放入列表中
sequences = [seq1, seq2, seq3]

# 填充序列,batch_first 为 False
padded_seq_false = pad_sequence(sequences, batch_first=False, padding_value=0)
print("batch_first=False 时的填充结果:")
print(padded_seq_false)
print("形状:", padded_seq_false.shape)

# 填充序列,batch_first 为 True
padded_seq_true = pad_sequence(sequences, batch_first=True, padding_value=0)
print("batch_first=True 时的填充结果:")
print(padded_seq_true)
print("形状:", padded_seq_true.shape)

在这个示例中,我们创建了三个长度不同的序列,然后使用 pad_sequence 函数将它们填充到相同的长度。通过设置 batch_first 参数为 FalseTrue,我们可以看到输出的 Tensor 形状的变化。

通过使用 torch.nn.utils.rnn.pad_sequence 函数,你可以方便地处理长度不一致的序列数据,将它们填充到相同的长度,以便进行批量处理。

相关推荐
我的xiaodoujiao13 小时前
使用 Python 语言 从 0 到 1 搭建完整 Web UI自动化测试学习系列 38--Allure 测试报告
python·学习·测试工具·pytest
小鸡吃米…18 小时前
机器学习 - K - 中心聚类
人工智能·机器学习·聚类
好奇龙猫19 小时前
【AI学习-comfyUI学习-第三十节-第三十一节-FLUX-SD放大工作流+FLUX图生图工作流-各个部分学习】
人工智能·学习
沈浩(种子思维作者)19 小时前
真的能精准医疗吗?癌症能提前发现吗?
人工智能·python·网络安全·健康医疗·量子计算
minhuan19 小时前
大模型应用:大模型越大越好?模型参数量与效果的边际效益分析.51
人工智能·大模型参数评估·边际效益分析·大模型参数选择
Cherry的跨界思维19 小时前
28、AI测试环境搭建与全栈工具实战:从本地到云平台的完整指南
java·人工智能·vue3·ai测试·ai全栈·测试全栈·ai测试全栈
MM_MS19 小时前
Halcon变量控制类型、数据类型转换、字符串格式化、元组操作
开发语言·人工智能·深度学习·算法·目标检测·计算机视觉·视觉检测
ASF1231415sd20 小时前
【基于YOLOv10n-CSP-PTB的大豆花朵检测与识别系统详解】
人工智能·yolo·目标跟踪
njsgcs20 小时前
ue python二次开发启动教程+ 导入fbx到指定文件夹
开发语言·python·unreal engine·ue
io_T_T20 小时前
迭代器 iteration、iter 与 多线程 concurrent 交叉实践(详细)
python