【AAAI2025】风车卷积替代标准卷积,增强了底层特征提取能力

Pinwheel-shaped Convolution and Scale-based Dynamic Loss for Infrared Small Target Detection

基于风车形卷积和尺度动态损失的红外小目标检测

风车形卷积(PConv)模块:

作者提出了一种新颖的风车形卷积(PConv)模块,用于替代标准卷积。该模块更好地符合红外小目标(IRST)的高斯空间分布特性,增强了底层特征提取能力,并显著扩展了感受野,同时仅引入了极少的参数增加。

效果图

代码

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
"""风车型卷积,使用了padding再各个方向上实现方向敏感性,增加了参数量"""

def autopad(k, p=None, d=1):  # kernel, padding, dilation
    """Pad to 'same' shape outputs."""
    if d > 1:
        k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]  # actual kernel-size
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad
    return p


class Conv(nn.Module):
    """Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""

    default_act = nn.SiLU()  # default activation

    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
        """Initialize Conv layer with given arguments including activation."""
        super().__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()

    def forward(self, x):
        """Apply convolution, batch normalization and activation to input tensor."""
        return self.act(self.bn(self.conv(x)))

    def forward_fuse(self, x):
        """Perform transposed convolution of 2D data."""
        return self.act(self.conv(x))


class PConv(nn.Module):  
    ''' Pinwheel-shaped Convolution using the Asymmetric Padding method. '''
    
    def __init__(self, c1, c2, k, s):
        super().__init__()

        # self.k = k
        p = [(k, 0, 1, 0), (0, k, 0, 1), (0, 1, k, 0), (1, 0, 0, k)]
        self.pad = [nn.ZeroPad2d(padding=(p[g])) for g in range(4)]
        self.cw = Conv(c1, c2 // 4, (1, k), s=s, p=0)
        self.ch = Conv(c1, c2 // 4, (k, 1), s=s, p=0)
        self.cat = Conv(c2, c2, 2, s=1, p=0)

    def forward(self, x):
        yw0 = self.cw(self.pad[0](x))
        yw1 = self.cw(self.pad[1](x))
        yh0 = self.ch(self.pad[2](x))
        yh1 = self.ch(self.pad[3](x))
        return self.cat(torch.cat([yw0, yw1, yh0, yh1], dim=1))


if __name__ == "__main__":
    x = torch.randn(1, 32, 64, 64).cuda()
    xm = torch.randn(1, 1, 320, 320).cuda()
    model = PConv(c1=32,c2=32,k=3,s=1).cuda()
    #model = Conv(c1=32, c2=32, k=3, s=1).cuda()
    y  = model(x)
    print(y.size())
    print("最大内存占用:", torch.cuda.max_memory_allocated() // 1024 // 1024, "MB")
    # | 模块   | 参数量 | FLOPs   | 特征提取能力       |
    # |--------|--------|---------|--------------------|
    # | Conv   | 3.7K   | 2.4G    | 各向同性特征        |
    # | PConv  | 25.6K  | 3.2G    | 方向敏感特征        |
相关推荐
人工智能小豪4 小时前
2025年大模型平台落地实践研究报告|附75页PDF文件下载
大数据·人工智能·transformer·anythingllm·ollama·大模型应用
芯盾时代4 小时前
AI在网络安全领域的应用现状和实践
人工智能·安全·web安全·网络安全
黑鹿0224 小时前
机器学习基础(三) 逻辑回归
人工智能·机器学习·逻辑回归
电鱼智能的电小鱼5 小时前
虚拟现实教育终端技术方案——基于EFISH-SCB-RK3588的全场景国产化替代
linux·网络·人工智能·分类·数据挖掘·vr
天天代码码天天5 小时前
C# Onnx 动漫人物头部检测
人工智能·深度学习·神经网络·opencv·目标检测·机器学习·计算机视觉
Joseit6 小时前
从零打造AI面试系统全栈开发
人工智能·面试·职场和发展
小猪猪_16 小时前
多视角学习、多任务学习,迁移学习
人工智能·迁移学习
飞哥数智坊6 小时前
AI编程实战:Cursor 1.0 上手实测,刀更锋利马更快
人工智能·cursor
vlln6 小时前
【论文解读】ReAct:从思考脱离行动, 到行动反馈思考
人工智能·深度学习·机器学习
qq_430908576 小时前
华为ICT和AI智能应用
人工智能·华为