【学习笔记】深入浅出详解Pytorch中的View, reshape, unfold,flatten等方法。

文章目录

一、写在前面

最近在解析transformer源码的时候突然看到了unfold?我在想unfold是什么意思?为什么不用reshape,他们的底层逻辑有什么区别呢?于是便相对比一下他们之间的区别,便有了本篇博客,希望对大家有帮助!

二、Reshape

(一)用法

1. torch.reshape(input, shape)

输入是tensor和shape,其中原始shape和目标shape的元素数量要一致。

2. Tensor.reshape(shape) → Tensor

与上述用法一致,只不过这个是直接在tensor的基础上进行reshape。

reshpe是按照顺序进行重新排列组合的。

[16,2] 其实与 [4,2,2,2] 是一样的,只要最后一维的数字是一样,其实结果都是一样的。

(二)代码展示

三、Unfold

(一)torch.unfold 的基本概念

torch.unfold 的作用是将输入张量的某个维度展开为多个滑动窗口。每个窗口包含一个局部区域,这些区域可以用于后续的计算。

  1. 语法:
python 复制代码
torch.Tensor.unfold(dimension, size, step)
  1. 参数:
    • dimension:要展开的维度(整数)。
    • size:滑动窗口的大小(整数)。
    • step:滑动窗口的步长(整数)。
    • 返回值:返回一个新的张量,其中指定维度的每个元素被展开为多个滑动窗口。

(二)torch.unfold 的工作原理

假设我们有一个形状为 (N, C, H, W) 的张量(例如图像数据),我们希望在高度维度(H)上提取滑动窗口。

  • 输入张量:(N, C, H, W)
  • 展开维度:dimension=2(即高度维度 H)
  • 窗口大小:size=k(例如 k=3)
  • 步长:step=s(例如 s=1)

torch.unfold 会将高度维度 H 展开为多个大小为 k 的滑动窗口,每个窗口之间间隔 s。

(三) 示例代码

  • 示例 1:一维张量的展开
python 复制代码
import torch
# 创建一个一维张量
x = torch.arange(10)
print("原始张量:", x)

# 使用 unfold 展开
unfolded = x.unfold(dimension=0, size=3, step=1)
print("展开后的张量:\n", unfolded)
  • 输出:
python 复制代码
原始张量: tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
展开后的张量:
 tensor([[0, 1, 2],
         [1, 2, 3],
         [2, 3, 4],
         [3, 4, 5],
         [4, 5, 6],
         [5, 6, 7],
         [6, 7, 8],
         [7, 8, 9]])
  • 示例 2:二维张量的展开(图像处理)
python 复制代码
import torch
# 创建一个二维张量(模拟图像)
    x = torch.arange(16).reshape(1, 1, 4, 4)  # 形状为 (1, 1, 4, 4)
    print("原始张量:\n", x)
    # 使用 unfold 展开
    unfolded_1 = x.unfold(dimension=2, size=3, step=1)
    unfolded_2 = unfolded_1.unfold(dimension=3, size=3, step=1)
    print("展开后的张量形状:", unfolded_1.shape)
    print("展开后的张量:\n", unfolded_1)

    print("展开后的张量形状:", unfolded_2.shape)
    print("展开后的张量:\n", unfolded_2)
  • 输出:
python 复制代码
原始张量:
 tensor([[[[ 0,  1,  2,  3],
          [ 4,  5,  6,  7],
          [ 8,  9, 10, 11],
          [12, 13, 14, 15]]]])
展开后的张量形状: torch.Size([1, 1, 2, 4, 3])
展开后的张量:
 tensor([[[[[ 0,  4,  8],
           [ 1,  5,  9],
           [ 2,  6, 10],
           [ 3,  7, 11]],

          [[ 4,  8, 12],
           [ 5,  9, 13],
           [ 6, 10, 14],
           [ 7, 11, 15]]]]])
展开后的张量形状: torch.Size([1, 1, 2, 2, 3, 3])
展开后的张量:
 tensor([[[[[[ 0,  1,  2],
            [ 4,  5,  6],
            [ 8,  9, 10]],

           [[ 1,  2,  3],
            [ 5,  6,  7],
            [ 9, 10, 11]]],


          [[[ 4,  5,  6],
            [ 8,  9, 10],
            [12, 13, 14]],

           [[ 5,  6,  7],
            [ 9, 10, 11],
            [13, 14, 15]]]]]])

数组的运算主要看最后两维,倒数第二维代表行,倒数第一维代表列。

(四)torch.unfold 的应用场景

  1. 卷积操作

在卷积神经网络(CNN)中,卷积核通过滑动窗口的方式提取图像的局部特征。torch.unfold 可以用于手动实现卷积操作。

  1. 图像处理

在图像处理中,torch.unfold 可以用于提取图像的局部区域(例如提取图像的滑动窗口)。

  1. 时间序列分析

在时间序列分析中,torch.unfold 可以用于提取时间序列的滑动窗口,用于特征提取或模型训练。

(五)注意事项

  1. 维度选择:需要明确指定要展开的维度(dimension)。
  2. 窗口大小和步长:窗口大小(size)和步长(step)的选择会影响展开后的张量形状。
  3. 内存消耗:torch.unfold 可能会生成较大的张量,尤其是在高维数据上使用时,需要注意内存消耗。

(六)总结

  1. torch.unfold 的作用:从张量的某个维度提取滑动窗口。
  2. 常用参数:dimension(展开维度)、size(窗口大小)、step(步长)。
  3. 应用场景:卷积操作、图像处理、时间序列分析等。
  4. 注意事项:选择合适的维度、窗口大小和步长,避免内存消耗过大。

四、View

torch.view 用于返回一个与原始张量共享相同数据存储的新张量,但具有不同的形状。换句话说,view 只是改变了张量的视图(view),而不会复制数据。

(一)用法

  1. Tensor.view(*shape):

    • 参数:

    *shape:新的形状(可以是整数或元组)。

    • 返回值:

    返回一个新的张量,具有指定的形状,并与原始张量共享相同的数据存储。

(二)注意事项

  1. view 返回的张量与原始张量共享相同的数据存储。
  2. 如果原始张量的数据发生变化,view 返回的张量也会随之变化。
  3. view 要求张量的内存必须是连续的(即张量在内存中是连续存储的)。如果内存不连续,view 会抛出错误。

(三)其他方法

  1. view_as
  • 作用:将当前张量转换为与另一个张量相同的形状。
  • 语法:
python 复制代码
torch.Tensor.view_as(other)
  • 参数:

other:目标张量,当前张量的形状将被转换为与 other 相同的形状。

  • 返回值:

返回一个新的张量,具有与 other 相同的形状,并与原始张量共享数据存储。

  • 示例:
python 复制代码
import torch
# 创建一个张量
x = torch.arange(12)
print("原始张量:", x)
# 创建目标张量
other = torch.empty(3, 4)
# 使用 view_as 改变形状
y = x.view_as(other)
print("改变形状后的张量:\n", y)
  • 输出:
python 复制代码
原始张量: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])
改变形状后的张量:
 tensor([[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]])
  1. view_as_real
  • 作用:将复数张量转换为实数张量。
  • 语法:
python 复制代码
torch.Tensor.view_as_real()
  • 返回值:

返回一个新的实数张量,形状为 (..., 2),其中最后一个维度包含复数的实部和虚部。

  • 示例:
python 复制代码
import torch
# 创建一个复数张量
x = torch.tensor([1 + 2j, 3 + 4j])
print("原始张量:", x)
# 使用 view_as_real 转换为实数张量
y = x.view_as_real()
print("转换后的张量:\n", y)
  • 输出:
python 复制代码
原始张量: tensor([1.+2.j, 3.+4.j])
转换后的张量:
 tensor([[1., 2.],
         [3., 4.]])
  1. view_as_complex
  • 作用:将实数张量转换为复数张量。
  • 语法:
python 复制代码
torch.Tensor.view_as_complex()
  • 返回值:

返回一个新的复数张量,形状为 (...,),其中最后一个维度被解释为复数的实部和虚部。

  • 示例:
python 复制代码
import torch
# 创建一个实数张量
x = torch.tensor([[1., 2.], [3., 4.]])
print("原始张量:\n", x)
# 使用 view_as_complex 转换为复数张量
y = x.view_as_complex()
print("转换后的张量:", y)
  • 输出:
python 复制代码
原始张量:
 tensor([[1., 2.],
         [3., 4.]])
转换后的张量: tensor([1.+2.j, 3.+4.j])
  1. view_as_strided
  • 作用:返回一个具有指定步长和内存布局的张量视图。
  • 语法:
python 复制代码
torch.Tensor.view_as_strided(size, stride)
  • 参数:

size:新的形状(元组)。

stride:新的步长(元组)。

  • 返回值:

返回一个新的张量,具有指定的形状和步长,并与原始张量共享数据存储。

  • 示例:
python 复制代码
import torch
# 创建一个张量
x = torch.arange(9).view(3, 3)
print("原始张量:\n", x)
# 使用 view_as_strided 改变形状和步长
y = x.view_as_strided((2, 2), (1, 2))
print("改变形状和步长后的张量:\n", y)
  • 输出:
python 复制代码
原始张量:
 tensor([[0, 1, 2],
         [3, 4, 5],
         [6, 7, 8]])
改变形状和步长后的张量:
 tensor([[0, 2],
         [1, 3]])
  1. view_as_real 和 view_as_complex 的应用场景
  • 复数计算:

在涉及复数计算的任务中,view_as_real 和 view_as_complex 可以用于在复数和实数之间进行转换。

  • 信号处理:

在信号处理中,复数张量常用于表示频域信号,view_as_real 和 view_as_complex 可以用于频域和时域之间的转换。

  1. view_as_strided 的应用场景
  • 自定义内存布局:

在需要自定义内存布局的场景中,view_as_strided 可以用于创建具有特定步长和形状的张量视图。

  • 高效内存访问:

在需要高效访问内存的场景中,view_as_strided 可以用于优化内存访问模式。

五、Flatten

(一)torch.flatten 的基本概念

torch.flatten 的作用是将输入张量的指定维度展平为一维。它可以展平整个张量,也可以只展平部分维度。

  1. 语法:
python 复制代码
torch.flatten(input, start_dim=0, end_dim=-1)
  1. 参数:

    • input:输入张量。
    • start_dim:开始展平的维度(整数),默认为 0。
    • end_dim:结束展平的维度(整数),默认为 -1。
    • 返回值:返回一个新的张量,具有展平后的形状。

(二)torch.flatten 的工作原理

  1. torch.flatten 会将指定范围内的维度展平为一维。
  2. 如果 start_dim=0 且 end_dim=-1,则整个张量会被展平为一维。
  3. 如果只展平部分维度,则其他维度保持不变。

(三)示例代码

  • 示例 1:展平整个张量
python 复制代码
import torch
# 创建一个二维张量
x = torch.arange(12).view(3, 4)
print("原始张量:\n", x)
# 使用 flatten 展平整个张量
y = torch.flatten(x)
print("展平后的张量:", y)
  • 输出:
python 复制代码
原始张量:
 tensor([[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]])
展平后的张量: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])
  • 示例 2:展平部分维度
python 复制代码
import torch
# 创建一个三维张量
x = torch.arange(24).view(2, 3, 4)
print("原始张量:\n", x)
# 使用 flatten 展平第二个维度到第三个维度
y = torch.flatten(x, start_dim=1, end_dim=2)
print("展平后的张量:\n", y)
  • 输出:
python 复制代码
原始张量:
 tensor([[[ 0,  1,  2,  3],
          [ 4,  5,  6,  7],
          [ 8,  9, 10, 11]],

         [[12, 13, 14, 15],
          [16, 17, 18, 19],
          [20, 21, 22, 23]]])
展平后的张量:
 tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11],
         [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]])
  • 示例 3:展平特定维度
python 复制代码
import torch
# 创建一个四维张量
x = torch.arange(24).view(2, 2, 3, 2)
print("原始张量:\n", x)
# 使用 flatten 展平第三个维度
y = torch.flatten(x, start_dim=2, end_dim=2)
print("展平后的张量:\n", y)
  • 输出:
python 复制代码
原始张量:
 tensor([[[[ 0,  1],
           [ 2,  3],
           [ 4,  5]],

          [[ 6,  7],
           [ 8,  9],
           [10, 11]]],


         [[[12, 13],
           [14, 15],
           [16, 17]],

          [[18, 19],
           [20, 21],
           [22, 23]]]])
展平后的张量:
 tensor([[[ 0,  1,  2,  3,  4,  5],
          [ 6,  7,  8,  9, 10, 11]],

         [[12, 13, 14, 15, 16, 17],
          [18, 19, 20, 21, 22, 23]]])

六、Permute

(一)torch.permute 的基本概念

torch.permute 的作用是将输入张量的维度按照指定的顺序重新排列。它类似于 NumPy 中的 transpose,但更加灵活,可以同时对多个维度进行排列。

  1. 语法:
python 复制代码
torch.Tensor.permute(*dims)
  1. 参数:
    • *dims:新的维度顺序(元组或多个整数)。
    • 返回值:返回一个新的张量,具有重新排列后的维度顺序,并与原始张量共享数据存储。

(二)torch.permute 的工作原理

torch.permute 会将输入张量的维度按照指定的顺序重新排列。

新的维度顺序必须与原始张量的维度数量相同,并且每个维度索引只能出现一次。

(三) 示例代码

  1. 示例 1:二维张量的转置
python 复制代码
import torch
# 创建一个二维张量
x = torch.arange(12).view(3, 4)
print("原始张量:\n", x)

# 使用 permute 转置
y = x.permute(1, 0)  # 将维度 0 和 1 交换
print("转置后的张量:\n", y)
  • 输出:
python 复制代码
原始张量:
 tensor([[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]])
转置后的张量:
 tensor([[ 0,  4,  8],
         [ 1,  5,  9],
         [ 2,  6, 10],
         [ 3,  7, 11]])
  1. 示例 2:三维张量的维度重排
python 复制代码
import torch
# 创建一个三维张量
x = torch.arange(24).view(2, 3, 4)
print("原始张量:\n", x)

# 使用 permute 重排维度
y = x.permute(2, 0, 1)  # 将维度 0, 1, 2 重排为 2, 0, 1
print("重排后的张量形状:", y.shape)
print("重排后的张量:\n", y)
  • 输出:
python 复制代码
原始张量:
 tensor([[[ 0,  1,  2,  3],
          [ 4,  5,  6,  7],
          [ 8,  9, 10, 11]],

         [[12, 13, 14, 15],
          [16, 17, 18, 19],
          [20, 21, 22, 23]]])
重排后的张量形状: torch.Size([4, 2, 3])
重排后的张量:
 tensor([[[ 0,  4,  8],
          [12, 16, 20]],

         [[ 1,  5,  9],
          [13, 17, 21]],

         [[ 2,  6, 10],
          [14, 18, 22]],

         [[ 3,  7, 11],
          [15, 19, 23]]])
  1. 示例 3:四维张量的维度重排
python 复制代码
import torch
# 创建一个四维张量
x = torch.arange(24).view(2, 2, 3, 2)
print("原始张量:\n", x)

# 使用 permute 重排维度
y = x.permute(3, 1, 2, 0)  # 将维度 0, 1, 2, 3 重排为 3, 1, 2, 0
print("重排后的张量形状:", y.shape)
print("重排后的张量:\n", y)
  • 输出:
python 复制代码
原始张量:
 tensor([[[[ 0,  1],
           [ 2,  3],
           [ 4,  5]],

          [[ 6,  7],
           [ 8,  9],
           [10, 11]]],


         [[[12, 13],
           [14, 15],
           [16, 17]],

          [[18, 19],
           [20, 21],
           [22, 23]]]])
重排后的张量形状: torch.Size([2, 2, 3, 2])
重排后的张量:
 tensor([[[[ 0, 12],
           [ 2, 14],
           [ 4, 16]],

          [[ 6, 18],
           [ 8, 20],
           [10, 22]]],


         [[[ 1, 13],
           [ 3, 15],
           [ 5, 17]],

          [[ 7, 19],
           [ 9, 21],
           [11, 23]]]])

(四) torch.permute 的应用场景

  1. 图像处理

在图像处理中,torch.permute 可以用于调整图像的通道顺序。例如,将 (C, H, W) 的图像张量转换为 (H, W, C)。

  1. 深度学习模型输入

在深度学习中,模型的输入张量通常需要特定的维度顺序。torch.permute 可以用于调整输入张量的维度顺序。

  1. 数据预处理

在数据预处理中,torch.permute 可以用于调整数据的维度顺序,以便进行后续的计算或操作。

如果每一维度都有特定的含义,此时想改变维度的时候用permute。

七、总结

  1. reshape 与 view 几乎一致, 甚至可以说reshape可以代替view;
  2. reshape 与 permute 的区别在于,reshape是按照顺序重新进行排列组合,permute是按照维度进行重新排列组合,如果各个维度都有特定的意义,那么permute会更合适。
  3. unfold是按照某一维度滑动选取数据,新增加一维,新增加的维度的大小为滑动窗口的大小,原始维度会根据滑动窗口和step的大小而变化。
  4. flatten也是按照顺序进行展开,且展开的是某个范围的维度,不是特定的维度。
  5. 不管是几维数组,主要看最后两维,倒数第二维代表行,倒数第一维代表列。
相关推荐
懒大王爱吃狼10 分钟前
Python绘制数据地图-MovingPandas
开发语言·python·信息可视化·python基础·python学习
数据小小爬虫14 分钟前
如何使用Python爬虫按关键字搜索AliExpress商品:代码示例与实践指南
开发语言·爬虫·python
martian66536 分钟前
第17篇:python进阶:详解数据分析与处理
开发语言·python
无码不欢的我40 分钟前
使用vscode在本地和远程服务器端运行和调试Python程序的方法总结
ide·vscode·python
五味香41 分钟前
Java学习,查找List最大最小值
android·java·开发语言·python·学习·golang·kotlin
AIGC大时代42 分钟前
方法建议ChatGPT提示词分享
人工智能·深度学习·chatgpt·aigc·ai写作
糯米导航1 小时前
ChatGPT Prompt 编写指南
人工智能·chatgpt·prompt
金融OG1 小时前
99.8 金融难点通俗解释:净资产收益率(ROE)
大数据·python·线性代数·机器学习·数学建模·金融·矩阵
Damon小智1 小时前
全面评测 DOCA 开发环境下的 DPU:性能表现、机器学习与金融高频交易下的计算能力分析
人工智能·机器学习·金融·边缘计算·nvidia·dpu·doca
赵孝正1 小时前
特征选择(机器学习)
人工智能·机器学习