PyTorch学习之 torch.squeeze 函数
一、功能
torch.squeeze
的主要作用是从给定的张量 input
中移除所有尺寸为1
的维度。
二、基本语法
python
torch.squeeze(input, dim=None)
三、参数说明
input
(Tensor): 输入的张量。dim
(int, 可选): 指定要移除的尺寸为1的维度- 如果未指定,函数将移除所有尺寸为1的维度。
- 如果指定的维度不为1,则
torch.squeeze
不会对该维度进行操作 - 如果所有维度都不为1且未指定
dim
参数,则返回的张量与输入张量相同
四、返回值
- 返回一个新的张量,移除了指定的尺寸为1的维度。
- ⚠️如果没有可以移除的维度,则返回与输入相同的张量。
五、示例
以下是一些使用 torch.squeeze
的示例,以帮助更好地理解其用法。
示例 1: 移除所有尺寸为1的维度
python
import torch
# 创建一个张量,其形状为 (1, 3, 1, 5)
x = torch.randn(1, 3, 1, 5)
print("原始张量形状:", x.shape)
# 使用 torch.squeeze 移除所有尺寸为1的维度
y = x.squeeze()
print("移除后张量形状:", y.shape)
输出:
原始张量形状: torch.Size([1, 3, 1, 5])
移除后张量形状: torch.Size([3, 5])
示例 2: 移除指定维度(该维度尺寸为1)
python
import torch
# 创建一个张量,其形状为 (1, 3, 1, 5)
x = torch.randn(1, 3, 1, 5)
print("原始张量形状:", x.shape)
# 指定维度移除,尝试移除第0维
y = x.squeeze(0)
print("移除第0维后的张量形状:", y.shape)
# 尝试移除第2维
z = x.squeeze(2)
print("移除第2维后的张量形状:", z.shape)
输出:
原始张量形状: torch.Size([1, 3, 1, 5])
移除第0维后的张量形状: torch.Size([3, 1, 5])
移除第2维后的张量形状: torch.Size([1, 3, 5])