【学习笔记】深入浅出详解Pytorch中的View, reshape, unfold,flatten等方法。

文章目录

一、写在前面

最近在解析transformer源码的时候突然看到了unfold?我在想unfold是什么意思?为什么不用reshape,他们的底层逻辑有什么区别呢?于是便相对比一下他们之间的区别,便有了本篇博客,希望对大家有帮助!

二、Reshape

(一)用法

1. torch.reshape(input, shape)

输入是tensor和shape,其中原始shape和目标shape的元素数量要一致。

2. Tensor.reshape(shape) → Tensor

与上述用法一致,只不过这个是直接在tensor的基础上进行reshape。

reshpe是按照顺序进行重新排列组合的。

[16,2] 其实与 [4,2,2,2] 是一样的,只要最后一维的数字是一样,其实结果都是一样的。

(二)代码展示

三、Unfold

(一)torch.unfold 的基本概念

torch.unfold 的作用是将输入张量的某个维度展开为多个滑动窗口。每个窗口包含一个局部区域,这些区域可以用于后续的计算。

  1. 语法:
python 复制代码
torch.Tensor.unfold(dimension, size, step)
  1. 参数:
    • dimension:要展开的维度(整数)。
    • size:滑动窗口的大小(整数)。
    • step:滑动窗口的步长(整数)。
    • 返回值:返回一个新的张量,其中指定维度的每个元素被展开为多个滑动窗口。

(二)torch.unfold 的工作原理

假设我们有一个形状为 (N, C, H, W) 的张量(例如图像数据),我们希望在高度维度(H)上提取滑动窗口。

  • 输入张量:(N, C, H, W)
  • 展开维度:dimension=2(即高度维度 H)
  • 窗口大小:size=k(例如 k=3)
  • 步长:step=s(例如 s=1)

torch.unfold 会将高度维度 H 展开为多个大小为 k 的滑动窗口,每个窗口之间间隔 s。

(三) 示例代码

  • 示例 1:一维张量的展开
python 复制代码
import torch
# 创建一个一维张量
x = torch.arange(10)
print("原始张量:", x)

# 使用 unfold 展开
unfolded = x.unfold(dimension=0, size=3, step=1)
print("展开后的张量:\n", unfolded)
  • 输出:
python 复制代码
原始张量: tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
展开后的张量:
 tensor([[0, 1, 2],
         [1, 2, 3],
         [2, 3, 4],
         [3, 4, 5],
         [4, 5, 6],
         [5, 6, 7],
         [6, 7, 8],
         [7, 8, 9]])
  • 示例 2:二维张量的展开(图像处理)
python 复制代码
import torch
# 创建一个二维张量(模拟图像)
    x = torch.arange(16).reshape(1, 1, 4, 4)  # 形状为 (1, 1, 4, 4)
    print("原始张量:\n", x)
    # 使用 unfold 展开
    unfolded_1 = x.unfold(dimension=2, size=3, step=1)
    unfolded_2 = unfolded_1.unfold(dimension=3, size=3, step=1)
    print("展开后的张量形状:", unfolded_1.shape)
    print("展开后的张量:\n", unfolded_1)

    print("展开后的张量形状:", unfolded_2.shape)
    print("展开后的张量:\n", unfolded_2)
  • 输出:
python 复制代码
原始张量:
 tensor([[[[ 0,  1,  2,  3],
          [ 4,  5,  6,  7],
          [ 8,  9, 10, 11],
          [12, 13, 14, 15]]]])
展开后的张量形状: torch.Size([1, 1, 2, 4, 3])
展开后的张量:
 tensor([[[[[ 0,  4,  8],
           [ 1,  5,  9],
           [ 2,  6, 10],
           [ 3,  7, 11]],

          [[ 4,  8, 12],
           [ 5,  9, 13],
           [ 6, 10, 14],
           [ 7, 11, 15]]]]])
展开后的张量形状: torch.Size([1, 1, 2, 2, 3, 3])
展开后的张量:
 tensor([[[[[[ 0,  1,  2],
            [ 4,  5,  6],
            [ 8,  9, 10]],

           [[ 1,  2,  3],
            [ 5,  6,  7],
            [ 9, 10, 11]]],


          [[[ 4,  5,  6],
            [ 8,  9, 10],
            [12, 13, 14]],

           [[ 5,  6,  7],
            [ 9, 10, 11],
            [13, 14, 15]]]]]])

数组的运算主要看最后两维,倒数第二维代表行,倒数第一维代表列。

(四)torch.unfold 的应用场景

  1. 卷积操作

在卷积神经网络(CNN)中,卷积核通过滑动窗口的方式提取图像的局部特征。torch.unfold 可以用于手动实现卷积操作。

  1. 图像处理

在图像处理中,torch.unfold 可以用于提取图像的局部区域(例如提取图像的滑动窗口)。

  1. 时间序列分析

在时间序列分析中,torch.unfold 可以用于提取时间序列的滑动窗口,用于特征提取或模型训练。

(五)注意事项

  1. 维度选择:需要明确指定要展开的维度(dimension)。
  2. 窗口大小和步长:窗口大小(size)和步长(step)的选择会影响展开后的张量形状。
  3. 内存消耗:torch.unfold 可能会生成较大的张量,尤其是在高维数据上使用时,需要注意内存消耗。

(六)总结

  1. torch.unfold 的作用:从张量的某个维度提取滑动窗口。
  2. 常用参数:dimension(展开维度)、size(窗口大小)、step(步长)。
  3. 应用场景:卷积操作、图像处理、时间序列分析等。
  4. 注意事项:选择合适的维度、窗口大小和步长,避免内存消耗过大。

四、View

torch.view 用于返回一个与原始张量共享相同数据存储的新张量,但具有不同的形状。换句话说,view 只是改变了张量的视图(view),而不会复制数据。

(一)用法

  1. Tensor.view(*shape):

    • 参数:

    *shape:新的形状(可以是整数或元组)。

    • 返回值:

    返回一个新的张量,具有指定的形状,并与原始张量共享相同的数据存储。

(二)注意事项

  1. view 返回的张量与原始张量共享相同的数据存储。
  2. 如果原始张量的数据发生变化,view 返回的张量也会随之变化。
  3. view 要求张量的内存必须是连续的(即张量在内存中是连续存储的)。如果内存不连续,view 会抛出错误。

(三)其他方法

  1. view_as
  • 作用:将当前张量转换为与另一个张量相同的形状。
  • 语法:
python 复制代码
torch.Tensor.view_as(other)
  • 参数:

other:目标张量,当前张量的形状将被转换为与 other 相同的形状。

  • 返回值:

返回一个新的张量,具有与 other 相同的形状,并与原始张量共享数据存储。

  • 示例:
python 复制代码
import torch
# 创建一个张量
x = torch.arange(12)
print("原始张量:", x)
# 创建目标张量
other = torch.empty(3, 4)
# 使用 view_as 改变形状
y = x.view_as(other)
print("改变形状后的张量:\n", y)
  • 输出:
python 复制代码
原始张量: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])
改变形状后的张量:
 tensor([[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]])
  1. view_as_real
  • 作用:将复数张量转换为实数张量。
  • 语法:
python 复制代码
torch.Tensor.view_as_real()
  • 返回值:

返回一个新的实数张量,形状为 (..., 2),其中最后一个维度包含复数的实部和虚部。

  • 示例:
python 复制代码
import torch
# 创建一个复数张量
x = torch.tensor([1 + 2j, 3 + 4j])
print("原始张量:", x)
# 使用 view_as_real 转换为实数张量
y = x.view_as_real()
print("转换后的张量:\n", y)
  • 输出:
python 复制代码
原始张量: tensor([1.+2.j, 3.+4.j])
转换后的张量:
 tensor([[1., 2.],
         [3., 4.]])
  1. view_as_complex
  • 作用:将实数张量转换为复数张量。
  • 语法:
python 复制代码
torch.Tensor.view_as_complex()
  • 返回值:

返回一个新的复数张量,形状为 (...,),其中最后一个维度被解释为复数的实部和虚部。

  • 示例:
python 复制代码
import torch
# 创建一个实数张量
x = torch.tensor([[1., 2.], [3., 4.]])
print("原始张量:\n", x)
# 使用 view_as_complex 转换为复数张量
y = x.view_as_complex()
print("转换后的张量:", y)
  • 输出:
python 复制代码
原始张量:
 tensor([[1., 2.],
         [3., 4.]])
转换后的张量: tensor([1.+2.j, 3.+4.j])
  1. view_as_strided
  • 作用:返回一个具有指定步长和内存布局的张量视图。
  • 语法:
python 复制代码
torch.Tensor.view_as_strided(size, stride)
  • 参数:

size:新的形状(元组)。

stride:新的步长(元组)。

  • 返回值:

返回一个新的张量,具有指定的形状和步长,并与原始张量共享数据存储。

  • 示例:
python 复制代码
import torch
# 创建一个张量
x = torch.arange(9).view(3, 3)
print("原始张量:\n", x)
# 使用 view_as_strided 改变形状和步长
y = x.view_as_strided((2, 2), (1, 2))
print("改变形状和步长后的张量:\n", y)
  • 输出:
python 复制代码
原始张量:
 tensor([[0, 1, 2],
         [3, 4, 5],
         [6, 7, 8]])
改变形状和步长后的张量:
 tensor([[0, 2],
         [1, 3]])
  1. view_as_real 和 view_as_complex 的应用场景
  • 复数计算:

在涉及复数计算的任务中,view_as_real 和 view_as_complex 可以用于在复数和实数之间进行转换。

  • 信号处理:

在信号处理中,复数张量常用于表示频域信号,view_as_real 和 view_as_complex 可以用于频域和时域之间的转换。

  1. view_as_strided 的应用场景
  • 自定义内存布局:

在需要自定义内存布局的场景中,view_as_strided 可以用于创建具有特定步长和形状的张量视图。

  • 高效内存访问:

在需要高效访问内存的场景中,view_as_strided 可以用于优化内存访问模式。

五、Flatten

(一)torch.flatten 的基本概念

torch.flatten 的作用是将输入张量的指定维度展平为一维。它可以展平整个张量,也可以只展平部分维度。

  1. 语法:
python 复制代码
torch.flatten(input, start_dim=0, end_dim=-1)
  1. 参数:

    • input:输入张量。
    • start_dim:开始展平的维度(整数),默认为 0。
    • end_dim:结束展平的维度(整数),默认为 -1。
    • 返回值:返回一个新的张量,具有展平后的形状。

(二)torch.flatten 的工作原理

  1. torch.flatten 会将指定范围内的维度展平为一维。
  2. 如果 start_dim=0 且 end_dim=-1,则整个张量会被展平为一维。
  3. 如果只展平部分维度,则其他维度保持不变。

(三)示例代码

  • 示例 1:展平整个张量
python 复制代码
import torch
# 创建一个二维张量
x = torch.arange(12).view(3, 4)
print("原始张量:\n", x)
# 使用 flatten 展平整个张量
y = torch.flatten(x)
print("展平后的张量:", y)
  • 输出:
python 复制代码
原始张量:
 tensor([[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]])
展平后的张量: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])
  • 示例 2:展平部分维度
python 复制代码
import torch
# 创建一个三维张量
x = torch.arange(24).view(2, 3, 4)
print("原始张量:\n", x)
# 使用 flatten 展平第二个维度到第三个维度
y = torch.flatten(x, start_dim=1, end_dim=2)
print("展平后的张量:\n", y)
  • 输出:
python 复制代码
原始张量:
 tensor([[[ 0,  1,  2,  3],
          [ 4,  5,  6,  7],
          [ 8,  9, 10, 11]],

         [[12, 13, 14, 15],
          [16, 17, 18, 19],
          [20, 21, 22, 23]]])
展平后的张量:
 tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11],
         [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]])
  • 示例 3:展平特定维度
python 复制代码
import torch
# 创建一个四维张量
x = torch.arange(24).view(2, 2, 3, 2)
print("原始张量:\n", x)
# 使用 flatten 展平第三个维度
y = torch.flatten(x, start_dim=2, end_dim=2)
print("展平后的张量:\n", y)
  • 输出:
python 复制代码
原始张量:
 tensor([[[[ 0,  1],
           [ 2,  3],
           [ 4,  5]],

          [[ 6,  7],
           [ 8,  9],
           [10, 11]]],


         [[[12, 13],
           [14, 15],
           [16, 17]],

          [[18, 19],
           [20, 21],
           [22, 23]]]])
展平后的张量:
 tensor([[[ 0,  1,  2,  3,  4,  5],
          [ 6,  7,  8,  9, 10, 11]],

         [[12, 13, 14, 15, 16, 17],
          [18, 19, 20, 21, 22, 23]]])

六、Permute

(一)torch.permute 的基本概念

torch.permute 的作用是将输入张量的维度按照指定的顺序重新排列。它类似于 NumPy 中的 transpose,但更加灵活,可以同时对多个维度进行排列。

  1. 语法:
python 复制代码
torch.Tensor.permute(*dims)
  1. 参数:
    • *dims:新的维度顺序(元组或多个整数)。
    • 返回值:返回一个新的张量,具有重新排列后的维度顺序,并与原始张量共享数据存储。

(二)torch.permute 的工作原理

torch.permute 会将输入张量的维度按照指定的顺序重新排列。

新的维度顺序必须与原始张量的维度数量相同,并且每个维度索引只能出现一次。

(三) 示例代码

  1. 示例 1:二维张量的转置
python 复制代码
import torch
# 创建一个二维张量
x = torch.arange(12).view(3, 4)
print("原始张量:\n", x)

# 使用 permute 转置
y = x.permute(1, 0)  # 将维度 0 和 1 交换
print("转置后的张量:\n", y)
  • 输出:
python 复制代码
原始张量:
 tensor([[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]])
转置后的张量:
 tensor([[ 0,  4,  8],
         [ 1,  5,  9],
         [ 2,  6, 10],
         [ 3,  7, 11]])
  1. 示例 2:三维张量的维度重排
python 复制代码
import torch
# 创建一个三维张量
x = torch.arange(24).view(2, 3, 4)
print("原始张量:\n", x)

# 使用 permute 重排维度
y = x.permute(2, 0, 1)  # 将维度 0, 1, 2 重排为 2, 0, 1
print("重排后的张量形状:", y.shape)
print("重排后的张量:\n", y)
  • 输出:
python 复制代码
原始张量:
 tensor([[[ 0,  1,  2,  3],
          [ 4,  5,  6,  7],
          [ 8,  9, 10, 11]],

         [[12, 13, 14, 15],
          [16, 17, 18, 19],
          [20, 21, 22, 23]]])
重排后的张量形状: torch.Size([4, 2, 3])
重排后的张量:
 tensor([[[ 0,  4,  8],
          [12, 16, 20]],

         [[ 1,  5,  9],
          [13, 17, 21]],

         [[ 2,  6, 10],
          [14, 18, 22]],

         [[ 3,  7, 11],
          [15, 19, 23]]])
  1. 示例 3:四维张量的维度重排
python 复制代码
import torch
# 创建一个四维张量
x = torch.arange(24).view(2, 2, 3, 2)
print("原始张量:\n", x)

# 使用 permute 重排维度
y = x.permute(3, 1, 2, 0)  # 将维度 0, 1, 2, 3 重排为 3, 1, 2, 0
print("重排后的张量形状:", y.shape)
print("重排后的张量:\n", y)
  • 输出:
python 复制代码
原始张量:
 tensor([[[[ 0,  1],
           [ 2,  3],
           [ 4,  5]],

          [[ 6,  7],
           [ 8,  9],
           [10, 11]]],


         [[[12, 13],
           [14, 15],
           [16, 17]],

          [[18, 19],
           [20, 21],
           [22, 23]]]])
重排后的张量形状: torch.Size([2, 2, 3, 2])
重排后的张量:
 tensor([[[[ 0, 12],
           [ 2, 14],
           [ 4, 16]],

          [[ 6, 18],
           [ 8, 20],
           [10, 22]]],


         [[[ 1, 13],
           [ 3, 15],
           [ 5, 17]],

          [[ 7, 19],
           [ 9, 21],
           [11, 23]]]])

(四) torch.permute 的应用场景

  1. 图像处理

在图像处理中,torch.permute 可以用于调整图像的通道顺序。例如,将 (C, H, W) 的图像张量转换为 (H, W, C)。

  1. 深度学习模型输入

在深度学习中,模型的输入张量通常需要特定的维度顺序。torch.permute 可以用于调整输入张量的维度顺序。

  1. 数据预处理

在数据预处理中,torch.permute 可以用于调整数据的维度顺序,以便进行后续的计算或操作。

如果每一维度都有特定的含义,此时想改变维度的时候用permute。

七、总结

  1. reshape 与 view 几乎一致, 甚至可以说reshape可以代替view;
  2. reshape 与 permute 的区别在于,reshape是按照顺序重新进行排列组合,permute是按照维度进行重新排列组合,如果各个维度都有特定的意义,那么permute会更合适。
  3. unfold是按照某一维度滑动选取数据,新增加一维,新增加的维度的大小为滑动窗口的大小,原始维度会根据滑动窗口和step的大小而变化。
  4. flatten也是按照顺序进行展开,且展开的是某个范围的维度,不是特定的维度。
  5. 不管是几维数组,主要看最后两维,倒数第二维代表行,倒数第一维代表列。
相关推荐
Srlua4 分钟前
辅助任务改进社交帖子多模态分类
人工智能·python
兔子的洋葱圈5 分钟前
Python的3D可视化库【vedo】2-5 (plotter模块) 坐标转换、场景导出、添加控件
python·3d·数据可视化
drebander12 分钟前
基于 Python 将 PDF 转 Markdown 并拆解为 JSON,支持自定义标题处理
python·pdf·json
L_cl23 分钟前
【NLP 15、深度学习处理文本】
人工智能·深度学习
孔汤姆32 分钟前
渗透测试学习笔记(五)网络
网络·笔记·学习
2401_8711510744 分钟前
十二月第14讲:使用Python实现两组数据纵向排序
开发语言·python·算法
知新_ROL44 分钟前
通过解调使用正则化相位跟踪技术进行相位解包裹
人工智能·算法·机器学习
十月ooOO1 小时前
Docker 学习笔记(持续更新)
笔记·学习·docker
一位小说男主1 小时前
可解释性方法:从理论到实践的深度剖析(续上文)
人工智能·深度学习·机器学习
Cachel wood1 小时前
Vue.js前端框架教程5:Vue数据拷贝和数组函数
linux·前端·vue.js·python·阿里云·前端框架·云计算