PyTorch_张量形状操作

搭建模型时,数据都是基于张量形式的表示,网络层与层之间很多都是以不同的shape的方式进行表现和运算。

对张量形状的操作,以便能够更好处理网络各层之间的数据连接。


reshape 函数的用法

reshape 函数可以再保证张量数据不变的前提下改变数据的维度,将其转换成指定的形状,在神经网络中经常使用该函数来调节数据的形状,以适配不同网络层之间的数据传递。

python 复制代码
import torch 
import numpy as np 

def test01():
    torch.manual_seed(0)
    data = torch.randint(0, 10, [4, 5])
    
    # 查看张量的形状
    print(data.shape, data.shape[0], data.shape[1])  # shape属性可以查看张量的形状
    print(data.size(), data.size(0), data.size(1))  # size()方法可以查看张量的形状

    # 修改张量的形状
    new_data = data.reshape(2, 10) # 两行十列
    print(new_data)

    # 注意:转换之后的形状元素个数得等于原来张量的元素个数,不然就报错。上面创建data就是4*5=20个元素

    # 使用 -1 代替省略的形状
    new_data = data.reshape(-1, 10) # -1表示自动计算行数
    print(new_data.shape)  # torch.Size([2, 10])
    print(new_data)

    new_data = data.reshape(2, -1) # -1表示自动计算列数
    print(new_data)

if __name__ == "__main__":
    test01() 

transpose 和 permute 函数的使用

transpose 函数可以实现交换张量形状的指定维度。

例如:一个张量的形状为 (2, 3, 4) 可以通过 transpose 函数把 3 和 4 进行交换,将张量的形状变为 (2, 4, 3)。

permute 函数可以一次交换更多的维度。

本质上都是在修改数据的维度。

python 复制代码
import torch 
import numpy as np 

# transpose 函数
def test01():
    torch.manual_seed(0)
    data = torch.randint(0, 10, [3, 4, 5])

    # new_data = data.reshape(4, 3, 5)  # 重新计算维度
    # print(new_data.shape)

    # 直接交换两个维度的值
    new_data = torch.transpose(data, 0, 1) # 只是将这两个位置进行交换。0表示第0个维度,1表示第1个维度
    print(new_data.shape)

    # 缺点:transpose 一次只能交换两个维度
    # 把数据的形状变成 (4, 5, 3)
    # 进行第一次交换:(4, 3, 5)
    # 进行第二次交换:(4, 5, 3)
    new_data = torch.transpose(data, 0, 1)
    new_data = torch.transpose(new_data, 1, 2)
    print(new_data.shape)

# permute 函数
def test02():
    torch.manual_seed(0)
    data = torch.randint(0, 10, [3, 4, 5])

    # permute 函数可以一次性交换多个维度
    new_data = torch.permute(data, [1, 2, 0])
    print(new_data.shape)

if __name__ == "__main__":
    test02() 

view 和 contigous 函数的用法

view 函数可以用于修改张量的形状,但是其用法比较局限,只能用于存储在整块内存中的张量。

在 PyTorch 中,有些张量是由不同的数据块组成的,它们并没有存储在整块的内存中,view 函数无法对这样的张量进行变形处理。

例如:一个张量经过了 transpose 或者 permute 函数的处理之后,就无法使用 view 函数进行形状操作。

python 复制代码
import torch 
import numpy as np 

# view 函数的使用
def test01():
    data = torch.tensor([[10, 20, 30], [40, 50, 60]])
    data = data.view(3, 2)
    print(data.shape)

    # 通过 is_contigous 函数来判断张量是否是连续内存空间 (整块的内存)
    print(data.is_contiguous())

# view 函数使用注意
def test02():
    # 当张量经过 transpose 或者 permute 函数之后,内存空间基本不连续
    # 此时,必须先把空间连续,才能使用 view 函数进行张量形状操作

    data = torch.tensor([[10, 20, 30], [40, 50, 60]])
    data = torch.transpose(data, 0, 1)
    print(data.is_contiguous())

    # data = data.view(2, 3)  # 这是报错的
    data = data.contiguous().view(2, 3)
    print(data)

if __name__ == "__main__":
    test02() 

squeeze 和 unsqueeze 函数的用法

squeeze 函数用删除 shape 为 1 的维度。

unsqueeze 在每个维度添加1,以增加数据的形状。

python 复制代码
import torch 
import numpy as np 

# squeeze 函数使用
def test01():
    data = torch.randint(0, 10, [1, 3, 1, 5])
    print(data.shape)

    # 维度压缩,默认去掉所有的1的维度
    new_data = data.squeeze()
    print(new_data.shape)

    # 指定去掉某个1的维度
    new_data = data.squeeze(2)
    print(new_data.shape)

# unsqueeze 函数使用
def test02():
    data = torch.randint(0, 10, [3, 5])
    print(data.shape)

    new_data = data.unsqueeze(0)
    print(new_data)

if __name__ == "__main__":
    test01() 

总结

  1. reshape 函数可以在保证张量数据不变的前提下改变数据的维度
  2. transpose 函数可以实现交换张量形状的指定维度,permute 可以一次交换更多的维度
  3. view 函数也可以用于修改张量的形状,但是它要求被转换的张量内存必须连续,所以一般配合 contiguous 函数使用。
  4. squeeze 和 unsqueeze 函数可以用来增加或者减少维度。
相关推荐
小许学java7 分钟前
Spring AI快速入门以及项目的创建
java·开发语言·人工智能·后端·spring·ai编程·spring ai
人工智能技术派23 分钟前
Qwen-Audio:一种新的大规模音频-语言模型
人工智能·语言模型·音视频
lpfasd12328 分钟前
从OpenAI发布会看AI未来:中国就业市场的重构与突围
人工智能·重构
春末的南方城市1 小时前
清华&字节开源HuMo: 打造多模态可控的人物视频,输入文字、图片、音频,生成电影级的视频,Demo、代码、模型、数据全开源。
人工智能·深度学习·机器学习·计算机视觉·aigc
whltaoin1 小时前
Java 后端与 AI 融合:技术路径、实战案例与未来趋势
java·开发语言·人工智能·编程思想·ai生态
中杯可乐多加冰1 小时前
smardaten AI + 无代码开发实践:基于自然语言交互快速开发【苏超赛事管理系统】
人工智能
Hy行者勇哥1 小时前
数据中台的数据源与数据处理流程
大数据·前端·人工智能·学习·个人开发
xiaohanbao091 小时前
Transformer架构与NLP词表示演进
python·深度学习·神经网络
岁月宁静1 小时前
AI 时代,每个程序员都该拥有个人提示词库:从效率工具到战略资产的蜕变
前端·人工智能·ai编程
双向331 小时前
Trae Solo+豆包Version1.6+Seedream4.0打造"AI识菜通"
人工智能