torch中维度操作总结(repeat,squeeze,unsqueeze,flatten,transpose)

文章目录

repeat 函数

1.repeat参数个数与tensor向量维数一致

bash 复制代码
a = torch.tensor([[1, 2, 3],
                  [1, 2, 3]])
b = a.repeat(2, 2)
print(b.shape)

结果为:

c 复制代码
torch.Size([4,6])

即repeat的参数是对应维度的复制个数,上段代码为0维复制两次,1维复制两次,则得到以上运行结果。其余扩展情况依此类推。

2.repeat参数个数与tensor向量维数不一致

在参数个数大于原tensor维度个数时,总是先在第0维扩展一个维数为1的维度,然后按照参数指定的复制次数进行复制。计算输出的形状时,可以按照 对应参数*对应维度维数 得到结果

bash 复制代码
# a形状(2,3)
a = torch.tensor([[1, 2, 3],
                  [1, 2, 3]])
# repeat参数比维度多
# 首先在第0维扩展一个维度,维数为1,然后按照参数指定的次数进行复制
# 在扩展前先将a的形状扩展为(1,2,3)然后复制
b = a.repeat(1, 2, 1)
print(b.shape)  # 得到结果torch.Size([1, 4, 3])
bash 复制代码
# a形状(2,3)
a = torch.tensor([[1, 2, 3],
                  [1, 2, 3]])
# repeat参数比维度多,在扩展前先将a的形状扩展为(1,2,3)然后复制
b = a.repeat(1, 1, 2)
print(b.shape)  # 得到结果torch.Size([1, 2, 6])
bash 复制代码
# a形状(2,3)
a = torch.tensor([[1, 2, 3],
                  [1, 2, 3]])
# repeat参数比维度多,在扩展前先将a的形状扩展为(1,2,3)然后复制
b = a.repeat(2, 1, 1)
print(b.shape)  # 得到结果torch.Size([2, 2, 3])

squeeze 函数

bash 复制代码
torch.squeeze(A,N)

torch.unsqueeze()函数:减少数组A指定位置N的维度。

如果不指定位置参数N,如果数组A的维度为(1,1,3)。

如果指定位置参数,执行 torch.squeeze(A,1) ,A的维度变为 (1,3),中间的维度被删除。

注:

  1. 如果指定的维度大于1,那么将操作无效
  2. 如果不指定维度N,那么将删除所有维度为1的维度
bash 复制代码
a=torch.randn(1,1,3)
print(a.shape) # torch.Size([1, 1, 3])
b=torch.squeeze(a)
print(b.shape)	# torch.Size([3])
c=torch.squeeze(a,0)
print(c.shape)  # torch.Size([1, 3])
d=torch.squeeze(a,1)
print(d.shape)	# torch.Size([1, 3])
e=torch.squeeze(a,2)#如果去掉第三维,则数不够放了,所以直接保留
print(e.shape)	# torch.Size([1, 1, 3])

unsqueeze 函数

bash 复制代码
torch.unsqueeze(A,N)

torch.unsqueeze()函数:增加数组A指定位置N的维度。

两行三列的数组A维度为(2,3),那么这个数组就有三个位置可以增加维度,分别是

bash 复制代码
([位置0], 2,[位置1], 3, [位置2]) 
或者
( [位置-3] ,2,[位置-2], 3 ,[位置-1] )

如果执行 torch.unsqueeze(A,1),数据的维度就变为了 (2,1,3)

bash 复制代码
a=torch.randn(1,3)
print(a.shape)	# torch.Size([1, 3])
b=torch.unsqueeze(a,0)
print(b.shape)	# torch.Size([1, 1, 3])
c=torch.unsqueeze(a,1)
print(c.shape)	# torch.Size([1, 1, 3])
d=torch.unsqueeze(a,2)
print(d.shape)	# torch.Size([1, 3, 1])

flatten 函数

flatten() 是对多维数据的降维函数。

flatten(),默认缺省参数为0,也就是说flatten()和flatte(0)效果一样。

python里的flatten(dim)表示,从第dim个维度开始展开,将后面的维度转化为一维.也就是说,只保留dim之前的维度,其他维度的数据全都挤在dim这一维。

bash 复制代码
import torch
a = torch.rand(2,3,4)
print(a.shape) # torch.Size([2, 3, 4])
b = a.flatten()
print(b.shape)  # torch.Size([24])
c = a.flatten(0)
print(c.shape)  # torch.Size([24])
d = a.flatten(1)
print(d.shape)  # torch.Size([2, 12])
e = a.flatten(2)
print(e.shape)	 # torch.Size([2, 3, 4])

transpose函数

二维数组

python 复制代码
import numpy as np
X=np.arange(6).reshape((2,3))
print(X)
#[[0 1 2]
# [3 4 5]]

print(X.transpose())
#[[0 3]
# [1 4]
# [2 5]]

print(X.T)
#[[0 3]
# [1 4]
# [2 5]]

多维数组

python 复制代码
x=np.arange(24).reshape((2,3,4))
print(x.shape)
y = x.transpose((0,1,2))
print(y.shape)
y = x.transpose((0,2,1))
print(y.shape)
y = x.transpose((2,1,0))
print(y.shape)

#(2, 3, 4)
#(2, 3, 4)
#(2, 4, 3)
#(4, 3, 2)

参考网址

https://blog.csdn.net/tequila53/article/details/119183678

https://blog.csdn.net/kuan__/article/details/116987162

说明

说明如下,如有侵权,十分抱歉,可联系本人删除对应内容。

会根据平时使用不断更新博客内容。

相关推荐
CodeCraft Studio36 分钟前
CAD文件处理控件Aspose.CAD教程:使用 Python 将绘图转换为 Photoshop
python·photoshop·cad·aspose·aspose.cad
Python×CATIA工业智造3 小时前
Frida RPC高级应用:动态模拟执行Android so文件实战指南
开发语言·python·pycharm
千宇宙航3 小时前
闲庭信步使用SV搭建图像测试平台:第三十一课——基于神经网络的手写数字识别
图像处理·人工智能·深度学习·神经网络·计算机视觉·fpga开发
onceco3 小时前
领域LLM九讲——第5讲 为什么选择OpenManus而不是QwenAgent(附LLM免费api邀请码)
人工智能·python·深度学习·语言模型·自然语言处理·自动化
天水幼麟4 小时前
动手学深度学习-学习笔记(总)
笔记·深度学习·学习
狐凄4 小时前
Python实例题:基于 Python 的简单聊天机器人
开发语言·python
悦悦子a啊5 小时前
Python之--基本知识
开发语言·前端·python
天水幼麟6 小时前
动手学深度学习-学习笔记【二】(基础知识)
笔记·深度学习·学习
笑稀了的野生俊6 小时前
在服务器中下载 HuggingFace 模型:终极指南
linux·服务器·python·bash·gpu算力
Naiva6 小时前
【小技巧】Python+PyCharm IDE 配置解释器出错,环境配置不完整或不兼容。(小智AI、MCP、聚合数据、实时新闻查询、NBA赛事查询)
ide·python·pycharm