Pytorch对tensor进行变换的函数

在 PyTorch 中,经常需要对 tensor 的形状或维度进行转换,常用的函数有 transpose、permute、view 和 reshape。它们的主要区别如下:


1. view

功能说明

  • 作用:改变 tensor 的形状(即维度),但不复制数据,只是重新解释底层数据的布局。
  • 要求 :输入 tensor 必须是连续的(contiguous),否则会报错。如果非连续,需要先调用 tensor.contiguous()
  • 特点:返回的新 tensor 与原 tensor 共享内存(修改其中一个,另一个也会改变)。

示例代码及输出

python 复制代码
import torch

# 创建一个 1D tensor,包含 12 个元素
x = torch.arange(12)
print("原始 x:")
print(x)
# 输出:
# tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])

# 用 view 改变形状为 (3, 4)
x_view = x.view(3, 4)
print("\n使用 view 后的 x_view:")
print(x_view)
# 输出:
# tensor([[ 0,  1,  2,  3],
#         [ 4,  5,  6,  7],
#         [ 8,  9, 10, 11]])

# 修改 x_view 的某个元素,原始 x 的对应元素也会改变
x_view[0, 0] = 100
print("\n修改 x_view 后的 x:")
print(x)
# 输出:
# tensor([100,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11])

2. reshape

功能说明

  • 作用:改变 tensor 的形状,类似于 view,但更灵活。
  • 特点
    • 如果输入 tensor 是连续的,则与 view 效果相同(返回 view,不复制数据)。
    • 如果输入 tensor 非连续,则会复制数据,返回的新 tensor 与原 tensor 不共享内存。

示例代码及输出

(1) 连续 tensor 使用 reshape(与 view 相同)
python 复制代码
# 创建一个连续的 tensor
x1 = torch.arange(12).view(3, 4)
# 使用 reshape 改变形状为 (2, 6)
x1_reshape = x1.reshape(2, 6)
print("连续 tensor 使用 reshape:")
print(x1_reshape)
# 输出:
# tensor([[ 0,  1,  2,  3,  4,  5],
#         [ 6,  7,  8,  9, 10, 11]])
(2) 非连续 tensor 使用 reshape
python 复制代码
# 创建一个 3x4 tensor
x2 = torch.arange(12).view(3, 4)
# 通过 transpose 交换维度后,x2_t 可能非连续
x2_t = x2.transpose(0, 1)
print("\n通过 transpose 后的 x2_t:")
print(x2_t)
# 例如输出:
# tensor([[ 0,  4,  8],
#         [ 1,  5,  9],
#         [ 2,  6, 10],
#         [ 3,  7, 11]])
print("x2_t 是否连续:", x2_t.is_contiguous())
# 输出可能为:False

# 使用 reshape 调整形状为 (2, 6),此时 reshape 会复制数据
x2_reshape = x2_t.reshape(2, 6)
print("\n非连续 tensor 使用 reshape 后:")
print(x2_reshape)
# 输出:
# tensor([[ 0,  4,  8,  1,  5,  9],
#         [ 2,  6, 10,  3,  7, 11]])

3. transpose

功能说明

  • 作用:交换 tensor 中两个指定的维度。
  • 特点
    • 返回的 tensor 仍为原 tensor 的 view,共享内存;
    • 但常常因为交换后内存不再按连续顺序排列,导致 is_contiguous() 返回 False。

示例代码及输出

python 复制代码
# 创建一个 3x4 tensor
x3 = torch.arange(12).view(3, 4)
print("\n原始 x3:")
print(x3)
# 输出:
# tensor([[ 0,  1,  2,  3],
#         [ 4,  5,  6,  7],
#         [ 8,  9, 10, 11]])

# 使用 transpose 交换第 0 维和第 1 维(适用于二维 tensor)
x3_transposed = x3.transpose(0, 1)
print("\n使用 transpose 后的 x3_transposed:")
print(x3_transposed)
# 输出:
# tensor([[ 0,  4,  8],
#         [ 1,  5,  9],
#         [ 2,  6, 10],
#         [ 3,  7, 11]])

# 检查是否连续
print("x3_transposed 是否连续:", x3_transposed.is_contiguous())
# 输出通常为:False

4. permute

功能说明

  • 作用:按照指定顺序重新排列所有维度,可以看作是对多个维度同时交换顺序的操作。
  • 特点
    • 返回的新 tensor 也是 view(共享数据),但可能非连续。
    • 灵活性更高,当需要对多维 tensor 调整顺序时使用。

示例代码及输出

python 复制代码
# 创建一个三维 tensor,形状 (2, 3, 4)
y = torch.arange(24).view(2, 3, 4)
print("\n原始 y:")
print(y)
# 输出:
# tensor([[[ 0,  1,  2,  3],
#          [ 4,  5,  6,  7],
#          [ 8,  9, 10, 11]],
#
#         [[12, 13, 14, 15],
#          [16, 17, 18, 19],
#          [20, 21, 22, 23]]])

# 使用 permute 调整维度顺序,将原来的 (2, 3, 4) 变为 (4, 2, 3)
y_permuted = y.permute(2, 0, 1)
print("\n使用 permute 后的 y_permuted:")
print(y_permuted)
# 输出:
# tensor([[[ 0,  4,  8],
#          [12, 16, 20]],
#
#         [[ 1,  5,  9],
#          [13, 17, 21]],
#
#         [[ 2,  6, 10],
#          [14, 18, 22]],
#
#         [[ 3,  7, 11],
#          [15, 19, 23]]])

总结对比

  • viewreshape 都可用于改变 tensor 形状:
    • view 要求 tensor 连续,返回的是共享内存的 view。
    • reshape 更灵活,当 tensor 非连续时会自动复制数据,返回新 tensor,内存不共享。
  • transposepermute 用于调整维度顺序:
    • transpose 只交换两个维度,适用于二维或简单交换。
    • permute 可一次性重新排列所有维度,适用于多维 tensor 的任意维度调整。

选择哪个函数取决于你的需求:

  • 如果只是调整形状且确保 tensor 连续,view 速度快且节省内存。
  • 如果不确定 tensor 是否连续或希望避免错误,使用 reshape 更安全。
  • 若仅交换两个维度,使用 transpose;若需要调整多个维度的顺序,使用 permute
相关推荐
hie9889425 分钟前
MATLAB锂离子电池伪二维(P2D)模型实现
人工智能·算法·matlab
晨同学032728 分钟前
opencv的颜色通道问题 & rgb & bgr
人工智能·opencv·计算机视觉
路来了28 分钟前
Python小工具之PDF合并
开发语言·windows·python
蓝婷儿38 分钟前
Python 机器学习核心入门与实战进阶 Day 3 - 决策树 & 随机森林模型实战
人工智能·python·机器学习
大千AI助手40 分钟前
PageRank:互联网的马尔可夫链平衡态
人工智能·机器学习·贝叶斯·mc·pagerank·条件概率·马尔科夫链
AntBlack1 小时前
拖了五个月 ,不当韭菜体验版算是正式发布了
前端·后端·python
小和尚同志1 小时前
Cline | Cline + Grok3 免费 AI 编程新体验
人工智能·aigc
我就是全世界1 小时前
TensorRT-LLM:大模型推理加速的核心技术与实践优势
人工智能·机器学习·性能优化·大模型·tensorrt-llm
.30-06Springfield1 小时前
决策树(Decision tree)算法详解(ID3、C4.5、CART)
人工智能·python·算法·决策树·机器学习
我不是哆啦A梦1 小时前
破解风电运维“百模大战”困局,机械版ChatGPT诞生?
运维·人工智能·python·算法·chatgpt