【无标题】PyTorch 常用算子说明

1.增加维度

print(a.unsqueeze(0).shape) # 在0号维度位置插入一个维度

print(a.unsqueeze(-1).shape) # 在最后插入一个维度

print(a.unsqueeze(3).shape) # 在3号维度位置插入一个维度

2.删减维度

a = torch.Tensor(1, 4, 1, 9)

print(a.squeeze().shape) # 能删除的都删除掉

print(a.squeeze(0).shape) # 尝试删除0号维度,ok

3.维度扩展(expand)

b = torch.rand(32)

f = torch.rand(4, 32, 14, 14)

先进行维度增加

b = b.unsqueeze(1).unsqueeze(2).unsqueeze(0)

print(b.shape)

再进行维度扩展

b = b.expand(4, -1, 14, 14) # -1表示这个维度保持不变,这里写32也可以

print(b.shape)

输出:

torch.Size(1, 32, 1, 1)

torch.Size(4, 32, 14, 14)

4.维度重复(repeat)

print(b.shape)

维度重复,32这里不想进行重复,所以就相当于"重复至1次"

b = b.repeat(4, 1, 14, 14)

print(b.shape)

输出:

torch.Size(1, 32, 1, 1)

torch.Size(4, 32, 14, 14)

5.转置

只适用于dim=2的Tensor。

c = torch.Tensor(2, 4)

print(c.t().shape)

输出:

torch.Size(4, 2)

  1. 维度交换

d = torch.Tensor(6, 3, 1, 2)

print(d.transpose(1, 3).contiguous().shape) # 1号维度和3号维度交换

输出:

torch.Size(6, 2, 1, 3)

7.permute

h = torch.rand(4, 3, 6, 7)

print(h.permute(0, 2, 3, 1).shape)

输出:

torch.Size(4, 6, 7, 3)

8.gather

1)input:输入

2)dim:维度,常用的为0和1

3)index:索引位置

a=t.arange(0,16).view(4,4)

print(a)

index_1=t.LongTensor(\[3,2,1,0])

b=a.gather(0,index_1)

print(b)

index_2=t.LongTensor(\[0,1,2,3]).t()#tensor转置操作:(a)T=a.t()

c=a.gather(1,index_2)

print(c)

outout输出:

tensor(\[ 0, 1, 2, 3,

4, 5, 6, 7,

8, 9, 10, 11,

12, 13, 14, 15])

tensor(\[12, 9, 6, 3])

tensor(\[ 0,

5,

10,

15])

在gather中,我们是通过index对input进行索引把对应的数据提取出来的,而dim决定了索引的方式。

9.Chunk

torch.chunk(tensor, chunks, dim=0)

在给定维度(轴)上将输入张量进行分块儿

直接用上面的数据来举个例子:

l, m, n = x.chunk(3, 0) # 在 0 维上拆分成 3 份

l.size(), m.size(), n.size()

(torch.Size(1, 10, 6), torch.Size(1, 10, 6), torch.Size(1, 10, 6))

u, v = x.chunk(2, 0) # 在 0 维上拆分成 2 份

u.size(), v.size()

(torch.Size(2, 10, 6), torch.Size(1, 10, 6))

10.Stack

合并新增(stack)

stack需要保证两个Tensor的shape是一致的。

c = torch.rand(4, 3, 32, 32)

d = torch.rand(4, 3, 32, 32)

print(torch.stack(c, d, dim=2).shape)

print(torch.stack(c, d, dim=0).shape)

运行结果:

torch.Size(4, 3, 2, 32, 32)

torch.Size(2, 4, 3, 32, 32)

11.View

Pytorch中的view函数主要用于Tensor维度的重构,即返回一个有相同数据但不同维度的Tensor。

a3 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,

13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24])

a4 = a3.view(4, -1)

a5 = a3.view(2, 3, -1)

输出:

#a3

tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,

19, 20, 21, 22, 23, 24])

#a4

tensor(\[ 1, 2, 3, 4, 5, 6,

7, 8, 9, 10, 11, 12,

13, 14, 15, 16, 17, 18,

19, 20, 21, 22, 23, 24])

#a5

tensor(\[\[ 1, 2, 3, 4,

5, 6, 7, 8,

9, 10, 11, 12],

\[13, 14, 15, 16,

17, 18, 19, 20,

21, 22, 23, 24]])

12.reshape

返回与 input张量数据大小一样、给定 shape的张量。如果可能,返回的是input 张量的视图,否则返回的是其拷贝。

a1 = torch.tensor(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12)

a2 = torch.reshape(a1, (3, 4))

print(a1.shape)

print(a1)

print(a2.shape)

print(a2)

运行结果:

torch.Size(12)

tensor( 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12)

torch.Size(3, 4)

tensor(\[ 1, 2, 3, 4,

5, 6, 7, 8,

9, 10, 11, 12])

同view函数,也可以自动推断维度:a4 = torch.reshape(a1, (-1, 6))

相关推荐
装不满的克莱因瓶4 分钟前
掌握 RNN 与 LSTM 模型结构
人工智能·python·rnn·深度学习·神经网络·ai·lstm
何以解忧,唯有..14 分钟前
Python包管理工具pip:从入门到精通
开发语言·python·pip
努力学习_小白30 分钟前
ResNeXt-50——学习记录
pytorch·深度学习·学习
金銀銅鐵31 分钟前
用 Tkinter 实现简单的猜数字游戏
后端·python
Kobebryant-Manba43 分钟前
记录动手学深度学习基础知识
人工智能·深度学习
copyer_xyf1 小时前
Python 模块与包的导入导出
前端·后端·python
ice8130331811 小时前
【Python】Matplotlib折线图绘制
开发语言·python·matplotlib
copyer_xyf1 小时前
Python venv 虚拟环境
前端·后端·python
LaughingZhu1 小时前
Product Hunt 每日热榜 | 2026-06-04
人工智能·经验分享·深度学习·神经网络·产品运营
君为先-bey2 小时前
JointDiT:使用扩散变换器增强RGB-深度联合建模
人工智能·深度学习·计算机视觉·扩散模型·图像生成