在 PyTorch 中,经常需要对 tensor 的形状或维度进行转换,常用的函数有 transpose、permute、view 和 reshape。它们的主要区别如下:
1. view
功能说明
- 作用:改变 tensor 的形状(即维度),但不复制数据,只是重新解释底层数据的布局。
- 要求 :输入 tensor 必须是连续的(contiguous),否则会报错。如果非连续,需要先调用
tensor.contiguous()
。 - 特点:返回的新 tensor 与原 tensor 共享内存(修改其中一个,另一个也会改变)。
示例代码及输出
python
import torch
# 创建一个 1D tensor,包含 12 个元素
x = torch.arange(12)
print("原始 x:")
print(x)
# 输出:
# tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
# 用 view 改变形状为 (3, 4)
x_view = x.view(3, 4)
print("\n使用 view 后的 x_view:")
print(x_view)
# 输出:
# tensor([[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]])
# 修改 x_view 的某个元素,原始 x 的对应元素也会改变
x_view[0, 0] = 100
print("\n修改 x_view 后的 x:")
print(x)
# 输出:
# tensor([100, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
2. reshape
功能说明
- 作用:改变 tensor 的形状,类似于 view,但更灵活。
- 特点 :
- 如果输入 tensor 是连续的,则与 view 效果相同(返回 view,不复制数据)。
- 如果输入 tensor 非连续,则会复制数据,返回的新 tensor 与原 tensor 不共享内存。
示例代码及输出
(1) 连续 tensor 使用 reshape(与 view 相同)
python
# 创建一个连续的 tensor
x1 = torch.arange(12).view(3, 4)
# 使用 reshape 改变形状为 (2, 6)
x1_reshape = x1.reshape(2, 6)
print("连续 tensor 使用 reshape:")
print(x1_reshape)
# 输出:
# tensor([[ 0, 1, 2, 3, 4, 5],
# [ 6, 7, 8, 9, 10, 11]])
(2) 非连续 tensor 使用 reshape
python
# 创建一个 3x4 tensor
x2 = torch.arange(12).view(3, 4)
# 通过 transpose 交换维度后,x2_t 可能非连续
x2_t = x2.transpose(0, 1)
print("\n通过 transpose 后的 x2_t:")
print(x2_t)
# 例如输出:
# tensor([[ 0, 4, 8],
# [ 1, 5, 9],
# [ 2, 6, 10],
# [ 3, 7, 11]])
print("x2_t 是否连续:", x2_t.is_contiguous())
# 输出可能为:False
# 使用 reshape 调整形状为 (2, 6),此时 reshape 会复制数据
x2_reshape = x2_t.reshape(2, 6)
print("\n非连续 tensor 使用 reshape 后:")
print(x2_reshape)
# 输出:
# tensor([[ 0, 4, 8, 1, 5, 9],
# [ 2, 6, 10, 3, 7, 11]])
3. transpose
功能说明
- 作用:交换 tensor 中两个指定的维度。
- 特点 :
- 返回的 tensor 仍为原 tensor 的 view,共享内存;
- 但常常因为交换后内存不再按连续顺序排列,导致
is_contiguous()
返回 False。
示例代码及输出
python
# 创建一个 3x4 tensor
x3 = torch.arange(12).view(3, 4)
print("\n原始 x3:")
print(x3)
# 输出:
# tensor([[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]])
# 使用 transpose 交换第 0 维和第 1 维(适用于二维 tensor)
x3_transposed = x3.transpose(0, 1)
print("\n使用 transpose 后的 x3_transposed:")
print(x3_transposed)
# 输出:
# tensor([[ 0, 4, 8],
# [ 1, 5, 9],
# [ 2, 6, 10],
# [ 3, 7, 11]])
# 检查是否连续
print("x3_transposed 是否连续:", x3_transposed.is_contiguous())
# 输出通常为:False
4. permute
功能说明
- 作用:按照指定顺序重新排列所有维度,可以看作是对多个维度同时交换顺序的操作。
- 特点 :
- 返回的新 tensor 也是 view(共享数据),但可能非连续。
- 灵活性更高,当需要对多维 tensor 调整顺序时使用。
示例代码及输出
python
# 创建一个三维 tensor,形状 (2, 3, 4)
y = torch.arange(24).view(2, 3, 4)
print("\n原始 y:")
print(y)
# 输出:
# tensor([[[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]],
#
# [[12, 13, 14, 15],
# [16, 17, 18, 19],
# [20, 21, 22, 23]]])
# 使用 permute 调整维度顺序,将原来的 (2, 3, 4) 变为 (4, 2, 3)
y_permuted = y.permute(2, 0, 1)
print("\n使用 permute 后的 y_permuted:")
print(y_permuted)
# 输出:
# tensor([[[ 0, 4, 8],
# [12, 16, 20]],
#
# [[ 1, 5, 9],
# [13, 17, 21]],
#
# [[ 2, 6, 10],
# [14, 18, 22]],
#
# [[ 3, 7, 11],
# [15, 19, 23]]])
总结对比
- view 与 reshape 都可用于改变 tensor 形状:
- view 要求 tensor 连续,返回的是共享内存的 view。
- reshape 更灵活,当 tensor 非连续时会自动复制数据,返回新 tensor,内存不共享。
- transpose 和 permute 用于调整维度顺序:
- transpose 只交换两个维度,适用于二维或简单交换。
- permute 可一次性重新排列所有维度,适用于多维 tensor 的任意维度调整。
选择哪个函数取决于你的需求:
- 如果只是调整形状且确保 tensor 连续,
view
速度快且节省内存。 - 如果不确定 tensor 是否连续或希望避免错误,使用
reshape
更安全。 - 若仅交换两个维度,使用
transpose
;若需要调整多个维度的顺序,使用permute
。