pytorch中常见的维度操作

1、view ;reshape;Flatten:维度合并和分解

2、squeeze;unsqueeze:压缩维度和增加维度(相对于维度为1的数据)

3、transpose;t;permute:维度顺顺序变换(转置)

4、expand;repeat:维度扩展

python 复制代码
import torch

'''
维度变换
1、view ;reshape;Flatten:维度合并和分解
2、squeeze;unsqueeze:压缩维度和增加维度(相对于维度为1的数据)
3、transpose;t;permute:维度顺顺序变换(转置)
4、expand;repeat:维度扩展
'''
a = torch.rand(4, 1, 32, 32)

'''
view()的原理很简单,其实就是把原先tensor中的数据进行排列,排成一行,然后根据所给的view()中的参数从一行中按顺序选择组成最终的tensor。
view()可以有多个参数,这取决于你想要得到的是几维的tensor,一般设置两个参数,也是神经网络中常用的(一般在全连接之前),代表二维。
view(h,w),h代表行(想要变为几行),当不知道要变为几行,但知道要变为几列时可取-1;w代表的是列(想要变为几列),当不知道要变为几列,但知道要变为几行时可取-1。
'''


def zqb_view():
    print(a.shape)  # torch.Size([4, 1, 32, 32])
    a1 = a.view(4, 32 * 32)
    print(a1.shape)  # torch.Size([4, 1024])
    a2 = a1.view(4, 1, 32, 32)
    print(a2.shape)  # torch.Size([4, 1, 32, 32])

    # a3 = a1.view(4,28,28) #RuntimeError: shape '[4, 28, 28]' is invalid for input of size 4096
    #     要保持输出数据与输入数据总量,防止数据污染
    # a4 = a1.view(4,32,32,1) # 逻辑错误,改变了原来数据的存储方式,虽然不会报错,但是数据已经被污染,无法正常使用

    a5 = a.view(-1, 32 * 32)  # torch.Size([4, 1024])  -1表示该维度保持不变
    print(a5.shape)


'''
reshape()可以由torch.reshape(),也可由torch.Tensor.reshape()调用,
其作用是在不改变tensor元素数目的情况下改变tensor的shape。
torch.reshape() 需要两个参数,一个是待被改变的张量tensor,一个是想要改变的形状。
'''


def zqb_reshape():
    print(a.shape)  # torch.Size([4, 1, 32, 32])
    a1 = a.reshape(-1, 32 * 32)  # torch.Size([4, 1024])
    print(a1.shape)
    a2 = a1.reshape(4, 1, 32, 32)  # torch.Size([4, 1, 32, 32])
    print(a2.shape)


'''
torch.nn.Flatten(start_dim=1,end_dim=-1)
start_dim与end_dim分别表示开始的维度和终止的维度,默认值为1和-1,其中1表示第一维度,-1表示最后的维度。结合起来看意思就是从第一维度到最后一个维度全部给展平为张量。
(注意:数据的维度是从0开始的,也就是存在第0维度,第一维度并不是真正意义上的第一个)。
因为其被用在神经网络中,输入为一批数据,第 0 维为batch(输入数据的个数),通常要把一个数据拉成一维,而不是将一批数据拉为一维。所以torch.nn.Flatten()默认从第一维开始平坦化。

'''


def zqb_Flatten():
    print(a.shape)  # torch.Size([4, 1, 32, 32])
    a1 = a.flatten(1, -1)  # torch.Size([4, 1024]) ,
    print(a1.shape)
    a2 = a.flatten(2, -1)  # torch.Size([4, 1, 1024])
    print(a2.shape)


'''
unsqueeze(dim=idx)
表示插入的维度占据输出数据的维度,比如
[4, 1, 32, 32].unsqueeze(0),表示新插入维度占据输出数据的0维度torch.Size([1, 4, 1, 32, 32])
[4, 1, 32, 32].unsqueeze(-1),表示新插入维度占据输出数据的-1维度torch.Size([1, 4, 1, 32, 32,1])
'''


def zqb_unsqueeze():
    print(a.shape)  # torch.Size([4, 1, 32, 32])
    a1 = a.unsqueeze(0)
    print(a1.shape)  # torch.Size([1, 4, 1, 32, 32])
    a2 = a.unsqueeze(-1)
    print(a2.shape)  # torch.Size([4, 1, 32, 32, 1])


'''
squeeze()不给参数,表示删除所有1的维度
squeeze(index) 给参数,删除指定index维度,若刚该维度不为1则不做处理
'''


def zqb_squeeze():
    print(a.shape)  # torch.Size([4, 1, 32, 32])
    a1 = a.squeeze()
    print(a1.shape)  # torch.Size([4, 32, 32])
    a2 = a.squeeze(1)
    print(a2.shape)  # torch.Size([4, 32, 32])
    a3 = a.squeeze(0)
    print(a3.shape)  # torch.Size([4, 1, 32, 32])


'''
转置t操作仅仅针对2维数据操作
'''


def zqb_t():
    a = torch.rand(2, 3)
    print(a.shape)  # torch.Size([2, 3])
    a1 = a.t()
    print(a1.shape)  # torch.Size([3, 2])


'''
transpose(dim0,dim1)指定需要调换的两个维度,与顺序无关
对3维及以上的进行操作,输入需要调换的两个维度
'''


def zqb_transpose():
    print(a.shape)  # torch.Size([4, 1, 32, 32])对应的维度0,1,2,3
    a1 = a.transpose(0, 1)
    print(a1.shape)  # torch.Size([1, 4, 32, 32])
    a2 = a.transpose(1, 0)
    print(a2.shape)  # torch.Size([1, 4, 32, 32])


'''
permute(*dims),指定新维度的顺序
'''


def zqb_permute():
    print(a.shape)  # torch.Size([4, 1, 32, 32])对应的维度0,1,2,3
    a1 = a.permute(1, 0, 2, 3)
    print(a1.shape)  # torch.Size([1, 4, 32, 32])


'''
只能对维度值为1的维度进行扩展,无需扩展的维度,维度值不变,
对应位置可写上原始维度大小或直接写作-1;且扩展的Tensor不会分配新的内存,
只是原来的基础上创建新的视图并返回,返回的张量内存是不连续的。
类似于numpy中的broadcast_to函数的作用。如果希望张量内存连续,可以调用contiguous函数。

'''


def zqb_expand():
    print(a.shape)  # torch.Size([4, 1, 32, 32])
    a1 = a.expand(-1, 4, -1, -1)
    print(a1.shape)  # torch.Size([4, 4, 32, 32])


'''
repeat参数*sizes指定了原始张量在各维度上复制的次数。
整个原始张量作为一个整体进行复制,这与Numpy中的repeat函数截然不同,
而更接近于tile函数的效果。
'''


def zqb_repeat():
    print(a.shape)  # torch.Size([4, 1, 32, 32])
    a1 = a.repeat(1, 4, 1, 1)  # torch.Size([4, 4, 32, 32])
    print(a1.shape)
    a2 = a.repeat(4, 1, 1, 1)
    print(a2.shape)  # torch.Size([16, 1, 32, 32])


if __name__ == '__main__':
    zqb_view()
    zqb_reshape()
    zqb_Flatten()
    zqb_unsqueeze()
    zqb_squeeze()
    zqb_t()
    zqb_transpose()
    zqb_permute()
    zqb_expand()
    zqb_repeat()
相关推荐
AI即插即用1 小时前
即插即用系列 | ECCV 2024 WTConv:利用小波变换实现超大感受野的卷积神经网络
图像处理·人工智能·深度学习·神经网络·计算机视觉·cnn·视觉检测
愚公搬代码1 小时前
【愚公系列】《扣子开发 AI Agent 智能体应用》003-扣子 AI 应用开发平台介绍(选择扣子的理由)
人工智能
lhrimperial2 小时前
AI工程化实践指南:从入门到落地
人工智能
jifengzhiling2 小时前
零极点对消:原理、作用与风险
人工智能·算法
哥布林学者2 小时前
吴恩达深度学习课程四:计算机视觉 第三周:检测算法 (一)目标定位与特征点检测
深度学习·ai
科技看点2 小时前
想帮帮服务智能体荣获2025 EDGE AWARDS「最佳AI创新应用」大奖
人工智能
m0_704887892 小时前
DAY 40
人工智能·深度学习
Katecat996632 小时前
【海滩垃圾检测与分类识别-基于改进YOLO13-seg-iRMB模型】
人工智能·数据挖掘
程序员佳佳2 小时前
2025年大模型终极横评:GPT-5.2、Banana Pro与DeepSeek V3.2实战硬核比拼(附统一接入方案)
服务器·数据库·人工智能·python·gpt·api