PyTorch如何通过 torch.unbind 和torch.stack动态调整张量的维度顺序

笔者一篇博客PyTorch 的 torch.unbind 函数详解与进阶应用:中英双语中有一个例子如下:

c 复制代码
# 创建一个 3x2x2 的三维张量
x = torch.tensor([[[1, 2], [3, 4]],
                  [[5, 6], [7, 8]],
                  [[9, 10], [11, 12]]])

# 第一步:沿第 0 维分解为 3 个 2x2 张量
unbind_result = torch.unbind(x, dim=0)

# 第二步:沿第 2 维重新堆叠
stack_result = torch.stack(unbind_result, dim=2)
print("最终结果:", stack_result)

结果

c 复制代码
最终结果:
tensor([[[ 1,  5,  9],
         [ 3,  7, 11]],

        [[ 2,  6, 10],
         [ 4,  8, 12]]])
  • 使用 torch.unbind 沿第 0 维分解。
  • 使用 torch.stack 沿第 2 维重新组合,从而完成了维度转换。

张量的形状在每一步的变化如下:

  • 原始张量形状为 [3, 2, 2]。
  • 分解后,得到 3 个形状为 [2, 2] 的张量。
  • 堆叠时,将这些张量沿新的维度 dim=2 组合,最终形状变为 [2, 2, 3]。

通过这种分解和堆叠方式,我们可以灵活地操作张量的维度和数据布局。

具体是怎么变的,这里记录一下。

这个例子展示了如何通过 torch.unbindtorch.stack 动态调整张量的维度顺序。以下是对这个例子的详细解释,包括每一步的操作和张量形状变化:


1. 初始张量

我们先创建一个形状为 [3, 2, 2] 的张量 x

python 复制代码
x = torch.tensor([[[1, 2], [3, 4]],
                  [[5, 6], [7, 8]],
                  [[9, 10], [11, 12]]])

张量的内容

c 复制代码
x = [
      [[1, 2],  [3, 4]],    # 第一个"平面"
      [[5, 6],  [7, 8]],    # 第二个"平面"
      [[9, 10], [11, 12]]   # 第三个"平面"
    ]

形状[3, 2, 2]

这里的含义:

  • 第一维度(dim=0,大小为3):有3个"平面"(或者块)。
  • 第二维度(dim=1,大小为2):每个"平面"有两行。
  • 第三维度(dim=2,大小为2):每行有两个元素。

2. 使用 torch.unbind 沿 dim=0 分解

python 复制代码
unbind_result = torch.unbind(x, dim=0)

torch.unbind 的作用是沿着指定的维度(这里是 dim=0)移除这一维度,并返回一个元组,元组中的每个元素都是输入张量在该维度上的切片。

对于我们的例子:

  • x 沿着 dim=0 分解,相当于把张量按"平面"切开。
  • 原始的 3×2×2 张量被分成了 3 个形状为 [2, 2] 的子张量。

unbind_result 的内容

c 复制代码
unbind_result = (
    tensor([[1, 2],  [3, 4]]),  # 第一个平面
    tensor([[5, 6],  [7, 8]]),  # 第二个平面
    tensor([[9, 10], [11, 12]]) # 第三个平面
)

每个切片都是一个形状为 [2, 2] 的二维张量。

这里的维度变化:

  • 原始张量形状 [3, 2, 2] → 切片形状 [2, 2]

3. 使用 torch.stack 沿 dim=2 重新组合

python 复制代码
stack_result = torch.stack(unbind_result, dim=2)

torch.stack 的作用是把一组张量沿着新的维度拼接起来。这里:

  • unbind_result 是一个包含 3 个 [2, 2] 张量的元组。
  • 我们指定 dim=2,意思是在原始张量的最后一维(第三维)增加一个新的维度来进行拼接。
拼接过程
  1. 第一个子张量的每个位置与第二个、第三个子张量的对应位置对齐,按列方向拼接。
  2. 拼接后,原来 [2, 2] 的子张量变成了 [2, 3] 的子张量。

举例说明:

  • 原始三个 [2, 2] 的张量:

    c 复制代码
    tensor([[1, 2], [3, 4]])
    tensor([[5, 6], [7, 8]])
    tensor([[9, 10], [11, 12]])
  • 沿 dim=2 进行拼接后:

    c 复制代码
    [
      [[1, 5, 9], [3, 7, 11]],  # 第一行拼接
      [[2, 6, 10], [4, 8, 12]]  # 第二行拼接
    ]

最终结果

c 复制代码
stack_result = tensor([
    [[ 1,  5,  9], [ 3,  7, 11]],
    [[ 2,  6, 10], [ 4,  8, 12]]
])

形状变化

  • 原始张量 [3, 2, 2] → 分解后的切片 [2, 2] → 拼接后的结果 [2, 2, 3]

4. 形状变化总结

操作 张量内容 张量形状
初始张量 x [3, 2, 2]
使用 torch.unbind(dim=0) 3 个 [2, 2] 的子张量 [2, 2]
使用 torch.stack(dim=2) 拼接为一个新的张量 [2, 2, 3]

5. 为什么维度顺序调整了?

通过 torch.unbindtorch.stack 的组合,实际上我们重新定义了张量的组织方式:

  1. torch.unbinddim=0 的维度移除,分解成多个子张量。
  2. torch.stack 指定新的维度(这里是 dim=2),将这些子张量拼接为一个新维度,从而实现了维度的重新排列。

最终,我们将原来的"平面"维度(dim=0)转移到了列方向(dim=2),实现了动态调整维度顺序的效果。


6. 总结

  • torch.unbind 用于移除一个维度并分解张量
  • torch.stack 用于沿指定的新维度拼接张量
  • 两者结合可以灵活调整张量的维度顺序。

这个例子展示了如何从 [3, 2, 2] 变换到 [2, 2, 3],过程中分解和拼接操作相辅相成,适用于需要动态调整张量维度的高级场景。

后记

2024年12月12日22点28分于上海,基于GPT4o大模型生成。

相关推荐
TGC达成共识2 分钟前
解锁处暑健康生活
人工智能·科技·其他·安全·生活·美食·风景
猫头虎20 分钟前
什么是AI+?什么是人工智能+?
人工智能·ai·prompt·aigc·数据集·ai编程·mcp
站大爷IP20 分钟前
Python多线程与多进程性能对比:从原理到实战的深度解析
python
聚客AI21 分钟前
💡为什么你的RAG回答总是胡言乱语?致命瓶颈在数据预处理层
人工智能·langchain·llm
东方佑29 分钟前
Python音频分析与线性回归:探索声音中的数学之美
python·音视频·线性回归
彭军辉32 分钟前
什么是AI宠物
人工智能
siliconstorm.ai43 分钟前
穿越周期:AIoT产业的真实突破口与实践路径
大数据·人工智能
爱喝奶茶的企鹅1 小时前
Ethan独立开发新品速递 | 2025-08-27
人工智能
武子康1 小时前
AI-调查研究-59-机器人 行业职业地图:发展路径、技能要求与薪资全解读
人工智能·gpt·程序人生·ai·职场和发展·机器人·个人开发
大视码垛机1 小时前
大视码垛机器人:以技术优势撬动工业码垛升级
人工智能·机器人·自动化·制造