torch张量的降维与升维

文章目录


一、降维和升维

squeeze和unsqueeze是torch张量常用的降维与升维的一种方式,但这种方式只能增添或减少大小为1的维度,如下:

python 复制代码
x1 = torch.randn(1, 8, 256, 256)
x1 = torch.squeeze(x1,dim=0)
print(x1.shape) # torch.Size([8, 256, 256])

x2 = torch.randn(8, 1, 256, 256)
x2 = torch.squeeze(x2,dim=1)
print(x2.shape) # torch.Size([8, 256, 256])

x1 = torch.randn(8, 256, 256)
x1 = torch.unsqueeze(x1,dim=0)
print(x1.shape)  # torch.Size([1, 8, 256, 256])

x2 = torch.randn(8, 256, 256)
x2 = torch.unsqueeze(x2,dim=1)
print(x2.shape)  # torch.Size([8, 1, 256, 256])

但如果维度大小不为1,squeeze就无效了。
降维:可以使用torch.mean()函数来对维度X进行求平均值,相当于将维度X的所有通道合并为一个单一的通道。
升维:可以使用expand()函数对需要的尺寸进行扩展(其他维度传递-1作为参数,表示在那个维度不进行扩展)

python 复制代码
x1 = torch.randn(2, 8, 256, 256)
x1 = torch.mean(x1, dim=0)
print(x1.shape) # torch.Size([8, 256, 256])

x2 = torch.randn(8, 3, 256, 256)
x2 = torch.mean(x2, dim=1)
print(x2.shape) # torch.Size([8, 256, 256])

x3 = torch.randn(8, 256, 256)
x3 = x3.unsqueeze(0).expand(4,-1,-1,-1)
print(x3.shape)  # torch.Size([4, 8, 256, 256])

x4 = torch.randn(16, 256, 256)
x4 = x4.unsqueeze(1).expand(-1, 8, -1, -1)
print(x4.shape) # torch.Size([16, 8, 256, 256])

未完待续...

相关推荐
Rorsion16 分钟前
PyTorch实现二分类(单特征输出+单层神经网络)
人工智能·pytorch·分类
HAPPY酷25 分钟前
C++ 和 Python 的“容器”对决:从万金油到核武器
开发语言·c++·python
gpfyyds6661 小时前
Python代码练习
开发语言·python
aiguangyuan2 小时前
使用LSTM进行情感分类:原理与实现剖析
人工智能·python·nlp
小小张说故事2 小时前
BeautifulSoup:Python网页解析的优雅利器
后端·爬虫·python
Yeats_Liao2 小时前
评估体系构建:基于自动化指标与人工打分的双重验证
运维·人工智能·深度学习·算法·机器学习·自动化
luoluoal2 小时前
基于python的医疗领域用户问答的意图识别算法研究(源码+文档)
python
Shi_haoliu3 小时前
python安装操作流程-FastAPI + PostgreSQL简单流程
python·postgresql·fastapi
ZH15455891313 小时前
Flutter for OpenHarmony Python学习助手实战:API接口开发的实现
python·学习·flutter
小宋10213 小时前
Java 项目结构 vs Python 项目结构:如何快速搭一个可跑项目
java·开发语言·python