CNN记录】pytorch中flatten函数

pytorch原型

python 复制代码
torch.flatten(input, start_dim=0, end_dim=- 1)

作用:将连续的维度范围展平维张量,一般写再某个nn后用于对输出处理,

参数:

start_dim:开始的维度

end_dim:终止的维度,-1为最后一个轴

默认值时展平为1维

例子

1、默认参数

python 复制代码
input = torch.randn(2, 3, 4, 5)
output = torch.flatten(input)
输出维:torch.Size([120])

2、设置参数

python 复制代码
input = torch.randn(2, 3, 4, 5)

output = torch.flatten(input,1)
输出shape为:torch.Size([2, 60])

output = torch.flatten(input,1,2)
输出shape为:torch.Size([2, 12, 5])
相关推荐
shangjian0072 小时前
AI大模型-深度学习-卷积神经网络CNN
人工智能·神经网络·cnn
Pyeako3 小时前
深度学习--PyTorch框架&优化器&激活函数
人工智能·pytorch·python·深度学习·优化器·激活函数·梯度爆炸与消失
Caesar Zou4 小时前
torchcodec is not available问题
人工智能·pytorch·深度学习·神经网络
翱翔的苍鹰4 小时前
循环神经网络-RNN和简单的例子
人工智能·pytorch·rnn·深度学习·神经网络·transformer·word2vec
爱吃肉的鹏4 小时前
树莓派4B安装pytorch
人工智能·pytorch·python
技术小黑4 小时前
TensorFlow学习系列03 | 实现天气识别
人工智能·cnn·tensorflow
MistaCloud4 小时前
Pytorch进阶训练技巧(二)之梯度层面的优化策略
人工智能·pytorch·python·深度学习
啊阿狸不会拉杆5 小时前
《机器学习》第 8 章 - 常用深度网络模型
网络·人工智能·深度学习·机器学习·ai·cnn·ml
毕不了业的硏䆒僧5 小时前
NVIDIA DGX Spark | Ubuntu cuda13.0安装Pytorch GPU版本
pytorch·ubuntu·spark
shangjian0076 小时前
AI大模型-深度学习-卷积神经网络-残差网络
人工智能·深度学习·cnn