CNN记录】pytorch中flatten函数

pytorch原型

python 复制代码
torch.flatten(input, start_dim=0, end_dim=- 1)

作用:将连续的维度范围展平维张量,一般写再某个nn后用于对输出处理,

参数:

start_dim:开始的维度

end_dim:终止的维度,-1为最后一个轴

默认值时展平为1维

例子

1、默认参数

python 复制代码
input = torch.randn(2, 3, 4, 5)
output = torch.flatten(input)
输出维:torch.Size([120])

2、设置参数

python 复制代码
input = torch.randn(2, 3, 4, 5)

output = torch.flatten(input,1)
输出shape为:torch.Size([2, 60])

output = torch.flatten(input,1,2)
输出shape为:torch.Size([2, 12, 5])
相关推荐
老毛肚4 小时前
卷积神经网络CNN
人工智能·深度学习·cnn
一个小猴子`7 小时前
Pytorch快速复习
人工智能·pytorch·python
zh路西法12 小时前
【Qwen2.5本地部署】超简单pytorch-gpu部署教程
人工智能·pytorch·python
凯瑟琳.奥古斯特13 小时前
PyTorch动态计算图详解
人工智能·pytorch·python·深度学习
dfsj6601113 小时前
第五章:卷积神经网络
人工智能·神经网络·cnn
AI街潜水的八角15 小时前
PyTorch框架——基于深度学习SRN-DeblurNet神经网络AI去模糊图像增强系统
人工智能·pytorch·深度学习
lsjweiyi16 小时前
WSL2 + ROCm + PyTorch 深度学习环境配置全记录
人工智能·pytorch·深度学习
吾辈亦有感16 小时前
【动手学大语言模型】神经网络启蒙:PyTorch 入门实战
人工智能·pytorch·大语言模型
张二娃同学16 小时前
第03篇_CNN图像识别入门
人工智能·python·神经网络·cnn
凯瑟琳.奥古斯特16 小时前
深度学习入门:用PyTorch实现MNIST手写数字识别
pytorch·python·深度学习