探索flatten的其他参数用法及对报错异常进行修正

1 问题

  1. 对flatten的其他参数的用法进行进一步了解

  2. 探索torch.flatten()与torch.nn.flatten()的区别

  3. 在代码调式时,遇到报错问题,进行修正并了解原因。

2 方法

  1. flatten的中文含义为"扁平化",具体怎么理解呢?假设你的数据为1维数据,那么这个数据天然就已经扁平化了,如果是2维数据,那么扁平化就是将2维数据变为1维数据,如果是3维数据,那么就要根据你自己所选择的"扁平化程度"来进行操作,假设需要全部扁平化,那么就直接将3维数据变为1维数据,如果只需要部分扁平化,那么有一维的数据不会进行扁平操作。

    torch.flatten()方法有三个参数,分别:

    input tensor:该方法的输入

    start_dim:开始flatten的维度

    end_dim:结束flatten的维度

    torch.flatten(t, start_dim=0, end_dim=-1)

    t表示的时要展平的tensor,start_dim是开始展平的维度,end_dim是结束展平的维度,它的作用就是将输入tensor的第start_dim维到end_dim维之间的数据"拉平"成一维tensor.

  2. (1) 默认的dim不同,torch.flatten()默认的dim=0,而nn.Flatten()默认的dim=1,例如输入数据的尺寸是[3,1,4,4],经过torch.flatten()展开后的尺寸变为[48],而经过nn.Flatten()后得到的结果是[3, 16];

    (2) nn.Flatten是一个类,而torch.flatten()则是一个函数。

    对于torch.nn.Flatten(),因为其被用在神经网络中,输入为一批数据,第一维为batch,通常要把一个数据拉成一维,而不是将一批数据拉为一维。所以torch.nn.Flatten()默认从第二维开始平坦化。

  3. 发生异常: AttributeError

cannot assign module before Module.init() call

问题原因: 属性错误,模块不能在初始化之前赋值, 继承了父类(nn.Module)后子类重写了__init__,但是没有调用super初始化父类的构造函数。在类的初始化里面没有加上父类的初始化。

解决方法: super(XXX, self).init()

|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| class MyNet(nn.Module): ''' 定义神经网络有哪些层,不含输入层 ''' def init(self): ''' 发生异常: AttributeError cannot assign module before Module.init() call ''' super(MyNet,self).init() self.fc1 = nn.Linear( in_features=784, #前一层的神经元个数 out_features=512 ,#当前层的神经元个数 ) self.fc2 = nn.Linear( in_features=512, #前一层的神经元个数 out_features=10,#当前层的神经元个数 ) ''' 定义数据的前向传播方式 x:前面的train_dataloader的一个batch数据 ''' def forward(self,x): #x的shape是[B,C,H,W],例[32,1,28,28] #32不拉伸 #问题:探索flatten的其他参数用法 #拉伸展开成一维向量 x = torch.flatten(x,start_dim=1)#索引为1的开始拉伸 x = self.fc1(x) out = self.fc2(x) return out #人为造一些数据来测试网络的正确性 x = torch.rand(size = (32,1,28,28)) net = MyNet() #实例化对象 print(net) #问题:此处net(x)为什么可以? out = net(x) print(out.shape) #torch.Size([32,10]) device = ( 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' ) print(device) #cpu |

3 结语

针对问题一,对flatten的参数用法进一步了解,在对使用torch.flatten()与torch.nn.flatten()时,需注意俩者的差别,相对而言,对二维数据进行扁平化类nn.flatten()更简便。

针对问题二,对报错问题进行搜索,即可获取正确答案。也要对其含义及遇到类似问题可以解决。

相关推荐
ada7_16 分钟前
LeetCode(python)——543.二叉树的直径
数据结构·python·算法·leetcode·职场和发展
小白学大数据21 分钟前
Python 多线程爬取社交媒体品牌反馈数据
开发语言·python·媒体
东方不败之鸭梨的测试笔记22 分钟前
测试工程师如何利用AI大模型?
人工智能
智能化咨询27 分钟前
(68页PPT)埃森哲XX集团用户主数据治理项目汇报方案(附下载方式)
大数据·人工智能
HAPPY酷33 分钟前
压缩文件格式实战速查表 (纯文本版)
python
说私域34 分钟前
分享经济应用:以“开源链动2+1模式AI智能名片S2B2C商城小程序”为例
人工智能·小程序·开源
工业机器视觉设计和实现34 分钟前
我的第三个cudnn程序(cifar10改cifar100)
人工智能·深度学习·机器学习
熊猫钓鱼>_>38 分钟前
PyTorch深度学习框架入门浅析
人工智能·pytorch·深度学习·cnn·nlp·动态规划·微分
Altair澳汰尔1 小时前
成功案例丨仿真+AI技术为快消包装行业赋能提速:基于 AI 的轻量化设计节省数十亿美元
人工智能·ai·仿真·cae·消费品·hyperworks·轻量化设计
祝余Eleanor1 小时前
Day 31 类的定义和方法
开发语言·人工智能·python·机器学习