F.interpolate
是 PyTorch 中用于对张量(通常是图像数据)进行插值操作的函数,常用于调整张量的大小,例如改变图像的分辨率。它支持多种插值方法,包括最近邻插值、双线性插值和三次插值等。
语法
python
torch.nn.functional.interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None)
参数
-
input
:- 输入的张量,形状通常为
(N, C, H, W)
或(N, C, D, H, W)
(批次、通道数、高度、宽度 或深度、高度、宽度)。
- 输入的张量,形状通常为
-
size
:- 调整后张量的目标大小,可以是整数元组,例如
(height, width)
。 - 优先级高于
scale_factor
。
- 调整后张量的目标大小,可以是整数元组,例如
-
scale_factor
:- 用于调整大小的比例因子,可以是浮点数或元组(对于高度和宽度分别指定比例)。
- 如果指定了
size
,此参数会被忽略。
-
mode
:- 指定插值方法,常用选项:
'nearest'
:最近邻插值。'linear'
:线性插值(仅适用于 3D 输入)。'bilinear'
:双线性插值(常用于 2D 图像)。'bicubic'
:双三次插值(适用于 2D 图像)。'trilinear'
:三线性插值(适用于 3D 输入)。'area'
:区域插值,用于下采样。
- 指定插值方法,常用选项:
-
align_corners
:- 仅在
mode
为'linear'
,'bilinear'
,'bicubic'
或'trilinear'
时使用。 - 如果为
True
,则输入和输出的角像素对齐。
- 仅在
返回值
调整大小后的张量。
示例代码
1. 将图像从 640x640 调整为 832x832
python
import torch
import torch.nn.functional as F
# 创建一个随机图像张量,形状为 (batch_size=1, channels=3, height=640, width=640)
img = torch.randn(1, 3, 640, 640)
# 使用 F.interpolate 调整分辨率为 832x832
resized_img = F.interpolate(img, size=(832, 832), mode='bilinear', align_corners=False)
print("Original shape:", img.shape)
print("Resized shape:", resized_img.shape)
2. 使用比例调整图像大小
python
# 使用 scale_factor=1.3 对图像尺寸放大 1.3 倍
scaled_img = F.interpolate(img, scale_factor=1.3, mode='bilinear', align_corners=False)
print("Scaled shape:", scaled_img.shape)
3. 下采样为一半大小
python
# 使用 scale_factor=0.5 对图像尺寸缩小 50%
downsampled_img = F.interpolate(img, scale_factor=0.5, mode='area')
print("Downsampled shape:", downsampled_img.shape)
注意事项
-
align_corners
的影响当
align_corners=True
时,插值会在输入和输出张量的角像素之间进行对齐;否则,计算比例时不对齐角像素。通常推荐align_corners=False
,避免形变或偏移。 -
选择插值方法
- 双线性插值(
bilinear
)和双三次插值(bicubic
)通常适用于图像重采样,生成更平滑的结果。 - 最近邻插值(
nearest
)速度快,但结果不够平滑。
- 双线性插值(
-
处理多通道输入
F.interpolate
可直接处理多通道(如 RGB、IR 数据)的张量,不需要额外操作。