CNN记录】pytorch中flatten函数

pytorch原型

python 复制代码
torch.flatten(input, start_dim=0, end_dim=- 1)

作用:将连续的维度范围展平维张量,一般写再某个nn后用于对输出处理,

参数:

start_dim:开始的维度

end_dim:终止的维度,-1为最后一个轴

默认值时展平为1维

例子

1、默认参数

python 复制代码
input = torch.randn(2, 3, 4, 5)
output = torch.flatten(input)
输出维:torch.Size([120])

2、设置参数

python 复制代码
input = torch.randn(2, 3, 4, 5)

output = torch.flatten(input,1)
输出shape为:torch.Size([2, 60])

output = torch.flatten(input,1,2)
输出shape为:torch.Size([2, 12, 5])
相关推荐
兔子不吃草~3 小时前
Transformer学习记录与CNN思考
学习·cnn·transformer
蒋星熠5 小时前
如何在Anaconda中配置你的CUDA & Pytorch & cuNN环境(2025最新教程)
开发语言·人工智能·pytorch·python·深度学习·机器学习·ai
weiwei228446 小时前
Torch核心数据结构Tensor(张量)
pytorch·tensor
wL魔法师13 小时前
【LLM】大模型训练中的稳定性问题
人工智能·pytorch·深度学习·llm
技术小黑19 小时前
Transformer系列 | Pytorch复现Transformer
pytorch·深度学习·transformer
DogDaoDao20 小时前
神经网络稀疏化设计构架方法和原理深度解析
人工智能·pytorch·深度学习·神经网络·大模型·剪枝·网络稀疏
西猫雷婶21 小时前
pytorch基本运算-Python控制流梯度运算
人工智能·pytorch·python·深度学习·神经网络·机器学习
似乎很简单1 天前
卷积神经网络(CNN)
深度学习·神经网络·cnn
ACEEE12221 天前
Stanford CS336 | Assignment 2 - FlashAttention-v2 Pytorch & Triotn实现
人工智能·pytorch·python·深度学习·机器学习·nlp·transformer
深耕AI2 天前
【PyTorch训练】准确率计算(代码片段拆解)
人工智能·pytorch·python