PyTorch如何通过 torch.unbind 和torch.stack动态调整张量的维度顺序

笔者一篇博客PyTorch 的 torch.unbind 函数详解与进阶应用:中英双语中有一个例子如下:

c 复制代码
# 创建一个 3x2x2 的三维张量
x = torch.tensor([[[1, 2], [3, 4]],
                  [[5, 6], [7, 8]],
                  [[9, 10], [11, 12]]])

# 第一步:沿第 0 维分解为 3 个 2x2 张量
unbind_result = torch.unbind(x, dim=0)

# 第二步:沿第 2 维重新堆叠
stack_result = torch.stack(unbind_result, dim=2)
print("最终结果:", stack_result)

结果

c 复制代码
最终结果:
tensor([[[ 1,  5,  9],
         [ 3,  7, 11]],

        [[ 2,  6, 10],
         [ 4,  8, 12]]])
  • 使用 torch.unbind 沿第 0 维分解。
  • 使用 torch.stack 沿第 2 维重新组合,从而完成了维度转换。

张量的形状在每一步的变化如下:

  • 原始张量形状为 [3, 2, 2]。
  • 分解后,得到 3 个形状为 [2, 2] 的张量。
  • 堆叠时,将这些张量沿新的维度 dim=2 组合,最终形状变为 [2, 2, 3]。

通过这种分解和堆叠方式,我们可以灵活地操作张量的维度和数据布局。

具体是怎么变的,这里记录一下。

这个例子展示了如何通过 torch.unbindtorch.stack 动态调整张量的维度顺序。以下是对这个例子的详细解释,包括每一步的操作和张量形状变化:


1. 初始张量

我们先创建一个形状为 [3, 2, 2] 的张量 x

python 复制代码
x = torch.tensor([[[1, 2], [3, 4]],
                  [[5, 6], [7, 8]],
                  [[9, 10], [11, 12]]])

张量的内容

c 复制代码
x = [
      [[1, 2],  [3, 4]],    # 第一个"平面"
      [[5, 6],  [7, 8]],    # 第二个"平面"
      [[9, 10], [11, 12]]   # 第三个"平面"
    ]

形状[3, 2, 2]

这里的含义:

  • 第一维度(dim=0,大小为3):有3个"平面"(或者块)。
  • 第二维度(dim=1,大小为2):每个"平面"有两行。
  • 第三维度(dim=2,大小为2):每行有两个元素。

2. 使用 torch.unbind 沿 dim=0 分解

python 复制代码
unbind_result = torch.unbind(x, dim=0)

torch.unbind 的作用是沿着指定的维度(这里是 dim=0)移除这一维度,并返回一个元组,元组中的每个元素都是输入张量在该维度上的切片。

对于我们的例子:

  • x 沿着 dim=0 分解,相当于把张量按"平面"切开。
  • 原始的 3×2×2 张量被分成了 3 个形状为 [2, 2] 的子张量。

unbind_result 的内容

c 复制代码
unbind_result = (
    tensor([[1, 2],  [3, 4]]),  # 第一个平面
    tensor([[5, 6],  [7, 8]]),  # 第二个平面
    tensor([[9, 10], [11, 12]]) # 第三个平面
)

每个切片都是一个形状为 [2, 2] 的二维张量。

这里的维度变化:

  • 原始张量形状 [3, 2, 2] → 切片形状 [2, 2]

3. 使用 torch.stack 沿 dim=2 重新组合

python 复制代码
stack_result = torch.stack(unbind_result, dim=2)

torch.stack 的作用是把一组张量沿着新的维度拼接起来。这里:

  • unbind_result 是一个包含 3 个 [2, 2] 张量的元组。
  • 我们指定 dim=2,意思是在原始张量的最后一维(第三维)增加一个新的维度来进行拼接。
拼接过程
  1. 第一个子张量的每个位置与第二个、第三个子张量的对应位置对齐,按列方向拼接。
  2. 拼接后,原来 [2, 2] 的子张量变成了 [2, 3] 的子张量。

举例说明:

  • 原始三个 [2, 2] 的张量:

    c 复制代码
    tensor([[1, 2], [3, 4]])
    tensor([[5, 6], [7, 8]])
    tensor([[9, 10], [11, 12]])
  • 沿 dim=2 进行拼接后:

    c 复制代码
    [
      [[1, 5, 9], [3, 7, 11]],  # 第一行拼接
      [[2, 6, 10], [4, 8, 12]]  # 第二行拼接
    ]

最终结果

c 复制代码
stack_result = tensor([
    [[ 1,  5,  9], [ 3,  7, 11]],
    [[ 2,  6, 10], [ 4,  8, 12]]
])

形状变化

  • 原始张量 [3, 2, 2] → 分解后的切片 [2, 2] → 拼接后的结果 [2, 2, 3]

4. 形状变化总结

操作 张量内容 张量形状
初始张量 x [3, 2, 2]
使用 torch.unbind(dim=0) 3 个 [2, 2] 的子张量 [2, 2]
使用 torch.stack(dim=2) 拼接为一个新的张量 [2, 2, 3]

5. 为什么维度顺序调整了?

通过 torch.unbindtorch.stack 的组合,实际上我们重新定义了张量的组织方式:

  1. torch.unbinddim=0 的维度移除,分解成多个子张量。
  2. torch.stack 指定新的维度(这里是 dim=2),将这些子张量拼接为一个新维度,从而实现了维度的重新排列。

最终,我们将原来的"平面"维度(dim=0)转移到了列方向(dim=2),实现了动态调整维度顺序的效果。


6. 总结

  • torch.unbind 用于移除一个维度并分解张量
  • torch.stack 用于沿指定的新维度拼接张量
  • 两者结合可以灵活调整张量的维度顺序。

这个例子展示了如何从 [3, 2, 2] 变换到 [2, 2, 3],过程中分解和拼接操作相辅相成,适用于需要动态调整张量维度的高级场景。

后记

2024年12月12日22点28分于上海,基于GPT4o大模型生成。

相关推荐
AI_56784 小时前
AI无人机如何让安全隐患无处遁形
人工智能·无人机
机器之心4 小时前
DeepSeek强势回归,开源IMO金牌级数学模型
人工智能·openai
机器之心4 小时前
华为放出「准万亿级MoE推理」大招,两大杀手级优化技术直接开源
人工智能·openai
大力财经4 小时前
零跑Lafa5正式上市 以“五大硬核实力”开启品牌个性化新篇章
人工智能
ECT-OS-JiuHuaShan4 小时前
否定之否定的辩证法,谁会不承认?但又有多少人说的透?
开发语言·人工智能·数学建模·生活·学习方法·量子计算·拓扑学
软件开发技术深度爱好者4 小时前
基于多个大模型自己建造一个AI智能助手(增强版)
人工智能
c***87194 小时前
Flask:后端框架使用
后端·python·flask
骥龙4 小时前
4.12、隐私保护机器学习:联邦学习在安全数据协作中的应用
人工智能·安全·网络安全
天硕国产存储技术站4 小时前
DualPLP 双重掉电保护赋能 天硕工业级SSD筑牢关键领域安全存储方案
大数据·人工智能·安全·固态硬盘
腾讯云开发者4 小时前
AI独孤九剑:AI没有场景,无法落地?不存在的。
人工智能