torch张量的降维与升维

文章目录


一、降维和升维

squeeze和unsqueeze是torch张量常用的降维与升维的一种方式,但这种方式只能增添或减少大小为1的维度,如下:

python 复制代码
x1 = torch.randn(1, 8, 256, 256)
x1 = torch.squeeze(x1,dim=0)
print(x1.shape) # torch.Size([8, 256, 256])

x2 = torch.randn(8, 1, 256, 256)
x2 = torch.squeeze(x2,dim=1)
print(x2.shape) # torch.Size([8, 256, 256])

x1 = torch.randn(8, 256, 256)
x1 = torch.unsqueeze(x1,dim=0)
print(x1.shape)  # torch.Size([1, 8, 256, 256])

x2 = torch.randn(8, 256, 256)
x2 = torch.unsqueeze(x2,dim=1)
print(x2.shape)  # torch.Size([8, 1, 256, 256])

但如果维度大小不为1,squeeze就无效了。
降维:可以使用torch.mean()函数来对维度X进行求平均值,相当于将维度X的所有通道合并为一个单一的通道。
升维:可以使用expand()函数对需要的尺寸进行扩展(其他维度传递-1作为参数,表示在那个维度不进行扩展)

python 复制代码
x1 = torch.randn(2, 8, 256, 256)
x1 = torch.mean(x1, dim=0)
print(x1.shape) # torch.Size([8, 256, 256])

x2 = torch.randn(8, 3, 256, 256)
x2 = torch.mean(x2, dim=1)
print(x2.shape) # torch.Size([8, 256, 256])

x3 = torch.randn(8, 256, 256)
x3 = x3.unsqueeze(0).expand(4,-1,-1,-1)
print(x3.shape)  # torch.Size([4, 8, 256, 256])

x4 = torch.randn(16, 256, 256)
x4 = x4.unsqueeze(1).expand(-1, 8, -1, -1)
print(x4.shape) # torch.Size([16, 8, 256, 256])

未完待续...

相关推荐
StarPrayers.8 分钟前
用 PyTorch 搭建 CIFAR10 线性分类器:从数据加载到模型推理全流程解析
人工智能·pytorch·python
程序员杰哥10 分钟前
UI自动化测试实战:从入门到精通
自动化测试·软件测试·python·selenium·测试工具·ui·职场和发展
SunnyRivers12 分钟前
通俗易懂理解python yield
python
mortimer13 分钟前
Python 进阶:彻底理解类属性、类方法与静态方法
后端·python
Francek Chen24 分钟前
【深度学习计算机视觉】13:实战Kaggle比赛:图像分类 (CIFAR-10)
深度学习·计算机视觉·分类
Ro Jace36 分钟前
模式识别与机器学习课程笔记(11):深度学习
笔记·深度学习·机器学习
渡我白衣1 小时前
深度学习进阶(六)——世界模型与具身智能:AI的下一次跃迁
人工智能·深度学习
人工智能技术咨询.1 小时前
【无标题】
人工智能·深度学习·transformer
Sherry Wangs2 小时前
显卡算力过高导致PyTorch不兼容的救赎指南
人工智能·pytorch·显卡
小叮当⇔2 小时前
PYcharm——获取天气
ide·python·pycharm