Pytorch维度转换操作:view,reshape,permute,flatten函数详解

引言

Pytorch中常见的维度转换函数有view, reshape, permute, flatten。本文将详细介绍这几个函数的作用与使用方式,并给出了具体的代码示例,希望能够帮助大家。

常见的维度有四维:比如(batch, channel, height, width);三维:比如(b,n,c);二维:比如(h,w)。下面介绍如何使用上述函数进行维度之间的转换。

1. view函数

作用

tensor.view() 可以用来调整张量的形状,这对于在网络层之间传递数据或者在处理图像数据时非常有用。需要注意的是,新的形状必须与原始张量的元素数量一致。

参数

size (tuple of ints) -- 新的大小应该与原张量元素数量相匹配。可以指定一个尺寸为 -1 的维度来自动计算合适的大小。

代码示例:

将计算机视觉中的常见四维张量(Batch, Channel, Height, Width)转为三维(Batch,N,Channel)形式。

python 复制代码
import torch
# view使用示例
x = torch.randn(16,3,64,64) # B, C, H, W
print(x.shape) #torch.Size([16,3,64,64])
B, C, H, W = x.size()

# 转为BNC
x = x.view(B, -1, C)
# 或者 x = x.view(B, H*W, C)
print(x.shape) #torch.Size([16, 4096, 3])

torch.randn() 是 PyTorch 中的一个函数,用于生成一个填充了从标准正态分布(均值为 0,方差为 1)中随机抽取的数字的张量。

2. permute函数

作用

permute() 函数用于改变张量的维度顺序。它接受一个新的维度顺序作为参数,并返回一个新的张量,其维度顺序按照给定的顺序排列。

参数说明

参数:一个元组,表示新的维度顺序。例如,对于一个形状为 (10, 3, 32, 32) 的张量,permute(0, 2, 3, 1) 表示新的维度顺序为 (10, 32, 32, 3)。其中0,1,2,3分别表示4个维度(10, 3, 32, 32)的索引。

代码示例:

依然将计算机视觉中的常见四维张量(Batch, Channel, Height, Width)转为三维(Batch,N,Channel)形式。

python 复制代码
import torch
# permute使用示例:permute转换唯独顺序
x = torch.randn(16,3,64,64) # B, C, H, W
print(x.shape) #torch.Size([16,3,64,64])

# 16,3,64,64的维度索引分别为0,1,2,3
dim_change = x.permute(0,2,3,1) # 转为 B,H,W,C
# 然后将中间两个通道索引为[1,2]展平
out = dim_change.flatten(start_dim=1,end_dim=2)
print(out.shape) #torch.Size([16, 4096, 3])

flatten() 方法用于展平张量的一个或多个维度。它可以接受两个可选参数:start_dim:从哪个维度开始展平,默认为 0。

end_dim:到哪个维度结束展平,默认为 -1,表示直到最后一个维度。

此处的作用是将第二个和第三个维度进行展平。start_dim=1 表示从第二个维度(即 64)开始展平。end_dim=2 表示到第三个维度(即 64)结束展平。展平后的结果为 (16, 4096, 3),其中 4096= 64 * 64。

通过这些步骤,你可以将原始张量从 (16,3,64,64) 转换为 (16, 4096, 3)。

3. Reshape函数

torch.reshape() 可以改变张量的形状,而不改变张量中的数据。与view函数的作用类似。

注意事项:新旧形状的元素总数必须相同。

python 复制代码
import torch

# 创建一个简单的张量
x = torch.randn(4, 3)
print("Original tensor:")
print(x)

# 使用 torch.reshape() 来改变张量的形状
# 将 (4, 3) 的张量转换成 (2, 6) 的张量
reshaped_x = torch.reshape(x, (2, 6))
print("\nReshaped tensor:")
print(reshaped_x)

# 如果不确定某个维度的大小,可以使用 -1 让 PyTorch 自动计算
# 这里将 (4, 3) 转换为 (12,) 的一维张量
flat_x = torch.reshape(x, (-1))
print("\nFlattened tensor:")
print(flat_x)

# 更复杂的形状变换
# 将 (4, 3) 转换为 (3, 4) 的张量
complex_reshaped_x = torch.reshape(x, (3, 4))
print("\nComplex reshaped tensor:")
print(complex_reshaped_x)

4. flatten函数

torch.flatten 是 PyTorch 库中的一个函数,用于将一个多维张量转换为一维张量或降低其维度。

torch.flatten参数说明

input: 这是要被展平的张量。这是必需的参数。

start_dim (可选): 指定从哪个维度开始展平。默认值为 0,这意味着展平将从第一个维度(通常是批量大小)开始。如果你希望保留前几个维度并只展平后续的维度,你可以设置这个参数。

end_dim (可选): 指定展平到哪个维度结束。默认值为 -1,这表示展平将一直持续到最后一个维度。如果只想展平中间的一部分维度,可以设置这个参数来指定结束维度。

**注意:**当 start_dim 和 end_dim 都没有被显式地指定时,torch.flatten 将会展平除了第一个维度之外的所有维度,通常第一个维度是批量大小,会被保留以便于批次处理。

代码示例:

举个例子,假设有一个形状为 [batch_size, channels, height, width] 的四维张量,如果你想将其展平为 [batch_size, channels * height * width] 的二维张量,你可以直接调用 torch.flatten 而不需要额外的参数。但是,如果你想保留通道维度,并展平高度和宽度维度,你可以设置 start_dim=1 和 end_dim=2。

python 复制代码
import torch

# 创建一个形状为 [8, 3, 64, 64] 的随机张量
x = torch.randn(8, 3, 64, 64)

# 展平除了第一个维度外的所有维度
y = torch.flatten(x)
print(y.shape)  # 输出: torch.Size([8, 12288])

# 只展平第二和第三个维度[也就是最后两个维度],0,1,2,3
z = torch.flatten(x, 1, 2)
print(z.shape)  # 输出: torch.Size([8, 3, 4096])
相关推荐
IT古董2 小时前
【第三章:神经网络原理详解与Pytorch入门】02.深度学习框架PyTorch入门-(5)PyTorch 实战——使用 RNN 进行人名分类
pytorch·深度学习·神经网络
机器学习之心3 小时前
小波增强型KAN网络 + SHAP可解释性分析(Pytorch实现)
人工智能·pytorch·python·kan网络
Green1Leaves10 小时前
pytorch学习-11卷积神经网络(高级篇)
pytorch·学习·cnn
灵智工坊LingzhiAI12 小时前
人体坐姿检测系统项目教程(YOLO11+PyTorch+可视化)
人工智能·pytorch·python
William.csj1 天前
Pytorch/CUDA——flash-attn 库编译的 gcc 版本问题
pytorch·cuda
Green1Leaves2 天前
pytorch学习-9.多分类问题
人工智能·pytorch·学习
摸爬滚打李上进2 天前
重生学AI第十六集:线性层nn.Linear
人工智能·pytorch·python·神经网络·机器学习
HuashuiMu花水木2 天前
PyTorch笔记1----------Tensor(张量):基本概念、创建、属性、算数运算
人工智能·pytorch·笔记
喝过期的拉菲3 天前
如何使用 Pytorch Lightning 启用早停机制
pytorch·lightning·早停机制
kk爱闹3 天前
【挑战14天学完python和pytorch】- day01
android·pytorch·python