Pytorch基础:Tensor的reshape方法

相关阅读

Pytorch基础https://blog.csdn.net/weixin_45791458/category_12457644.html?spm=1001.2014.3001.5482


在Pytorch中,reshape是Tensor的一个重要方法,它与Numpy中的reshape类似,用于返回一个改变了形状但数据和数据顺序和原来一致的新Tensor对象。注意:此时返回的数据对象并不一定是新的,这取决于应用此方法的Tensor是否是连续的。

reshape方法的语法如下所示:

Tensor.reshape(*shape) → Tensor
shape (tuple of ints or int...) - the desired shape

reshape的用法如下所示:

import torch
# 创建一个张量
x = torch.randn(3, 4)
tensor([[ 0.1961, -0.9038,  0.9196, -1.1851],
        [ 1.1321,  0.3153,  0.3485,  0.7977],
        [-0.5279,  0.2062, -0.4224, -0.3993]])

# 使用reshape方法将其重新塑造为2行6列的形状
y = x.reshape(2, 6) 
y = x.reshape((2,6)) #两种形式均可,y = x.reshape([2,6])也可
tensor([[ 0.1961, -0.9038,  0.9196, -1.1851,  1.1321,  0.3153],
        [ 0.3485,  0.7977, -0.5279,  0.2062, -0.4224, -0.3993]])

可以看到,给出的参数既可以是多个整数(其中每个整数代表一个维度的大小,而整数的数量代表维度的数量),也可以是一个元组或是列表(其中每个元素代表一个维度的大小,而元素数量代表维度的数量)。而且reshape不改变Tensor中数据的排列顺序(指的是从上到下从左到右遍历的顺序),只改变形状,这也就对reshape各维度大小的乘积有要求,要与原Tensor一致。在上例中即3*4=2*6。

另外reshape还有一个trick,即某一维的实参可以是-1,此时会自动根据原Tensor大小和给出的其他维度参数的大小,推断出这一维度的大小,举例如下:

import torch
# 创建一个张量
x = torch.randn(3, 4)
tensor([[ 0.1961, -0.9038,  0.9196, -1.1851],
        [ 1.1321,  0.3153,  0.3485,  0.7977],
        [-0.5279,  0.2062, -0.4224, -0.3993]])

# 使用reshape方法将其重新塑造为6行n列的形状,n为自动推断出的值
y = x.reshape(6, -1)
tensor([[ 0.1961, -0.9038],
        [ 0.9196, -1.1851],
        [ 1.1321,  0.3153],
        [ 0.3485,  0.7977],
        [-0.5279,  0.2062],
        [-0.4224, -0.3993]])

# 使用reshape方法将其重新塑造为(2,2,n)的形状,n为自动推断出的值
y = x.reshape(2, 2, -1)
tensor([[[ 0.1961, -0.9038,  0.9196],
         [-1.1851,  1.1321,  0.3153]],

        [[ 0.3485,  0.7977, -0.5279],
         [ 0.2062, -0.4224, -0.3993]]])

# 不能在两个维度都指定-1,这时无法推断出唯一结果
y = x.reshape(2, -1, -1)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: only one dimension can be inferred

除此之外,还可以使用torch.reshape()函数,这与使用reshape方式效果一致,torch.reshape()的语法如下所示。

torch.reshape(input, shape) → Tensor
input (Tensor) -- the tensor to be reshaped
shape (tuple of python:int) -- the new shape


import torch
# 创建一个张量
x = torch.randn(3, 4)
tensor([[ 0.1961, -0.9038,  0.9196, -1.1851],
        [ 1.1321,  0.3153,  0.3485,  0.7977],
        [-0.5279,  0.2062, -0.4224, -0.3993]])

# 使用reshape函数将其重新塑造为6行n列的形状,n为自动推断出的值
y = torch.reshape(x, (6, -1))
tensor([[ 0.1961, -0.9038],
        [ 0.9196, -1.1851],
        [ 1.1321,  0.3153],
        [ 0.3485,  0.7977],
        [-0.5279,  0.2062],
        [-0.4224, -0.3993]])
相关推荐
Tom Boom7 分钟前
1.11.信息系统的分类【DSS】
人工智能·算法·机器学习·职场和发展·分类·数据挖掘·系统架构
扫地僧98511 分钟前
MuMu-LLaMA:通过大型语言模型进行多模态音乐理解和生成(Python代码实现+论文)
人工智能·语言模型·llama
skywalk816313 分钟前
Trae 是一款由 AI 驱动的 IDE,让编程更加愉悦和高效。国际版集成了 GPT-4 和 Claude 3.5,国内版集成了DeepSeek-r1
人工智能·trae
WenGyyyL19 分钟前
使用OpenCV和MediaPipe库——驼背检测(姿态监控)
人工智能·python·opencv·算法·计算机视觉·numpy
梓羽玩Python32 分钟前
开源版Manus来了!14.7k标星的OpenManus,让AI替你全自动执行任务!
人工智能·github
蹦蹦跳跳真可爱58933 分钟前
Python----数据分析(Matplotlib四:Figure的用法,创建Figure对象,常用的Figure对象的方法)
python·数据分析·matplotlib
广拓科技33 分钟前
中国视频生成 AI 开源潮:腾讯阿里掀技术普惠革命,重塑内容创作格局
人工智能·开源
dr李四维42 分钟前
Java在小米SU7 Ultra汽车中的技术赋能
java·人工智能·安卓·智能驾驶·互联·小米su7ultra·hdfs架构
guanshiyishi43 分钟前
ABeam 德硕 | 中国汽车市场(1)——正在推进电动化的中国汽车市场
人工智能·物联网·汽车
思茂信息44 分钟前
CST直角反射器 --- 距离多普勒(RD图), 毫米波汽车雷达ADAS
前端·人工智能·5g·汽车·无人机·软件工程