PyTorch-----torch.flatten()函数

torch.flatten() 是 PyTorch 中的一个函数,用于将输入张量展平为一维张量。它的语法如下:

python 复制代码
torch.flatten(input, start_dim=0, end_dim=-1)
  • input:要展平的输入张量。
  • start_dim(可选):指定从哪个维度开始展平。默认为 0。
  • end_dim(可选):指定从哪个维度结束展平。默认为 -1,表示最后一个维度。

torch.flatten() 函数会将输入张量的指定维度范围内的所有元素展平到一个一维张量中。展平后的张量保持与原始张量相同的数据顺序。例如,如果输入张量是一个 3x4x5 的三维张量,然后你使用 torch.flatten() 函数将它展平,那么结果将是一个包含 60 个元素的一维张量,其中包含原始张量中所有的元素。

以下是一个示例:

python 复制代码
import torch

# 创建一个3x4x5的张量
input_tensor = torch.randn(3, 4, 5)

# 使用torch.flatten()将其展平为一维张量
output_tensor = torch.flatten(input_tensor)

print(output_tensor.size())  # 输出 torch.Size([60])

在此示例中,input_tensor 是一个形状为 (3, 4, 5) 的三维张量,使用 torch.flatten() 函数将其展平为一个一维张量,并打印出了结果张量的大小。

示例:

python 复制代码
import torch

# 创建一个2×3x5x5的张量
input_tensor = torch.randn(2, 3, 5, 5)
print(f"原张量的尺寸为:{input_tensor.size()}") # torch.Size([2, 3, 5, 5])

# 使用torch.flatten()从第一个维度开始展平,从第二个维度结束展平
output_tensor = torch.flatten(input_tensor, start_dim=1, end_dim=2)
print(f"经过展平后的张量的尺寸为:{output_tensor.size()}")  # torch.Size([2, 15, 5])
相关推荐
梁辰兴40 分钟前
数据结构:排序
数据结构·算法·排序算法·c·插入排序·排序·交换排序
野犬寒鸦1 小时前
力扣hot100:搜索二维矩阵 II(常见误区与高效解法详解)(240)
java·数据结构·算法·leetcode·面试
hundaxxx1 小时前
自演化大语言模型的技术背景
人工智能
菜鸟得菜1 小时前
leecode kadane算法 解决数组中子数组的最大和,以及环形数组连续子数组的最大和问题
数据结构·算法·leetcode
数智顾问1 小时前
【73页PPT】美的简单高效的管理逻辑(附下载方式)
大数据·人工智能·产品运营
love530love1 小时前
【保姆级教程】阿里 Wan2.1-T2V-14B 模型本地部署全流程:从环境配置到视频生成(附避坑指南)
人工智能·windows·python·开源·大模型·github·音视频
木头左1 小时前
结合机器学习的Backtrader跨市场交易策略研究
人工智能·机器学习·kotlin
Coovally AI模型快速验证2 小时前
3D目标跟踪重磅突破!TrackAny3D实现「类别无关」统一建模,多项SOTA达成!
人工智能·yolo·机器学习·3d·目标跟踪·无人机·cocos2d
研梦非凡2 小时前
CVPR 2025|基于粗略边界框监督的3D实例分割
人工智能·计算机网络·计算机视觉·3d
MiaoChuAI2 小时前
秒出PPT vs 豆包AI PPT:实测哪款更好用?
人工智能·powerpoint