关于pytorch张量维度转换大全

关于pytorch张量维度转换大全

  • [1 tensor.view()](#1 tensor.view())

  • [2 tensor.reshape()](#2 tensor.reshape())

  • [3 tensor.squeeze()和tensor.unsqueeze()](#3 tensor.squeeze()和tensor.unsqueeze())

    • [3.1 tensor.squeeze() 降维](#3.1 tensor.squeeze() 降维)
    • [3.2 tensor.unsqueeze(idx)升维](#3.2 tensor.unsqueeze(idx)升维)
  • [4 tensor.permute()](#4 tensor.permute())

  • [5 torch.cat([a,b],dim)](#5 torch.cat([a,b],dim))

  • [6 tensor.expand()](#6 tensor.expand())

  • [7 tensor.narrow(dim, start, len)](#7 tensor.narrow(dim, start, len))

  • [8 tensor.resize_()](#8 tensor.resize_())

  • [9 tensor.repeat()](#9 tensor.repeat())

  • 参考:

    view() 转换维度

    reshape() 转换维度

    permute() 坐标系变换

    squeeze()/unsqueeze() 降维/升维

    expand() 扩张张量

    narraw() 缩小张量

    resize_() 重设尺寸

    repeat(), unfold() 重复张量

    cat(), stack() 拼接张量

1 tensor.view()

view() 用于改变张量的形状 ,但不会改变张量中的元素值
用法1:

例如,你可以使用view 将一个形状是(2,3)的张量变换成(3,2)的张量;

import torch
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
y = x.view(3, 2)    

上面的操作相当于,先把形状为**(2,3)的tensor展平,变成(1,6),然后再变成(3,2).**

用法2:

转换前后张量中的元素个数不变。view()中若存在某一维的维度是-1 ,则表示该维的维度根据总元素个数和其他维度尺寸自适应调整 。注意,view()中最多只能有一个维度的维数设置成-1

z = x.view(-1,2)

举例子:

在卷积神经网络中,经常会在全连接层用到view进行张量的维度拉伸:

假设输入特征是BCH*W 的4维张量,其中B表示batchsize,C表示特征通道数,H和W表示特征的高和宽,在将特征送入全连接层之前,会用.view将转换为B*(CHW)的2维张量 ,即保持batch不变,但将每个特征转换为一维向量。

2 tensor.reshape()

reshape()与view()使用方法相同。

3 tensor.squeeze()和tensor.unsqueeze()

3.1 tensor.squeeze() 降维

(1)若squeeze()括号内为空,则将张量中所有维度为1的维数进行压缩 ,如将1,2,1,9的张量降维到2,9维;若维度中无1维的维数,则保持源维度不变,如将23 4维的张量进行squeeze,则转换后维度不会变。

(2)若squeeze(idx) ,则将张量中对应的第idx维 的维度进行压缩,如1,2,1,9的张量做squeeze(2),则会降维到1,2,9维的张量;若第idx维度的维数不为1,则squeeze后维度不会变化。

例如:

3.2 tensor.unsqueeze(idx)升维

在第idx维进行升维,将tensor由原本的维度n,升维至n+1维 。如张量的维度维2*3,经unsqueeze(0)后,变为1,2,3维度的张量。

4 tensor.permute()

坐标系转换,即矩阵转置 ,使用方法与numpy array的transpose相同 。permute()括号内的参数数字指的是各维度的索引值。permute是深度学习中经常需要使用的技巧,一般的会将BCHW的特征张量 ,通过转置转化为BHWC的特征张量 ,即将特征深度转换到最后一个维度,通过调用**tensor.permute(0, 2, 3, 1)**实现。
torch.transpose只能操作2D矩阵的转置,而permute()函数可以对任意高维矩阵进行转置;

简单理解:permute()相当于可以同时操作tensor的若干维度,transpose只能同时作用于tensor的两个维度。

permute和view/reshape虽然都能将张量转化为特定的维度,但原理完全不同,注意区分。view和reshape处理后,张量中元素顺序都不会有变化,而permute转置后元素的排列会发生变化,因为坐标系变化了。

5 torch.cat([a,b],dim)

在第dim维度进行张量拼接 ,要注意维度保持一致

假设a为h1w1的二维张量,b为h2 w2的二维张量,torch.cat(a,b,0)表示在第一维进行拼接 ,即在列方向拼接 ,所以w1和w2必须相等。torch.cat(a,b,1)表示在第二维进行拼接,即在行方向拼接,所以h1和h2必须相等

假设a为c1h1 w1的二维张量,b为c2h2 w2的二维张量,torch.cat(a,b,0)表示在第一维进行拼接,即在特征的通道维度进行拼接,其他维度必须保持一致,即w1=w2,h1=h2。torch.cat(a,b,1)表示在第二维进行拼接,即在列方向拼接,必须保证w1=w2,c1=c2;torch.cat(a,b,2)表示在第三维进行拼接,即在行方向拼接,必须保证h1=h2,c1=c2;

6 tensor.expand()

扩展张量 ,通过值复制的方式,将单个维度扩大为更大的尺寸 。使用expand()函数不会使原tensor改变,需要将结果重新赋值。下面是具体的实例:

以二维张量为例:tensor是1n或n 1维的张量,分别调用tensor.expand(s, n)或tensor.expand(n, s)在行方向和列方向进行扩展。
expand()的填入参数是size

7 tensor.narrow(dim, start, len)

narrow()函数起到了筛选一定维度上的数据作用.

python 复制代码
torch.narrow(input, dim, start, length)->Tensor

input是需要切片的张量,dim是切片维度,start是开始的索引,length是切片长度,实际应用如下:

8 tensor.resize_()

尺寸变化,将tensor截断为resize_后的维度.

9 tensor.repeat()

tensor.repeat(a,b)将tensor整体在行方向复制a份,在列方向上复制b份

参考:

pytorch中与tensor维度变化相关的函数(持续更新) - weili21的文章 - 知乎
https://zhuanlan.zhihu.com/p/438099006

【pytorch tensor张量维度转换(tensor维度转换)】
https://blog.csdn.net/x_yan033/article/details/104965077

相关推荐
计算机编程-吉哥几秒前
计算机毕业设计 基于Python的社交音乐分享平台的设计与实现 Python+Django+Vue 前后端分离 附源码 讲解 文档
python·django·毕业设计·课程设计·毕业论文·计算机毕业设计选题·音乐分享平台
shinelord明7 分钟前
【Python】Python知识总结浅析
开发语言·人工智能·python
S0linteeH14 分钟前
Windows 11 的 24H2 更新將帶來全新 Copilot+ AI PC 功能
人工智能·copilot
AI大模型_学习君25 分钟前
大模型书籍强烈安利:《掌握NLP:从基础到大语言模型》(附PDF)
人工智能·深度学习·机器学习·语言模型·自然语言处理·pdf·ai大模型
雷神乐乐38 分钟前
Python常用函数
开发语言·python
AIGC安琪1 小时前
[ComfyUI]Flux:开源可商用F1!Apache2开源OpenFLUX1模型,已去蒸馏可微调
人工智能·stable diffusion·开源·aigc·midjourney·ai绘画·flux
背水1 小时前
pillow常用知识
人工智能·计算机视觉·pillow
AI人工智能+1 小时前
浅析人脸活体检测技术的实现过程及其应用领域
人工智能·计算机视觉
凭栏落花侧1 小时前
回归分析在数据挖掘中的应用简析
人工智能·数据挖掘·回归
model20052 小时前
android + tflite 分类APP开发-1
python·tflite·model maker