CNN记录】pytorch中flatten函数

pytorch原型

python 复制代码
torch.flatten(input, start_dim=0, end_dim=- 1)

作用:将连续的维度范围展平维张量,一般写再某个nn后用于对输出处理,

参数:

start_dim:开始的维度

end_dim:终止的维度,-1为最后一个轴

默认值时展平为1维

例子

1、默认参数

python 复制代码
input = torch.randn(2, 3, 4, 5)
output = torch.flatten(input)
输出维:torch.Size([120])

2、设置参数

python 复制代码
input = torch.randn(2, 3, 4, 5)

output = torch.flatten(input,1)
输出shape为:torch.Size([2, 60])

output = torch.flatten(input,1,2)
输出shape为:torch.Size([2, 12, 5])
相关推荐
studytosky26 分钟前
深度学习理论与实战:Pytorch基础入门
人工智能·pytorch·python·深度学习·机器学习
好多渔鱼好多1 小时前
【AI大模型】PyTorch 介绍
pytorch
袁气满满~_~1 小时前
Ubuntu下配置PyTorch
linux·pytorch·ubuntu
ins_lizhiming2 小时前
在华为910B GPU服务器上运行DeepSeek-R1-0528模型
人工智能·pytorch·python·华为
吃个糖糖5 小时前
pytorch 卷积操作
人工智能·pytorch·python
Dr.Kun8 小时前
【鲲码园Python】基于pytorch的蘑菇分类系统(9类)
pytorch·python·分类
Yongqiang Cheng8 小时前
Gradient Accumulation (梯度累积 / 梯度累加) in PyTorch
pytorch·梯度累积·gradient·accumulation·梯度累加
老鱼说AI8 小时前
PyTorch 深度强化学习实战:从零手写 PPO 算法训练你的月球着陆器智能体
人工智能·pytorch·深度学习·机器学习·计算机视觉·分类·回归
西猫雷婶9 小时前
CNN全连接层
人工智能·pytorch·python·深度学习·神经网络·机器学习·cnn
vvoennvv10 小时前
【Python TensorFlow】CNN-LSTM时序预测 卷积神经网络-长短期记忆神经网络组合模型时序预测算法(附代码)
python·神经网络·cnn·tensorflow·lstm