PyTorch-----torch.flatten()函数

torch.flatten() 是 PyTorch 中的一个函数,用于将输入张量展平为一维张量。它的语法如下:

python 复制代码
torch.flatten(input, start_dim=0, end_dim=-1)
  • input:要展平的输入张量。
  • start_dim(可选):指定从哪个维度开始展平。默认为 0。
  • end_dim(可选):指定从哪个维度结束展平。默认为 -1,表示最后一个维度。

torch.flatten() 函数会将输入张量的指定维度范围内的所有元素展平到一个一维张量中。展平后的张量保持与原始张量相同的数据顺序。例如,如果输入张量是一个 3x4x5 的三维张量,然后你使用 torch.flatten() 函数将它展平,那么结果将是一个包含 60 个元素的一维张量,其中包含原始张量中所有的元素。

以下是一个示例:

python 复制代码
import torch

# 创建一个3x4x5的张量
input_tensor = torch.randn(3, 4, 5)

# 使用torch.flatten()将其展平为一维张量
output_tensor = torch.flatten(input_tensor)

print(output_tensor.size())  # 输出 torch.Size([60])

在此示例中,input_tensor 是一个形状为 (3, 4, 5) 的三维张量,使用 torch.flatten() 函数将其展平为一个一维张量,并打印出了结果张量的大小。

示例:

python 复制代码
import torch

# 创建一个2×3x5x5的张量
input_tensor = torch.randn(2, 3, 5, 5)
print(f"原张量的尺寸为:{input_tensor.size()}") # torch.Size([2, 3, 5, 5])

# 使用torch.flatten()从第一个维度开始展平,从第二个维度结束展平
output_tensor = torch.flatten(input_tensor, start_dim=1, end_dim=2)
print(f"经过展平后的张量的尺寸为:{output_tensor.size()}")  # torch.Size([2, 15, 5])
相关推荐
deephub1 分钟前
torch.compile 加速原理:kernel 融合与缓冲区复用
人工智能·pytorch·深度学习·神经网络
ydl11281 分钟前
解码AI大模型:从神经网络到落地应用的全景探索
人工智能·深度学习·神经网络
小程故事多_802 分钟前
Elasticsearch ES 分词与关键词匹配技术方案解析
大数据·人工智能·elasticsearch·搜索引擎·aigc
yuanyuan2o23 分钟前
【深度学习】ResNet
人工智能·深度学习
HyperAI超神经4 分钟前
覆盖天体物理/地球科学/流变学/声学等19种场景,Polymathic AI构建1.3B模型实现精确连续介质仿真
人工智能·深度学习·学习·算法·机器学习·ai编程·vllm
小陈phd11 分钟前
系统测试与落地优化:问题案例、性能调优与扩展方向
人工智能·自然语言处理
模型时代13 分钟前
伯明翰Oracle项目遭遇数据清洗难题和资源短缺困境
人工智能
大黄说说13 分钟前
TensorRTSharp 实战指南:用 C# 驱动 GPU,实现毫秒级 AI 推理
开发语言·人工智能·c#
执着25914 分钟前
力扣hot100 - 144、二叉树的前序遍历
数据结构·算法·leetcode
范纹杉想快点毕业17 分钟前
嵌入式系统架构之道:告别“意大利面条”,拥抱状态机与事件驱动
java·开发语言·c++·嵌入式硬件·算法·架构·mfc