PyTorch维度操作的函数介绍

在 PyTorch 中,操作张量的维度是常见的需求,特别是在处理多维数据时。PyTorch 提供了一系列函数来操作张量的维度,包括改变维度顺序、添加或删除维度、扩展维度等。下面是一些常用的维度操作函数及其示例代码。

1. view()

  • 作用 :重新调整张量的形状(维度),但不改变其数据内容。view() 是基于张量的原始内存布局进行操作的,要求重新调整的形状能与原始数据兼容。

  • 示例

    import torch

    创建一个形状为 [2, 3, 4] 的张量

    tensor = torch.randn(2, 3, 4)

    调整为形状为 [6, 4] 的张量

    reshaped = tensor.view(6, 4)
    print(reshaped.shape) # 输出: torch.Size([6, 4])

2. permute()

  • 作用:重新排列张量的维度顺序。

  • 示例

    import torch

    创建一个形状为 [2, 3, 4] 的张量

    tensor = torch.randn(2, 3, 4)

    交换第一个维度和第二个维度,得到形状为 [3, 2, 4] 的张量

    permuted = tensor.permute(1, 0, 2)
    print(permuted.shape) # 输出: torch.Size([3, 2, 4])

3. unsqueeze()

  • 作用:在指定位置插入一个大小为 1 的新维度。

  • 示例

    import torch

    创建一个形状为 [3, 4] 的张量

    tensor = torch.randn(3, 4)

    在第 0 维添加一个新维度,结果形状为 [1, 3, 4]

    unsqueezed = tensor.unsqueeze(0)
    print(unsqueezed.shape) # 输出: torch.Size([1, 3, 4])

4. squeeze()

  • 作用:移除张量中所有大小为 1 的维度。

  • 示例

    import torch

    创建一个形状为 [1, 3, 1, 4] 的张量

    tensor = torch.randn(1, 3, 1, 4)

    移除所有大小为 1 的维度,结果形状为 [3, 4]

    squeezed = tensor.squeeze()
    print(squeezed.shape) # 输出: torch.Size([3, 4])

5. transpose()

  • 作用:交换张量的两个指定维度。

  • 示例

    import torch

    创建一个形状为 [2, 3, 4] 的张量

    tensor = torch.randn(2, 3, 4)

    交换第 1 维和第 2 维,结果形状为 [2, 4, 3]

    transposed = tensor.transpose(1, 2)
    print(transposed.shape) # 输出: torch.Size([2, 4, 3])

6. expand()

  • 作用:将张量的某些维度扩展为更大的尺寸,不会复制数据,而是通过广播机制扩展。

  • 示例

    import torch

    创建一个形状为 [2, 1, 4] 的张量

    tensor = torch.randn(2, 1, 4)

    扩展第 1 维到大小为 3,结果形状为 [2, 3, 4]

    expanded = tensor.expand(2, 3, 4)
    print(expanded.shape) # 输出: torch.Size([2, 3, 4])

7. repeat()

  • 作用:沿着指定的维度重复张量的元素。

  • 示例

    import torch

    创建一个形状为 [2, 3] 的张量

    tensor = torch.randn(2, 3)

    沿着第 0 维和第 1 维分别重复 2 次和 3 次,结果形状为 [4, 9]

    repeated = tensor.repeat(2, 3)
    print(repeated.shape) # 输出: torch.Size([4, 9])

8. cat()

  • 作用:在指定维度上连接多个张量。

  • 示例

    import torch

    创建两个形状为 [2, 3] 的张量

    tensor1 = torch.randn(2, 3)
    tensor2 = torch.randn(2, 3)

    在第 0 维连接,结果形状为 [4, 3]

    concatenated = torch.cat([tensor1, tensor2], dim=0)
    print(concatenated.shape) # 输出: torch.Size([4, 3])

9. stack()

  • 作用:在新的维度上堆叠多个张量。

  • 示例

    import torch

    创建两个形状为 [2, 3] 的张量

    tensor1 = torch.randn(2, 3)
    tensor2 = torch.randn(2, 3)

    在新的第 0 维堆叠,结果形状为 [2, 2, 3]

    stacked = torch.stack([tensor1, tensor2], dim=0)
    print(stacked.shape) # 输出: torch.Size([2, 2, 3])

总结

PyTorch 提供了丰富的维度操作函数,使得张量的操作非常灵活。在处理多维数据时,合理使用这些函数可以极大地简化代码,并提高数据处理的效率。

相关推荐
jndingxin4 分钟前
OpenCV特征检测(1)检测图像中的线段的类LineSegmentDe()的使用
人工智能·opencv·计算机视觉
@月落14 分钟前
alibaba获得店铺的所有商品 API接口
java·大数据·数据库·人工智能·学习
z千鑫23 分钟前
【人工智能】如何利用AI轻松将java,c++等代码转换为Python语言?程序员必读
java·c++·人工智能·gpt·agent·ai编程·ai工具
MinIO官方账号41 分钟前
从 HDFS 迁移到 MinIO 企业对象存储
人工智能·分布式·postgresql·架构·开源
aWty_1 小时前
机器学习--K-Means
人工智能·机器学习·kmeans
草莓屁屁我不吃1 小时前
AI大语言模型的全面解读
人工智能·语言模型·自然语言处理·chatgpt
农民小飞侠1 小时前
python AutoGen接入开源模型xLAM-7b-fc-r,测试function calling的功能
开发语言·python
战神刘玉栋1 小时前
《程序猿之设计模式实战 · 观察者模式》
python·观察者模式·设计模式
敲代码不忘补水1 小时前
Python 项目实践:简单的计算器
开发语言·python·json·项目实践
WPG大大通1 小时前
有奖直播 | onsemi IPM 助力汽车电气革命及电子化时代冷热管理
大数据·人工智能·汽车·方案·电气·大大通·研讨会