pytorch中常见的维度操作

1、view ;reshape;Flatten:维度合并和分解

2、squeeze;unsqueeze:压缩维度和增加维度(相对于维度为1的数据)

3、transpose;t;permute:维度顺顺序变换(转置)

4、expand;repeat:维度扩展

python 复制代码
import torch

'''
维度变换
1、view ;reshape;Flatten:维度合并和分解
2、squeeze;unsqueeze:压缩维度和增加维度(相对于维度为1的数据)
3、transpose;t;permute:维度顺顺序变换(转置)
4、expand;repeat:维度扩展
'''
a = torch.rand(4, 1, 32, 32)

'''
view()的原理很简单,其实就是把原先tensor中的数据进行排列,排成一行,然后根据所给的view()中的参数从一行中按顺序选择组成最终的tensor。
view()可以有多个参数,这取决于你想要得到的是几维的tensor,一般设置两个参数,也是神经网络中常用的(一般在全连接之前),代表二维。
view(h,w),h代表行(想要变为几行),当不知道要变为几行,但知道要变为几列时可取-1;w代表的是列(想要变为几列),当不知道要变为几列,但知道要变为几行时可取-1。
'''


def zqb_view():
    print(a.shape)  # torch.Size([4, 1, 32, 32])
    a1 = a.view(4, 32 * 32)
    print(a1.shape)  # torch.Size([4, 1024])
    a2 = a1.view(4, 1, 32, 32)
    print(a2.shape)  # torch.Size([4, 1, 32, 32])

    # a3 = a1.view(4,28,28) #RuntimeError: shape '[4, 28, 28]' is invalid for input of size 4096
    #     要保持输出数据与输入数据总量,防止数据污染
    # a4 = a1.view(4,32,32,1) # 逻辑错误,改变了原来数据的存储方式,虽然不会报错,但是数据已经被污染,无法正常使用

    a5 = a.view(-1, 32 * 32)  # torch.Size([4, 1024])  -1表示该维度保持不变
    print(a5.shape)


'''
reshape()可以由torch.reshape(),也可由torch.Tensor.reshape()调用,
其作用是在不改变tensor元素数目的情况下改变tensor的shape。
torch.reshape() 需要两个参数,一个是待被改变的张量tensor,一个是想要改变的形状。
'''


def zqb_reshape():
    print(a.shape)  # torch.Size([4, 1, 32, 32])
    a1 = a.reshape(-1, 32 * 32)  # torch.Size([4, 1024])
    print(a1.shape)
    a2 = a1.reshape(4, 1, 32, 32)  # torch.Size([4, 1, 32, 32])
    print(a2.shape)


'''
torch.nn.Flatten(start_dim=1,end_dim=-1)
start_dim与end_dim分别表示开始的维度和终止的维度,默认值为1和-1,其中1表示第一维度,-1表示最后的维度。结合起来看意思就是从第一维度到最后一个维度全部给展平为张量。
(注意:数据的维度是从0开始的,也就是存在第0维度,第一维度并不是真正意义上的第一个)。
因为其被用在神经网络中,输入为一批数据,第 0 维为batch(输入数据的个数),通常要把一个数据拉成一维,而不是将一批数据拉为一维。所以torch.nn.Flatten()默认从第一维开始平坦化。

'''


def zqb_Flatten():
    print(a.shape)  # torch.Size([4, 1, 32, 32])
    a1 = a.flatten(1, -1)  # torch.Size([4, 1024]) ,
    print(a1.shape)
    a2 = a.flatten(2, -1)  # torch.Size([4, 1, 1024])
    print(a2.shape)


'''
unsqueeze(dim=idx)
表示插入的维度占据输出数据的维度,比如
[4, 1, 32, 32].unsqueeze(0),表示新插入维度占据输出数据的0维度torch.Size([1, 4, 1, 32, 32])
[4, 1, 32, 32].unsqueeze(-1),表示新插入维度占据输出数据的-1维度torch.Size([1, 4, 1, 32, 32,1])
'''


def zqb_unsqueeze():
    print(a.shape)  # torch.Size([4, 1, 32, 32])
    a1 = a.unsqueeze(0)
    print(a1.shape)  # torch.Size([1, 4, 1, 32, 32])
    a2 = a.unsqueeze(-1)
    print(a2.shape)  # torch.Size([4, 1, 32, 32, 1])


'''
squeeze()不给参数,表示删除所有1的维度
squeeze(index) 给参数,删除指定index维度,若刚该维度不为1则不做处理
'''


def zqb_squeeze():
    print(a.shape)  # torch.Size([4, 1, 32, 32])
    a1 = a.squeeze()
    print(a1.shape)  # torch.Size([4, 32, 32])
    a2 = a.squeeze(1)
    print(a2.shape)  # torch.Size([4, 32, 32])
    a3 = a.squeeze(0)
    print(a3.shape)  # torch.Size([4, 1, 32, 32])


'''
转置t操作仅仅针对2维数据操作
'''


def zqb_t():
    a = torch.rand(2, 3)
    print(a.shape)  # torch.Size([2, 3])
    a1 = a.t()
    print(a1.shape)  # torch.Size([3, 2])


'''
transpose(dim0,dim1)指定需要调换的两个维度,与顺序无关
对3维及以上的进行操作,输入需要调换的两个维度
'''


def zqb_transpose():
    print(a.shape)  # torch.Size([4, 1, 32, 32])对应的维度0,1,2,3
    a1 = a.transpose(0, 1)
    print(a1.shape)  # torch.Size([1, 4, 32, 32])
    a2 = a.transpose(1, 0)
    print(a2.shape)  # torch.Size([1, 4, 32, 32])


'''
permute(*dims),指定新维度的顺序
'''


def zqb_permute():
    print(a.shape)  # torch.Size([4, 1, 32, 32])对应的维度0,1,2,3
    a1 = a.permute(1, 0, 2, 3)
    print(a1.shape)  # torch.Size([1, 4, 32, 32])


'''
只能对维度值为1的维度进行扩展,无需扩展的维度,维度值不变,
对应位置可写上原始维度大小或直接写作-1;且扩展的Tensor不会分配新的内存,
只是原来的基础上创建新的视图并返回,返回的张量内存是不连续的。
类似于numpy中的broadcast_to函数的作用。如果希望张量内存连续,可以调用contiguous函数。

'''


def zqb_expand():
    print(a.shape)  # torch.Size([4, 1, 32, 32])
    a1 = a.expand(-1, 4, -1, -1)
    print(a1.shape)  # torch.Size([4, 4, 32, 32])


'''
repeat参数*sizes指定了原始张量在各维度上复制的次数。
整个原始张量作为一个整体进行复制,这与Numpy中的repeat函数截然不同,
而更接近于tile函数的效果。
'''


def zqb_repeat():
    print(a.shape)  # torch.Size([4, 1, 32, 32])
    a1 = a.repeat(1, 4, 1, 1)  # torch.Size([4, 4, 32, 32])
    print(a1.shape)
    a2 = a.repeat(4, 1, 1, 1)
    print(a2.shape)  # torch.Size([16, 1, 32, 32])


if __name__ == '__main__':
    zqb_view()
    zqb_reshape()
    zqb_Flatten()
    zqb_unsqueeze()
    zqb_squeeze()
    zqb_t()
    zqb_transpose()
    zqb_permute()
    zqb_expand()
    zqb_repeat()
相关推荐
边缘计算社区24 分钟前
首个!艾灵参编的工业边缘计算国家标准正式发布
大数据·人工智能·边缘计算
游客52034 分钟前
opencv中的各种滤波器简介
图像处理·人工智能·python·opencv·计算机视觉
一位小说男主35 分钟前
编码器与解码器:从‘乱码’到‘通话’
人工智能·深度学习
深圳南柯电子1 小时前
深圳南柯电子|电子设备EMC测试整改:常见问题与解决方案
人工智能
Kai HVZ1 小时前
《OpenCV计算机视觉》--介绍及基础操作
人工智能·opencv·计算机视觉
biter00881 小时前
opencv(15) OpenCV背景减除器(Background Subtractors)学习
人工智能·opencv·学习
吃个糖糖1 小时前
35 Opencv 亚像素角点检测
人工智能·opencv·计算机视觉
qq_529025291 小时前
Torch.gather
python·深度学习·机器学习