torch张量的降维与升维

文章目录


一、降维和升维

squeeze和unsqueeze是torch张量常用的降维与升维的一种方式,但这种方式只能增添或减少大小为1的维度,如下:

python 复制代码
x1 = torch.randn(1, 8, 256, 256)
x1 = torch.squeeze(x1,dim=0)
print(x1.shape) # torch.Size([8, 256, 256])

x2 = torch.randn(8, 1, 256, 256)
x2 = torch.squeeze(x2,dim=1)
print(x2.shape) # torch.Size([8, 256, 256])

x1 = torch.randn(8, 256, 256)
x1 = torch.unsqueeze(x1,dim=0)
print(x1.shape)  # torch.Size([1, 8, 256, 256])

x2 = torch.randn(8, 256, 256)
x2 = torch.unsqueeze(x2,dim=1)
print(x2.shape)  # torch.Size([8, 1, 256, 256])

但如果维度大小不为1,squeeze就无效了。
降维:可以使用torch.mean()函数来对维度X进行求平均值,相当于将维度X的所有通道合并为一个单一的通道。
升维:可以使用expand()函数对需要的尺寸进行扩展(其他维度传递-1作为参数,表示在那个维度不进行扩展)

python 复制代码
x1 = torch.randn(2, 8, 256, 256)
x1 = torch.mean(x1, dim=0)
print(x1.shape) # torch.Size([8, 256, 256])

x2 = torch.randn(8, 3, 256, 256)
x2 = torch.mean(x2, dim=1)
print(x2.shape) # torch.Size([8, 256, 256])

x3 = torch.randn(8, 256, 256)
x3 = x3.unsqueeze(0).expand(4,-1,-1,-1)
print(x3.shape)  # torch.Size([4, 8, 256, 256])

x4 = torch.randn(16, 256, 256)
x4 = x4.unsqueeze(1).expand(-1, 8, -1, -1)
print(x4.shape) # torch.Size([16, 8, 256, 256])

未完待续...

相关推荐
qq_5278878712 分钟前
ImportError: cannot import name ‘PfeifferConfig‘ from ‘transformers‘【已解决】
linux·开发语言·python
StackOverthink16 分钟前
PyTorch:让深度学习像搭积木一样简单!!!
人工智能·pytorch·深度学习·其他
知舟不叙19 分钟前
使用OpenCV和Python进行图像掩膜与直方图分析
人工智能·python·opencv·图像掩膜
woniuhuihui26 分钟前
案例8 模型量化
python
tt卡丁车1 小时前
6.11打卡
python
武乐乐~1 小时前
强化学习入门:交叉熵方法实现CartPole智能体
人工智能·深度学习·机器学习
蓝婷儿1 小时前
6个月Python学习计划 Day 21 - Python 学习前三周回顾总结
python·学习
飞翔的佩奇1 小时前
【完整源码+数据集+部署教程】安检爆炸物检测系统源码和数据集:改进yolo11-REPVGGOREPA
python·yolo·计算机视觉·毕业设计·数据集·yolo11·安检爆炸物检测
摆渡搜不到你1 小时前
PyCharm Python IDE
ide·python·pycharm
pitepa1 小时前
初学者运行Pycharm程序可能会出现的问题,及解决办法
ide·python·pycharm