Pytorch对tensor进行变换的函数

在 PyTorch 中,经常需要对 tensor 的形状或维度进行转换,常用的函数有 transpose、permute、view 和 reshape。它们的主要区别如下:


1. view

功能说明

  • 作用:改变 tensor 的形状(即维度),但不复制数据,只是重新解释底层数据的布局。
  • 要求 :输入 tensor 必须是连续的(contiguous),否则会报错。如果非连续,需要先调用 tensor.contiguous()
  • 特点:返回的新 tensor 与原 tensor 共享内存(修改其中一个,另一个也会改变)。

示例代码及输出

python 复制代码
import torch

# 创建一个 1D tensor,包含 12 个元素
x = torch.arange(12)
print("原始 x:")
print(x)
# 输出:
# tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])

# 用 view 改变形状为 (3, 4)
x_view = x.view(3, 4)
print("\n使用 view 后的 x_view:")
print(x_view)
# 输出:
# tensor([[ 0,  1,  2,  3],
#         [ 4,  5,  6,  7],
#         [ 8,  9, 10, 11]])

# 修改 x_view 的某个元素,原始 x 的对应元素也会改变
x_view[0, 0] = 100
print("\n修改 x_view 后的 x:")
print(x)
# 输出:
# tensor([100,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11])

2. reshape

功能说明

  • 作用:改变 tensor 的形状,类似于 view,但更灵活。
  • 特点
    • 如果输入 tensor 是连续的,则与 view 效果相同(返回 view,不复制数据)。
    • 如果输入 tensor 非连续,则会复制数据,返回的新 tensor 与原 tensor 不共享内存。

示例代码及输出

(1) 连续 tensor 使用 reshape(与 view 相同)
python 复制代码
# 创建一个连续的 tensor
x1 = torch.arange(12).view(3, 4)
# 使用 reshape 改变形状为 (2, 6)
x1_reshape = x1.reshape(2, 6)
print("连续 tensor 使用 reshape:")
print(x1_reshape)
# 输出:
# tensor([[ 0,  1,  2,  3,  4,  5],
#         [ 6,  7,  8,  9, 10, 11]])
(2) 非连续 tensor 使用 reshape
python 复制代码
# 创建一个 3x4 tensor
x2 = torch.arange(12).view(3, 4)
# 通过 transpose 交换维度后,x2_t 可能非连续
x2_t = x2.transpose(0, 1)
print("\n通过 transpose 后的 x2_t:")
print(x2_t)
# 例如输出:
# tensor([[ 0,  4,  8],
#         [ 1,  5,  9],
#         [ 2,  6, 10],
#         [ 3,  7, 11]])
print("x2_t 是否连续:", x2_t.is_contiguous())
# 输出可能为:False

# 使用 reshape 调整形状为 (2, 6),此时 reshape 会复制数据
x2_reshape = x2_t.reshape(2, 6)
print("\n非连续 tensor 使用 reshape 后:")
print(x2_reshape)
# 输出:
# tensor([[ 0,  4,  8,  1,  5,  9],
#         [ 2,  6, 10,  3,  7, 11]])

3. transpose

功能说明

  • 作用:交换 tensor 中两个指定的维度。
  • 特点
    • 返回的 tensor 仍为原 tensor 的 view,共享内存;
    • 但常常因为交换后内存不再按连续顺序排列,导致 is_contiguous() 返回 False。

示例代码及输出

python 复制代码
# 创建一个 3x4 tensor
x3 = torch.arange(12).view(3, 4)
print("\n原始 x3:")
print(x3)
# 输出:
# tensor([[ 0,  1,  2,  3],
#         [ 4,  5,  6,  7],
#         [ 8,  9, 10, 11]])

# 使用 transpose 交换第 0 维和第 1 维(适用于二维 tensor)
x3_transposed = x3.transpose(0, 1)
print("\n使用 transpose 后的 x3_transposed:")
print(x3_transposed)
# 输出:
# tensor([[ 0,  4,  8],
#         [ 1,  5,  9],
#         [ 2,  6, 10],
#         [ 3,  7, 11]])

# 检查是否连续
print("x3_transposed 是否连续:", x3_transposed.is_contiguous())
# 输出通常为:False

4. permute

功能说明

  • 作用:按照指定顺序重新排列所有维度,可以看作是对多个维度同时交换顺序的操作。
  • 特点
    • 返回的新 tensor 也是 view(共享数据),但可能非连续。
    • 灵活性更高,当需要对多维 tensor 调整顺序时使用。

示例代码及输出

python 复制代码
# 创建一个三维 tensor,形状 (2, 3, 4)
y = torch.arange(24).view(2, 3, 4)
print("\n原始 y:")
print(y)
# 输出:
# tensor([[[ 0,  1,  2,  3],
#          [ 4,  5,  6,  7],
#          [ 8,  9, 10, 11]],
#
#         [[12, 13, 14, 15],
#          [16, 17, 18, 19],
#          [20, 21, 22, 23]]])

# 使用 permute 调整维度顺序,将原来的 (2, 3, 4) 变为 (4, 2, 3)
y_permuted = y.permute(2, 0, 1)
print("\n使用 permute 后的 y_permuted:")
print(y_permuted)
# 输出:
# tensor([[[ 0,  4,  8],
#          [12, 16, 20]],
#
#         [[ 1,  5,  9],
#          [13, 17, 21]],
#
#         [[ 2,  6, 10],
#          [14, 18, 22]],
#
#         [[ 3,  7, 11],
#          [15, 19, 23]]])

总结对比

  • viewreshape 都可用于改变 tensor 形状:
    • view 要求 tensor 连续,返回的是共享内存的 view。
    • reshape 更灵活,当 tensor 非连续时会自动复制数据,返回新 tensor,内存不共享。
  • transposepermute 用于调整维度顺序:
    • transpose 只交换两个维度,适用于二维或简单交换。
    • permute 可一次性重新排列所有维度,适用于多维 tensor 的任意维度调整。

选择哪个函数取决于你的需求:

  • 如果只是调整形状且确保 tensor 连续,view 速度快且节省内存。
  • 如果不确定 tensor 是否连续或希望避免错误,使用 reshape 更安全。
  • 若仅交换两个维度,使用 transpose;若需要调整多个维度的顺序,使用 permute
相关推荐
月疯2 分钟前
OPENCV摄像头读取视频
人工智能·opencv·音视频
极客天成ScaleFlash6 分钟前
极客天成让统一存储从云原生‘进化’到 AI 原生: 不是版本升级,而是基因重组
人工智能·云原生
王哥儿聊AI11 分钟前
Lynx:新一代个性化视频生成模型,单图即可生成视频,重新定义身份一致性与视觉质量
人工智能·算法·安全·机器学习·音视频·软件工程
_pinnacle_38 分钟前
打开神经网络的黑箱(三) 卷积神经网络(CNN)的模型逻辑
人工智能·神经网络·cnn·黑箱·卷积网络
Ada's41 分钟前
深度学习在自动驾驶上应用(二)
人工智能·深度学习·自动驾驶
张较瘦_1 小时前
[论文阅读] 人工智能 + 软件工程 | 从“人工扒日志”到“AI自动诊断”:LogCoT框架的3大核心创新
论文阅读·人工智能·软件工程
lisw051 小时前
连接蓝牙时“无媒体信号”怎么办?
人工智能·机器学习·微服务
傻啦嘿哟1 小时前
Python SQLite模块:轻量级数据库的实战指南
数据库·python·sqlite
扫地的小何尚1 小时前
深度解析 CUDA-QX 0.4 加速 QEC 与求解器库
人工智能·语言模型·llm·gpu·量子计算·nvidia·cuda
张较瘦_2 小时前
[论文阅读] 人工智能 + 软件工程 | 35篇文献拆解!LLM如何重塑软件配置的生成、验证与运维
论文阅读·人工智能·软件工程