PyTorch-----torch.flatten()函数

torch.flatten() 是 PyTorch 中的一个函数,用于将输入张量展平为一维张量。它的语法如下:

python 复制代码
torch.flatten(input, start_dim=0, end_dim=-1)
  • input:要展平的输入张量。
  • start_dim(可选):指定从哪个维度开始展平。默认为 0。
  • end_dim(可选):指定从哪个维度结束展平。默认为 -1,表示最后一个维度。

torch.flatten() 函数会将输入张量的指定维度范围内的所有元素展平到一个一维张量中。展平后的张量保持与原始张量相同的数据顺序。例如,如果输入张量是一个 3x4x5 的三维张量,然后你使用 torch.flatten() 函数将它展平,那么结果将是一个包含 60 个元素的一维张量,其中包含原始张量中所有的元素。

以下是一个示例:

python 复制代码
import torch

# 创建一个3x4x5的张量
input_tensor = torch.randn(3, 4, 5)

# 使用torch.flatten()将其展平为一维张量
output_tensor = torch.flatten(input_tensor)

print(output_tensor.size())  # 输出 torch.Size([60])

在此示例中,input_tensor 是一个形状为 (3, 4, 5) 的三维张量,使用 torch.flatten() 函数将其展平为一个一维张量,并打印出了结果张量的大小。

示例:

python 复制代码
import torch

# 创建一个2×3x5x5的张量
input_tensor = torch.randn(2, 3, 5, 5)
print(f"原张量的尺寸为:{input_tensor.size()}") # torch.Size([2, 3, 5, 5])

# 使用torch.flatten()从第一个维度开始展平,从第二个维度结束展平
output_tensor = torch.flatten(input_tensor, start_dim=1, end_dim=2)
print(f"经过展平后的张量的尺寸为:{output_tensor.size()}")  # torch.Size([2, 15, 5])
相关推荐
tankeven3 分钟前
HJ176 【模板】滑动窗口
c++·算法
慧知AI21 分钟前
Kimi 2.6 技术深度解析:5秒响应背后的架构突破
人工智能
卷卷说风控31 分钟前
单独一个工具再强,不如一套工具链协同|卷卷养虾记 · 十二篇
人工智能
黑金IT37 分钟前
vLLM本地缓存实战,重复提交直接复用不浪费算力
人工智能·缓存
七七powerful39 分钟前
运维养龙虾--Tmux 终端复用器完全指南:从入门到 AI Agent 远程操控
运维·服务器·人工智能
网域小星球43 分钟前
C 语言从 0 入门(十二)|指针与数组:数组名本质、指针遍历数组
c语言·算法·指针·数组·指针遍历数组
七夜zippoe43 分钟前
OpenClaw 飞书深度集成:文档操作
人工智能·飞书·集成·文档·openclaw
databook44 分钟前
从写代码到问问题:2026年,AI如何重构数据科学工作流
人工智能·后端·数据分析
深山技术宅1 小时前
OpenClaw 系统架构深度解析
人工智能·ai·系统架构·openclaw
skilllite作者1 小时前
AI 自进化系统架构详解 (一):重新定义 L1-L3 等级,揭秘 OpenClaw 背后的安全边界
人工智能·安全·系统架构