PyTorch-----torch.flatten()函数

torch.flatten() 是 PyTorch 中的一个函数,用于将输入张量展平为一维张量。它的语法如下:

python 复制代码
torch.flatten(input, start_dim=0, end_dim=-1)
  • input:要展平的输入张量。
  • start_dim(可选):指定从哪个维度开始展平。默认为 0。
  • end_dim(可选):指定从哪个维度结束展平。默认为 -1,表示最后一个维度。

torch.flatten() 函数会将输入张量的指定维度范围内的所有元素展平到一个一维张量中。展平后的张量保持与原始张量相同的数据顺序。例如,如果输入张量是一个 3x4x5 的三维张量,然后你使用 torch.flatten() 函数将它展平,那么结果将是一个包含 60 个元素的一维张量,其中包含原始张量中所有的元素。

以下是一个示例:

python 复制代码
import torch

# 创建一个3x4x5的张量
input_tensor = torch.randn(3, 4, 5)

# 使用torch.flatten()将其展平为一维张量
output_tensor = torch.flatten(input_tensor)

print(output_tensor.size())  # 输出 torch.Size([60])

在此示例中,input_tensor 是一个形状为 (3, 4, 5) 的三维张量,使用 torch.flatten() 函数将其展平为一个一维张量,并打印出了结果张量的大小。

示例:

python 复制代码
import torch

# 创建一个2×3x5x5的张量
input_tensor = torch.randn(2, 3, 5, 5)
print(f"原张量的尺寸为:{input_tensor.size()}") # torch.Size([2, 3, 5, 5])

# 使用torch.flatten()从第一个维度开始展平,从第二个维度结束展平
output_tensor = torch.flatten(input_tensor, start_dim=1, end_dim=2)
print(f"经过展平后的张量的尺寸为:{output_tensor.size()}")  # torch.Size([2, 15, 5])
相关推荐
多巴胺与内啡肽.24 分钟前
OpenCV进阶操作:人脸检测、微笑检测
人工智能·opencv·计算机视觉
小羊在奋斗25 分钟前
【LeetCode 热题 100】反转链表 / 回文链表 / 有序链表转换二叉搜索树 / LRU 缓存
算法·leetcode·链表
Wnq1007227 分钟前
基于 NanoDet 的工厂巡检机器人目标识别系统研究与实现
人工智能·机器学习·计算机视觉·目标跟踪·机器人·巡检机器人
一年春又来34 分钟前
AI-02a5a6.神经网络-与学习相关的技巧-批量归一化
人工智能·神经网络·学习
爱上彩虹c35 分钟前
LeetCode Hot100 (1/100)
算法·leetcode·职场和发展
kovlistudio39 分钟前
机器学习第十讲:异常值检测 → 发现身高填3米的不合理数据
人工智能·机器学习
小陈的进阶之路42 分钟前
计算机大类专业数据结构下半期实验练习题
数据结构·算法·深度优先
瑞雪兆丰年兮43 分钟前
数学实验(Matlab符号运算)
开发语言·算法·matlab·数学实验
马拉AI43 分钟前
解锁Nature发文小Tips:LSTM、CNN与Attention的创新融合之路
人工智能·cnn·lstm
sufu106544 分钟前
SpringAI更新:废弃tools方法、正式支持DeepSeek!
人工智能·后端