Pytorch基础:Tensor的permute方法

相关阅读

Pytorch基础https://blog.csdn.net/weixin_45791458/category_12457644.html


在Pytorch中,permute是Tensor的一个重要方法,同时它也是一个torch模块中的一个函数,它们的语法如下所示。

复制代码
Tensor.permute(*dims) → Tensor
torch.permute(input, dims) → Tensor

input (Tensor) -- the input tensor
dims (tuple of int) -- The desired ordering of dimensions

官方的解释是:返回原始张量输入的视图,并对其维度进行转置。这里返回视图指的是一个新的tensor对象,但新旧tensor对象内的数据共享存储(即是同一个对象),返回的新对象可能会变得不连续,这样就无法对新对象使用view方法。

可以看几个例子以更好的理解:

复制代码
import torch
 
# 创建一个张量
x = torch.rand(3, 3, 3)

 
# 使用permute操作,倒置三个维度
y = x.permute(2, 1, 0)
 
print(x, y)
 
tensor([[[0.9701, 0.7507, 0.8002],
         [0.5876, 0.1460, 0.0386],
         [0.5126, 0.1538, 0.5863]],

        [[0.8500, 0.8774, 0.2415],
         [0.1053, 0.5650, 0.7321],
         [0.8260, 0.1564, 0.7447]],

        [[0.5131, 0.7111, 0.3469],
         [0.6031, 0.8140, 0.9770],
         [0.7578, 0.0223, 0.5515]]])

tensor([[[0.9701, 0.8500, 0.5131],
         [0.5876, 0.1053, 0.6031],
         [0.5126, 0.8260, 0.7578]],

        [[0.7507, 0.8774, 0.7111],
         [0.1460, 0.5650, 0.8140],
         [0.1538, 0.1564, 0.0223]],

        [[0.8002, 0.2415, 0.3469],
         [0.0386, 0.7321, 0.9770],
         [0.5863, 0.7447, 0.5515]]])
 
print(id(x),id(y))
4554479952 4811331200 # 说明两个张量对象不同
 
print(x.storage().data_ptr(), y.storage().data_ptr())
4830094080 4830094080 # 说明两个张量对象里面保存的数据存储是共享的
 
print(id(x[0,0]), id(y[0,0])) 
4570943952 4570943952 # 进一步说明两个张量对象里面保存的数据存储是共享的
 
y[0, 0] = 7
print(x, y)
tensor([[[7.0000, 0.7507, 0.8002],
         [0.5876, 0.1460, 0.0386],
         [0.5126, 0.1538, 0.5863]],

        [[7.0000, 0.8774, 0.2415],
         [0.1053, 0.5650, 0.7321],
         [0.8260, 0.1564, 0.7447]],

        [[7.0000, 0.7111, 0.3469],
         [0.6031, 0.8140, 0.9770],
         [0.7578, 0.0223, 0.5515]]]) 

tensor([[[7.0000, 7.0000, 7.0000],
         [0.5876, 0.1053, 0.6031],
         [0.5126, 0.8260, 0.7578]],

        [[0.7507, 0.8774, 0.7111],
         [0.1460, 0.5650, 0.8140],
         [0.1538, 0.1564, 0.0223]],

        [[0.8002, 0.2415, 0.3469],
         [0.0386, 0.7321, 0.9770],
         [0.5863, 0.7447, 0.5515]]])# 说明对新tensor的更改影响了原tensor
 
print(x.is_contiguous(), y.is_contiguous())  
True False # 说明x是连续的,y不是连续的

以上的内容,类似于之前在关于python中列表的浅拷贝中说到的那样,对新列表内部嵌套的列表中的元素的更改会影响原列表。如下所示。 列表的浅拷贝

​​​

复制代码
import copy
my_list = [1, 2, [1, 2]]
 
your_list = list(my_list)  #工厂函数
his_list = my_list[:]      #切片操作
her_list = copy.copy(my_list)    #copy模块的copy函数
 
your_list[2][0] = 3
print(my_list)
print(your_list)
print(his_list)
print(her_list)
 
his_list[2][1] = 4
print(my_list)
print(your_list)
print(his_list)
print(her_list)
 
her_list[2].append(5)
print(my_list)
print(your_list)
print(his_list)
print(her_list)
 
 
 
输出
[1, 2, [3, 2]]
[1, 2, [3, 2]]
[1, 2, [3, 2]]
[1, 2, [3, 2]]
 
[1, 2, [3, 4]]
[1, 2, [3, 4]]
[1, 2, [3, 4]]
[1, 2, [3, 4]]
 
[1, 2, [3, 4, 5]]
[1, 2, [3, 4, 5]]
[1, 2, [3, 4, 5]]
[1, 2, [3, 4, 5]]

但不一样的是,在这里甚至对tensor中非嵌套的内容的修改也会导致另一个tensor受到影响,如下所示。

复制代码
import torch

# 创建一个张量
x = torch.tensor([[1, 2, 3], [4, 5, 6]])

# 使用permute操作
y = x.permute(0, 1)

print(x, y)

tensor([[1, 2, 3],
        [4, 5, 6]])
tensor([[1, 4],
        [2, 5],
        [3, 6]])

x[0] = torch.tensor[4, 4, 4] # 改变其中一个tensor的第0个元素

print(x, y)

tensor([[4, 4, 4],
        [4, 5, 6]])
tensor([[4, 4],
        [4, 5],
        [4, 6]])
相关推荐
吾日三省吾码6 分钟前
GitHub Copilot (Gen-AI) 很有用,但不是很好
人工智能·github·copilot
一颗橘子宣布成为星球21 分钟前
Unity AI-使用Ollama本地大语言模型运行框架运行本地Deepseek等模型实现聊天对话(一)
人工智能·unity·语言模型·游戏引擎
南 阳39 分钟前
从微服务到AI服务:Nacos 3.0如何重构下一代动态治理体系?
人工智能·微服务·云原生·重构
fmingzh1 小时前
NVIDIA高级辅助驾驶安全与技术读后感
人工智能·安全·自动驾驶
PXM的算法星球1 小时前
【软件工程】面向对象编程(OOP)概念详解
java·python·软件工程
qsmyhsgcs2 小时前
Java程序员转人工智能入门学习路线图(2025版)
java·人工智能·学习·机器学习·算法工程师·人工智能入门·ai算法工程师
A林玖2 小时前
【机器学习】朴素贝叶斯
人工智能·算法·机器学习
六边形战士DONK2 小时前
神经网络基础[损失函数,bp算法,梯度下降算法 ]
人工智能·神经网络·算法
IT从业者张某某2 小时前
机器学习-08-时序数据分析预测
人工智能·机器学习·数据分析
袁煦丞2 小时前
AI视频生成神器Wan 2.1:cpolar内网穿透实验室第596个成功挑战
人工智能·程序员·远程工作