pytorch中常见的维度操作

1、view ;reshape;Flatten:维度合并和分解

2、squeeze;unsqueeze:压缩维度和增加维度(相对于维度为1的数据)

3、transpose;t;permute:维度顺顺序变换(转置)

4、expand;repeat:维度扩展

python 复制代码
import torch

'''
维度变换
1、view ;reshape;Flatten:维度合并和分解
2、squeeze;unsqueeze:压缩维度和增加维度(相对于维度为1的数据)
3、transpose;t;permute:维度顺顺序变换(转置)
4、expand;repeat:维度扩展
'''
a = torch.rand(4, 1, 32, 32)

'''
view()的原理很简单,其实就是把原先tensor中的数据进行排列,排成一行,然后根据所给的view()中的参数从一行中按顺序选择组成最终的tensor。
view()可以有多个参数,这取决于你想要得到的是几维的tensor,一般设置两个参数,也是神经网络中常用的(一般在全连接之前),代表二维。
view(h,w),h代表行(想要变为几行),当不知道要变为几行,但知道要变为几列时可取-1;w代表的是列(想要变为几列),当不知道要变为几列,但知道要变为几行时可取-1。
'''


def zqb_view():
    print(a.shape)  # torch.Size([4, 1, 32, 32])
    a1 = a.view(4, 32 * 32)
    print(a1.shape)  # torch.Size([4, 1024])
    a2 = a1.view(4, 1, 32, 32)
    print(a2.shape)  # torch.Size([4, 1, 32, 32])

    # a3 = a1.view(4,28,28) #RuntimeError: shape '[4, 28, 28]' is invalid for input of size 4096
    #     要保持输出数据与输入数据总量,防止数据污染
    # a4 = a1.view(4,32,32,1) # 逻辑错误,改变了原来数据的存储方式,虽然不会报错,但是数据已经被污染,无法正常使用

    a5 = a.view(-1, 32 * 32)  # torch.Size([4, 1024])  -1表示该维度保持不变
    print(a5.shape)


'''
reshape()可以由torch.reshape(),也可由torch.Tensor.reshape()调用,
其作用是在不改变tensor元素数目的情况下改变tensor的shape。
torch.reshape() 需要两个参数,一个是待被改变的张量tensor,一个是想要改变的形状。
'''


def zqb_reshape():
    print(a.shape)  # torch.Size([4, 1, 32, 32])
    a1 = a.reshape(-1, 32 * 32)  # torch.Size([4, 1024])
    print(a1.shape)
    a2 = a1.reshape(4, 1, 32, 32)  # torch.Size([4, 1, 32, 32])
    print(a2.shape)


'''
torch.nn.Flatten(start_dim=1,end_dim=-1)
start_dim与end_dim分别表示开始的维度和终止的维度,默认值为1和-1,其中1表示第一维度,-1表示最后的维度。结合起来看意思就是从第一维度到最后一个维度全部给展平为张量。
(注意:数据的维度是从0开始的,也就是存在第0维度,第一维度并不是真正意义上的第一个)。
因为其被用在神经网络中,输入为一批数据,第 0 维为batch(输入数据的个数),通常要把一个数据拉成一维,而不是将一批数据拉为一维。所以torch.nn.Flatten()默认从第一维开始平坦化。

'''


def zqb_Flatten():
    print(a.shape)  # torch.Size([4, 1, 32, 32])
    a1 = a.flatten(1, -1)  # torch.Size([4, 1024]) ,
    print(a1.shape)
    a2 = a.flatten(2, -1)  # torch.Size([4, 1, 1024])
    print(a2.shape)


'''
unsqueeze(dim=idx)
表示插入的维度占据输出数据的维度,比如
[4, 1, 32, 32].unsqueeze(0),表示新插入维度占据输出数据的0维度torch.Size([1, 4, 1, 32, 32])
[4, 1, 32, 32].unsqueeze(-1),表示新插入维度占据输出数据的-1维度torch.Size([1, 4, 1, 32, 32,1])
'''


def zqb_unsqueeze():
    print(a.shape)  # torch.Size([4, 1, 32, 32])
    a1 = a.unsqueeze(0)
    print(a1.shape)  # torch.Size([1, 4, 1, 32, 32])
    a2 = a.unsqueeze(-1)
    print(a2.shape)  # torch.Size([4, 1, 32, 32, 1])


'''
squeeze()不给参数,表示删除所有1的维度
squeeze(index) 给参数,删除指定index维度,若刚该维度不为1则不做处理
'''


def zqb_squeeze():
    print(a.shape)  # torch.Size([4, 1, 32, 32])
    a1 = a.squeeze()
    print(a1.shape)  # torch.Size([4, 32, 32])
    a2 = a.squeeze(1)
    print(a2.shape)  # torch.Size([4, 32, 32])
    a3 = a.squeeze(0)
    print(a3.shape)  # torch.Size([4, 1, 32, 32])


'''
转置t操作仅仅针对2维数据操作
'''


def zqb_t():
    a = torch.rand(2, 3)
    print(a.shape)  # torch.Size([2, 3])
    a1 = a.t()
    print(a1.shape)  # torch.Size([3, 2])


'''
transpose(dim0,dim1)指定需要调换的两个维度,与顺序无关
对3维及以上的进行操作,输入需要调换的两个维度
'''


def zqb_transpose():
    print(a.shape)  # torch.Size([4, 1, 32, 32])对应的维度0,1,2,3
    a1 = a.transpose(0, 1)
    print(a1.shape)  # torch.Size([1, 4, 32, 32])
    a2 = a.transpose(1, 0)
    print(a2.shape)  # torch.Size([1, 4, 32, 32])


'''
permute(*dims),指定新维度的顺序
'''


def zqb_permute():
    print(a.shape)  # torch.Size([4, 1, 32, 32])对应的维度0,1,2,3
    a1 = a.permute(1, 0, 2, 3)
    print(a1.shape)  # torch.Size([1, 4, 32, 32])


'''
只能对维度值为1的维度进行扩展,无需扩展的维度,维度值不变,
对应位置可写上原始维度大小或直接写作-1;且扩展的Tensor不会分配新的内存,
只是原来的基础上创建新的视图并返回,返回的张量内存是不连续的。
类似于numpy中的broadcast_to函数的作用。如果希望张量内存连续,可以调用contiguous函数。

'''


def zqb_expand():
    print(a.shape)  # torch.Size([4, 1, 32, 32])
    a1 = a.expand(-1, 4, -1, -1)
    print(a1.shape)  # torch.Size([4, 4, 32, 32])


'''
repeat参数*sizes指定了原始张量在各维度上复制的次数。
整个原始张量作为一个整体进行复制,这与Numpy中的repeat函数截然不同,
而更接近于tile函数的效果。
'''


def zqb_repeat():
    print(a.shape)  # torch.Size([4, 1, 32, 32])
    a1 = a.repeat(1, 4, 1, 1)  # torch.Size([4, 4, 32, 32])
    print(a1.shape)
    a2 = a.repeat(4, 1, 1, 1)
    print(a2.shape)  # torch.Size([16, 1, 32, 32])


if __name__ == '__main__':
    zqb_view()
    zqb_reshape()
    zqb_Flatten()
    zqb_unsqueeze()
    zqb_squeeze()
    zqb_t()
    zqb_transpose()
    zqb_permute()
    zqb_expand()
    zqb_repeat()
相关推荐
梦云澜4 分钟前
论文阅读(十):用可分解图模型模拟连锁不平衡
论文阅读·人工智能·深度学习
FL162386312914 分钟前
马铃薯叶子病害检测数据集VOC+YOLO格式1332张9类别
人工智能·深度学习·机器学习
九亿AI算法优化工作室&1 小时前
GWO优化LSBooST回归预测matlab
人工智能·python·算法·机器学习·matlab·数据挖掘·回归
东锋1.32 小时前
Ollama 安装教程:轻松开启本地大语言模型之旅
人工智能
一只昀2 小时前
【产品经理学习案例——AI翻译棒出海业务】
人工智能·ai·产品经理
蓝染k9z3 小时前
在Ubuntu上使用Docker部署DeepSeek
linux·人工智能·ubuntu·docker·deepseek+
python算法(魔法师版)3 小时前
基于机器学习鉴别中药材的方法
深度学习·线性代数·算法·机器学习·支持向量机·数据挖掘·动态规划
小李学AI3 小时前
基于YOLO11的遥感影像山体滑坡检测系统
人工智能·深度学习·yolo·目标检测·机器学习·计算机视觉·yolo11
笨小古4 小时前
保姆级教程:利用Ollama与Open-WebUI本地部署 DeedSeek-R1大模型
人工智能·deepseek
AI浩4 小时前
【Block总结】CPCA,通道优先卷积注意力|即插即用
人工智能·深度学习·目标检测·计算机视觉