torch张量的降维与升维

文章目录


一、降维和升维

squeeze和unsqueeze是torch张量常用的降维与升维的一种方式,但这种方式只能增添或减少大小为1的维度,如下:

python 复制代码
x1 = torch.randn(1, 8, 256, 256)
x1 = torch.squeeze(x1,dim=0)
print(x1.shape) # torch.Size([8, 256, 256])

x2 = torch.randn(8, 1, 256, 256)
x2 = torch.squeeze(x2,dim=1)
print(x2.shape) # torch.Size([8, 256, 256])

x1 = torch.randn(8, 256, 256)
x1 = torch.unsqueeze(x1,dim=0)
print(x1.shape)  # torch.Size([1, 8, 256, 256])

x2 = torch.randn(8, 256, 256)
x2 = torch.unsqueeze(x2,dim=1)
print(x2.shape)  # torch.Size([8, 1, 256, 256])

但如果维度大小不为1,squeeze就无效了。
降维:可以使用torch.mean()函数来对维度X进行求平均值,相当于将维度X的所有通道合并为一个单一的通道。
升维:可以使用expand()函数对需要的尺寸进行扩展(其他维度传递-1作为参数,表示在那个维度不进行扩展)

python 复制代码
x1 = torch.randn(2, 8, 256, 256)
x1 = torch.mean(x1, dim=0)
print(x1.shape) # torch.Size([8, 256, 256])

x2 = torch.randn(8, 3, 256, 256)
x2 = torch.mean(x2, dim=1)
print(x2.shape) # torch.Size([8, 256, 256])

x3 = torch.randn(8, 256, 256)
x3 = x3.unsqueeze(0).expand(4,-1,-1,-1)
print(x3.shape)  # torch.Size([4, 8, 256, 256])

x4 = torch.randn(16, 256, 256)
x4 = x4.unsqueeze(1).expand(-1, 8, -1, -1)
print(x4.shape) # torch.Size([16, 8, 256, 256])

未完待续...

相关推荐
TF男孩7 小时前
ARQ:一款低成本的消息队列,实现每秒万级吞吐
后端·python·消息队列
该用户已不存在12 小时前
Mojo vs Python vs Rust: 2025年搞AI,该学哪个?
后端·python·rust
站大爷IP14 小时前
Java调用Python的5种实用方案:从简单到进阶的全场景解析
python
用户83562907805119 小时前
从手动编辑到代码生成:Python 助你高效创建 Word 文档
后端·python
c8i19 小时前
python中类的基本结构、特殊属性于MRO理解
python
隐语SecretFlow19 小时前
国人自研开源隐私计算框架SecretFlow,深度拆解框架及使用【开发者必看】
深度学习
liwulin050620 小时前
【ESP32-CAM】HELLO WORLD
python
Doris_202320 小时前
Python条件判断语句 if、elif 、else
前端·后端·python
Doris_202320 小时前
Python 模式匹配match case
前端·后端·python
Billy_Zuo20 小时前
人工智能深度学习——卷积神经网络(CNN)
人工智能·深度学习·cnn