pytorch代码实现之CoordConv卷积

CoordConv卷积

在深度学习领域,几乎没有什么想法能像卷积那样产生如此大的影响。对于任何涉及像素或空间表示的问题,普遍的直觉认为卷积神经网络可能是合适的。在本文中,我们通过看似平凡的坐标变换问题展示了一个惊人的反例,该问题只需要学习(x, y)笛卡尔空间中的坐标与单热像素空间中的坐标之间的映射。虽然卷积网络似乎适合这项任务,但我们表明它们失败得很明显。

CoordConv的工作原理是通过使用额外的坐标通道让卷积访问自己的输入坐标。在不牺牲普通卷积的计算和参数效率的情况下,CoordConv允许网络根据最终任务的需要学习完全的平移不变性或不同程度的平移依赖性。CoordConv解决了坐标变换问题,具有很好的泛化性,比convolution的参数少10-100倍,速度快150倍。

原文地址:An intriguing failing of convolutional neural networks and the CoordConv solution

代码实现:

matlab 复制代码
class AddCoords(nn.Module):
    def __init__(self, with_r=False):
        super().__init__()
        self.with_r = with_r

    def forward(self, input_tensor):
        """
        Args:
            input_tensor: shape(batch, channel, x_dim, y_dim)
        """
        batch_size, _, x_dim, y_dim = input_tensor.size()

        xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1)
        yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2)

        xx_channel = xx_channel.float() / (x_dim - 1)
        yy_channel = yy_channel.float() / (y_dim - 1)

        xx_channel = xx_channel * 2 - 1
        yy_channel = yy_channel * 2 - 1

        xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
        yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)

        ret = torch.cat([
            input_tensor,
            xx_channel.type_as(input_tensor),
            yy_channel.type_as(input_tensor)], dim=1)

        if self.with_r:
            rr = torch.sqrt(torch.pow(xx_channel.type_as(input_tensor) - 0.5, 2) + torch.pow(yy_channel.type_as(input_tensor) - 0.5, 2))
            ret = torch.cat([ret, rr], dim=1)

        return ret

class CoordConv(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, with_r=False):
        super().__init__()
        self.addcoords = AddCoords(with_r=with_r)
        in_channels += 2
        if with_r:
            in_channels += 1
        self.conv = Conv(in_channels, out_channels, k=kernel_size, s=stride)

    def forward(self, x):
        x = self.addcoords(x)
        x = self.conv(x)
        return x
相关推荐
_妲己几秒前
SD的细分功能包括重绘,图像处理、放大等扩散模型应用
人工智能·python·深度学习·机器学习·stable diffusion·comfyui·ai工作流
程途拾光1583 分钟前
企业组织架构图导出Word 在线编辑免费工具
大数据·论文阅读·人工智能·信息可视化·架构·word·流程图
AI浩3 分钟前
MODA:首个用于航空图像中多光谱目标检测的挑战性基准
人工智能·目标检测·目标跟踪
小热茶5 分钟前
浮点数计算专题【五、 IEEE 754 浮点乘法算法详解---基于RISCV的FP32乘法指令在五级流水线的运行分析与SystemC实现】
人工智能·嵌入式硬件·算法·systemc
一只乔哇噻5 分钟前
java后端工程师+AI大模型开发进修ing(研一版‖day63)
java·开发语言·人工智能·python·语言模型
Giser探索家5 分钟前
卫星遥感数据核心参数解析:空间分辨率与时间分辨率
大数据·图像处理·人工智能·深度学习·算法·计算机视觉
小白学大数据6 分钟前
从爬取到分析:使用 Pandas 处理头条问答数据
开发语言·爬虫·python·pandas
微盛企微增长小知识6 分钟前
2025企业微信智能表格使用全指南:AI驱动的数据管理实战
大数据·人工智能·企业微信
jghhh016 分钟前
基于 MATLAB 的光照不均匀图像增强
opencv·计算机视觉·matlab
分布式存储与RustFS7 分钟前
MinIO替代方案精选:RustFS深度评测与选型指南
人工智能·rust·开源项目·对象存储·minio·企业存储·rustfs