PyTorch 中 reshape 函数用法示例
在 PyTorch 中,reshape
函数用于改变张量的形状,而不改变其中的数据。下面是一些关于 reshape
函数的常见用法示例。
基本语法
torch.reshape(input, shape)
# input: 要重塑的张量。
# shape: 目标形状,可以是一个整数元组或列表。
示例1:将一维张量转为二维张量(重要)
import torch
# 创建一个一维张量
tensor_1d = torch.tensor([1, 2, 3, 4, 5, 6])
# 使用 reshape 将其转为形状为 (2, 3) 的二维张量
tensor_2d = tensor_1d.reshape(2, 3)
print(tensor_2d)
输出:
tensor([[1, 2, 3],
[4, 5, 6]])
示例 2:使用负数维度自动推导形状(重要)
在 reshape 中可以使用 -1 表示自动推导该维度的大小。
# 创建一个一维张量
tensor_1d = torch.tensor([1, 2, 3, 4, 5, 6])
# 使用 -1 自动推导维度
tensor_2d = tensor_1d.reshape(3, -1)
print(tensor_2d)
输出:
tensor([[1, 2],
[3, 4],
[5, 6]])
在这里,-1 的意思是由其他维度的大小推导出来的。
示例 3:将三维张量展平为二维张量
假设有一个形状为 (2, 3, 4) 的三维张量,可以将其展平为形状为 (2, 12) 的二维张量。
# 创建一个三维张量
tensor_3d = torch.randn(2, 3, 4) # 随机生成一个张量
print(tensor_3d)
# 重塑为二维张量
tensor_2d = tensor_3d.reshape(2, -1)
print(tensor_2d)
print(tensor_2d.shape) # 输出应该为 torch.Size([2, 12])
输出:
tensor([[[-2.0344, -0.0268, 1.4198, 0.5537],
[ 2.1429, -0.8317, -1.6704, 0.3521],
[ 0.4205, 0.0552, 1.8191, 0.4051]],
[[-0.5695, 0.2553, -0.8192, -1.3156],
[ 0.8952, -0.6411, 1.0547, 0.7071],
[-0.1367, -2.2702, 0.6299, -0.7946]]])
tensor([[-2.0344, -0.0268, 1.4198, 0.5537, 2.1429, -0.8317, -1.6704, 0.3521,
0.4205, 0.0552, 1.8191, 0.4051],
[-0.5695, 0.2553, -0.8192, -1.3156, 0.8952, -0.6411, 1.0547, 0.7071,
-0.1367, -2.2702, 0.6299, -0.7946]])
torch.Size([2, 12])
示例4:调换维度
如果你想把一个矩阵的行和列互换,可以先使用 reshape 将张量改变形状,再使用 .t() 方法进行转置(若适用)。
# 创建一个二维张量
tensor_2d = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 使用 reshape 先改变形状后,再用 .t() 转置
tensor_transposed = tensor_2d.reshape(3, 2).t() # 先变成 3x2 然后转置
print(tensor_transposed)
输出:
tensor([[1, 4],
[2, 5],
[3, 6]])
总结
- reshape 是用于改变张量形状的工具,数据不变。
- 可以使用 -1 进行自动推导。
- 适用于多维张量的重塑,便于后续的数据处理和建模。