PyTorch中grid_sample的使用方法

官方文档

首先Pytorch中grid_sample函数的接口声明如下:

复制代码
torch.nn.functional.grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=None)
  • input : 输入tensor, shape为 [N, C, H_in, W_in]
  • grid: 一个field flow, shape为[N, H_out, W_out, 2],最后一个维度是每个grid(H_out_i, W_out_i)在input的哪个位置的邻域去采点。数值范围被归一化到[-1,1]。

这里的input和output就是输入的图片,或者是网络中的feature map。关键的处理过程在于grid,grid的最后一维的大小为2,即表示input中pixel的位置信息 (x,y) ,这里一般会将x和y的取值范围归一化到 [−1,1] 之间, (−1,−1) 表示input左上角的像素的坐标,(1,1) 表示input右下角的像素的坐标,对于超出这个范围的坐标(x,y),函数将会根据参数_padding_mode_的设定进行不同的处理。

  • padding_mode='zeros':对于越界的位置在网格中采用pixel value=0进行填充。
  • padding_mode='border':对于越界的位置在网格中采用边界的pixel value进行填充。
  • padding_mode='reflection':对于越界的位置在网格中采用关于边界的对称值进行填充。

对于mode='bilinear'参数,则定义了在input中指定位置的pixel value中进行插值的方法,为什么需要插值呢?因为前面我们说了,grid中表示的位置信息x和y的取值范围在 [−1,1] 之间,这就意味着我们要根据一个浮点型的坐标值在input中对pixel value进行采样,mode有nearest和bilinear两种模式。

  • nearest就是直接采用与 (x,y) 距离最近处的像素值来填充grid
  • bilinear则是采用双线性插值的方法来进行填充,总之其与nearest的区别就是nearest只考虑最近点的pixel value,而bilinear则采用(x,y)周围的四个pixel value进行加权平均值来填充grid。

双线性插值:

举例:

python 复制代码
import torch
from torch.nn import functional as F


inp = torch.ones(1, 128, 4, 4)

# 目的是得到一个 长宽为20的tensor
out_h = 20
out_w = 20
grid_x, grid_y = torch.meshgrid(
        torch.linspace(-1, 1, out_h),
        torch.linspace(-1, 1, out_w)
    )
# grid 最后一维度表示在input采样的位置(x,y),y表示图像纵轴,x表示横轴,grid顺序应该先x递增,后y递增
grid = torch.stack((grid_y, grid_x), dim=-1).unsqueeze(0) # (out_h, out_w, 2)
# F.grid_sample -> input:(N,C,Hin,Win), grid:(N,Hout,Wout,2), output:(N,C,Hout,Wout)
# outp = F.grid_sample(features, grid, align_corners=True, mode='bilinear')
outp = F.grid_sample(inp, grid, align_corners=True, mode='nearest')
print(outp.shape) # torch.Size([1, 128, 20, 20])

对图像,特征进行采样用以上grid才不会图像位置错误

相关推荐
一只幸运猫.3 分钟前
2026Java 后端面试完整版|八股简答 + AI 大模型集成技术(最新趋势)
人工智能·面试·职场和发展
Promise微笑10 分钟前
2026年国产替代油介损测试仪:油介损全场景解决方案与技术演进
大数据·网络·人工智能
深海鱼在掘金16 分钟前
深入浅出 LangChain —— 第三章:模型抽象层
人工智能·langchain·agent
生信碱移16 分钟前
PACells:这个方法可以鉴定疾病/预后相关的重要细胞亚群,作者提供的代码流程可以学习起来了,甚至兼容转录组与 ATAC 两种数据类型!
人工智能·学习·算法·机器学习·数据挖掘·数据分析·r语言
workflower25 分钟前
具身智能行业应用-生活服务业
大数据·人工智能·机器人·动态规划·生活
蜡台32 分钟前
Python包管理工具pip完全指南-----2
linux·windows·python
Mr.朱鹏34 分钟前
【Python 进阶 | 第四篇】Psycopg3 + Flask 实现 PostgreSQL CRUD 全流程:从连接池到RESTful接口
python·postgresql·flask·virtualenv·fastapi·pip·tornado
GitCode官方38 分钟前
基于昇腾 MindSpeed LLM 玩转 DeepSeekV4-Flash 模型的预训练复现部署
人工智能·开源·atomgit
大刘讲IT1 小时前
AI重塑企业信息价值标准:从“系统供给”到“用户定义”的企业数字化新范式
人工智能·经验分享·ai·制造
流年似水~1 小时前
MCP协议实战:从零搭建一个让Claude能“看见“数据库的工具服务
数据库·人工智能·程序人生·ai·ai编程