torch张量的降维与升维

文章目录


一、降维和升维

squeeze和unsqueeze是torch张量常用的降维与升维的一种方式,但这种方式只能增添或减少大小为1的维度,如下:

python 复制代码
x1 = torch.randn(1, 8, 256, 256)
x1 = torch.squeeze(x1,dim=0)
print(x1.shape) # torch.Size([8, 256, 256])

x2 = torch.randn(8, 1, 256, 256)
x2 = torch.squeeze(x2,dim=1)
print(x2.shape) # torch.Size([8, 256, 256])

x1 = torch.randn(8, 256, 256)
x1 = torch.unsqueeze(x1,dim=0)
print(x1.shape)  # torch.Size([1, 8, 256, 256])

x2 = torch.randn(8, 256, 256)
x2 = torch.unsqueeze(x2,dim=1)
print(x2.shape)  # torch.Size([8, 1, 256, 256])

但如果维度大小不为1,squeeze就无效了。
降维:可以使用torch.mean()函数来对维度X进行求平均值,相当于将维度X的所有通道合并为一个单一的通道。
升维:可以使用expand()函数对需要的尺寸进行扩展(其他维度传递-1作为参数,表示在那个维度不进行扩展)

python 复制代码
x1 = torch.randn(2, 8, 256, 256)
x1 = torch.mean(x1, dim=0)
print(x1.shape) # torch.Size([8, 256, 256])

x2 = torch.randn(8, 3, 256, 256)
x2 = torch.mean(x2, dim=1)
print(x2.shape) # torch.Size([8, 256, 256])

x3 = torch.randn(8, 256, 256)
x3 = x3.unsqueeze(0).expand(4,-1,-1,-1)
print(x3.shape)  # torch.Size([4, 8, 256, 256])

x4 = torch.randn(16, 256, 256)
x4 = x4.unsqueeze(1).expand(-1, 8, -1, -1)
print(x4.shape) # torch.Size([16, 8, 256, 256])

未完待续...

相关推荐
视觉AI11 分钟前
研究下适合部署在jeston上的深度学习类单目标跟踪算法
深度学习·算法·目标跟踪
Tttian62223 分钟前
Python办公自动化(4)对PPT&邮箱的操作
开发语言·python
AndrewHZ28 分钟前
【图像处理基石】什么是AWB?
图像处理·深度学习·isp算法·awb·ai awb·isp芯片
pk_xz12345638 分钟前
python加载训练好的模型并进行叶片实例分割预测
开发语言·python
独好紫罗兰39 分钟前
洛谷题单3-P1075 [NOIP 2012 普及组] 质因数分解-python-流程图重构
开发语言·python·算法
胖哥真不错2 小时前
Python实现NOA星雀优化算法优化随机森林回归模型项目实战
python·机器学习·项目实战·随机森林回归模型·noa星雀优化算法
编程咕咕gu-2 小时前
从零开始玩python--python版植物大战僵尸来袭
开发语言·python·python基础·pygame·python教程
代码的乐趣3 小时前
支持selenium的chrome driver更新到135.0.7049.42
chrome·python·selenium
SsummerC6 小时前
【leetcode100】数组中的第K个最大元素
python·算法·leetcode
伊玛目的门徒6 小时前
解决backtrader框架下日志ValueError: I/O operation on closed file.报错(jupyternotebook)
python·backtrader·量化·日志管理·回测