torch.flatten()
是 PyTorch 中的一个函数,用于将输入张量展平为一维张量。它的语法如下:
python
torch.flatten(input, start_dim=0, end_dim=-1)
input
:要展平的输入张量。start_dim
(可选):指定从哪个维度开始展平。默认为 0。end_dim
(可选):指定从哪个维度结束展平。默认为 -1,表示最后一个维度。
torch.flatten()
函数会将输入张量的指定维度范围内的所有元素展平到一个一维张量中。展平后的张量保持与原始张量相同的数据顺序。例如,如果输入张量是一个 3x4x5 的三维张量,然后你使用 torch.flatten()
函数将它展平,那么结果将是一个包含 60 个元素的一维张量,其中包含原始张量中所有的元素。
以下是一个示例:
python
import torch
# 创建一个3x4x5的张量
input_tensor = torch.randn(3, 4, 5)
# 使用torch.flatten()将其展平为一维张量
output_tensor = torch.flatten(input_tensor)
print(output_tensor.size()) # 输出 torch.Size([60])
在此示例中,input_tensor
是一个形状为 (3, 4, 5) 的三维张量,使用 torch.flatten()
函数将其展平为一个一维张量,并打印出了结果张量的大小。
示例:
python
import torch
# 创建一个2×3x5x5的张量
input_tensor = torch.randn(2, 3, 5, 5)
print(f"原张量的尺寸为:{input_tensor.size()}") # torch.Size([2, 3, 5, 5])
# 使用torch.flatten()从第一个维度开始展平,从第二个维度结束展平
output_tensor = torch.flatten(input_tensor, start_dim=1, end_dim=2)
print(f"经过展平后的张量的尺寸为:{output_tensor.size()}") # torch.Size([2, 15, 5])