Pytorch对tensor进行变换的函数

在 PyTorch 中,经常需要对 tensor 的形状或维度进行转换,常用的函数有 transpose、permute、view 和 reshape。它们的主要区别如下:


1. view

功能说明

  • 作用:改变 tensor 的形状(即维度),但不复制数据,只是重新解释底层数据的布局。
  • 要求 :输入 tensor 必须是连续的(contiguous),否则会报错。如果非连续,需要先调用 tensor.contiguous()
  • 特点:返回的新 tensor 与原 tensor 共享内存(修改其中一个,另一个也会改变)。

示例代码及输出

python 复制代码
import torch

# 创建一个 1D tensor,包含 12 个元素
x = torch.arange(12)
print("原始 x:")
print(x)
# 输出:
# tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])

# 用 view 改变形状为 (3, 4)
x_view = x.view(3, 4)
print("\n使用 view 后的 x_view:")
print(x_view)
# 输出:
# tensor([[ 0,  1,  2,  3],
#         [ 4,  5,  6,  7],
#         [ 8,  9, 10, 11]])

# 修改 x_view 的某个元素,原始 x 的对应元素也会改变
x_view[0, 0] = 100
print("\n修改 x_view 后的 x:")
print(x)
# 输出:
# tensor([100,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11])

2. reshape

功能说明

  • 作用:改变 tensor 的形状,类似于 view,但更灵活。
  • 特点
    • 如果输入 tensor 是连续的,则与 view 效果相同(返回 view,不复制数据)。
    • 如果输入 tensor 非连续,则会复制数据,返回的新 tensor 与原 tensor 不共享内存。

示例代码及输出

(1) 连续 tensor 使用 reshape(与 view 相同)
python 复制代码
# 创建一个连续的 tensor
x1 = torch.arange(12).view(3, 4)
# 使用 reshape 改变形状为 (2, 6)
x1_reshape = x1.reshape(2, 6)
print("连续 tensor 使用 reshape:")
print(x1_reshape)
# 输出:
# tensor([[ 0,  1,  2,  3,  4,  5],
#         [ 6,  7,  8,  9, 10, 11]])
(2) 非连续 tensor 使用 reshape
python 复制代码
# 创建一个 3x4 tensor
x2 = torch.arange(12).view(3, 4)
# 通过 transpose 交换维度后,x2_t 可能非连续
x2_t = x2.transpose(0, 1)
print("\n通过 transpose 后的 x2_t:")
print(x2_t)
# 例如输出:
# tensor([[ 0,  4,  8],
#         [ 1,  5,  9],
#         [ 2,  6, 10],
#         [ 3,  7, 11]])
print("x2_t 是否连续:", x2_t.is_contiguous())
# 输出可能为:False

# 使用 reshape 调整形状为 (2, 6),此时 reshape 会复制数据
x2_reshape = x2_t.reshape(2, 6)
print("\n非连续 tensor 使用 reshape 后:")
print(x2_reshape)
# 输出:
# tensor([[ 0,  4,  8,  1,  5,  9],
#         [ 2,  6, 10,  3,  7, 11]])

3. transpose

功能说明

  • 作用:交换 tensor 中两个指定的维度。
  • 特点
    • 返回的 tensor 仍为原 tensor 的 view,共享内存;
    • 但常常因为交换后内存不再按连续顺序排列,导致 is_contiguous() 返回 False。

示例代码及输出

python 复制代码
# 创建一个 3x4 tensor
x3 = torch.arange(12).view(3, 4)
print("\n原始 x3:")
print(x3)
# 输出:
# tensor([[ 0,  1,  2,  3],
#         [ 4,  5,  6,  7],
#         [ 8,  9, 10, 11]])

# 使用 transpose 交换第 0 维和第 1 维(适用于二维 tensor)
x3_transposed = x3.transpose(0, 1)
print("\n使用 transpose 后的 x3_transposed:")
print(x3_transposed)
# 输出:
# tensor([[ 0,  4,  8],
#         [ 1,  5,  9],
#         [ 2,  6, 10],
#         [ 3,  7, 11]])

# 检查是否连续
print("x3_transposed 是否连续:", x3_transposed.is_contiguous())
# 输出通常为:False

4. permute

功能说明

  • 作用:按照指定顺序重新排列所有维度,可以看作是对多个维度同时交换顺序的操作。
  • 特点
    • 返回的新 tensor 也是 view(共享数据),但可能非连续。
    • 灵活性更高,当需要对多维 tensor 调整顺序时使用。

示例代码及输出

python 复制代码
# 创建一个三维 tensor,形状 (2, 3, 4)
y = torch.arange(24).view(2, 3, 4)
print("\n原始 y:")
print(y)
# 输出:
# tensor([[[ 0,  1,  2,  3],
#          [ 4,  5,  6,  7],
#          [ 8,  9, 10, 11]],
#
#         [[12, 13, 14, 15],
#          [16, 17, 18, 19],
#          [20, 21, 22, 23]]])

# 使用 permute 调整维度顺序,将原来的 (2, 3, 4) 变为 (4, 2, 3)
y_permuted = y.permute(2, 0, 1)
print("\n使用 permute 后的 y_permuted:")
print(y_permuted)
# 输出:
# tensor([[[ 0,  4,  8],
#          [12, 16, 20]],
#
#         [[ 1,  5,  9],
#          [13, 17, 21]],
#
#         [[ 2,  6, 10],
#          [14, 18, 22]],
#
#         [[ 3,  7, 11],
#          [15, 19, 23]]])

总结对比

  • viewreshape 都可用于改变 tensor 形状:
    • view 要求 tensor 连续,返回的是共享内存的 view。
    • reshape 更灵活,当 tensor 非连续时会自动复制数据,返回新 tensor,内存不共享。
  • transposepermute 用于调整维度顺序:
    • transpose 只交换两个维度,适用于二维或简单交换。
    • permute 可一次性重新排列所有维度,适用于多维 tensor 的任意维度调整。

选择哪个函数取决于你的需求:

  • 如果只是调整形状且确保 tensor 连续,view 速度快且节省内存。
  • 如果不确定 tensor 是否连续或希望避免错误,使用 reshape 更安全。
  • 若仅交换两个维度,使用 transpose;若需要调整多个维度的顺序,使用 permute
相关推荐
晚霞的不甘2 小时前
CANN 支持多模态大模型:Qwen-VL 与 LLaVA 的端侧部署实战
人工智能·神经网络·架构·开源·音视频
华玥作者8 小时前
[特殊字符] VitePress 对接 Algolia AI 问答(DocSearch + AI Search)完整实战(下)
前端·人工智能·ai
AAD555888998 小时前
YOLO11-EfficientRepBiPAN载重汽车轮胎热成像检测与分类_3
人工智能·分类·数据挖掘
王建文go8 小时前
RAG(宠物健康AI)
人工智能·宠物·rag
ALINX技术博客8 小时前
【202601芯动态】全球 FPGA 异构热潮,ALINX 高性能异构新品预告
人工智能·fpga开发·gpu算力·fpga
易营宝8 小时前
多语言网站建设避坑指南:既要“数据同步”,又能“按市场个性化”,别踩这 5 个坑
大数据·人工智能
春日见9 小时前
vscode代码无法跳转
大数据·人工智能·深度学习·elasticsearch·搜索引擎
Drgfd9 小时前
真智能 vs 伪智能:天选 WE H7 Lite 用 AI 人脸识别 + 呼吸灯带,重新定义智能化充电桩
人工智能·智能充电桩·家用充电桩·充电桩推荐
DeniuHe9 小时前
torch.distribution函数详解
pytorch
好家伙VCC10 小时前
### WebRTC技术:实时通信的革新与实现####webRTC(Web Real-TimeComm
java·前端·python·webrtc