【AAAI2025】风车卷积替代标准卷积,增强了底层特征提取能力

Pinwheel-shaped Convolution and Scale-based Dynamic Loss for Infrared Small Target Detection

基于风车形卷积和尺度动态损失的红外小目标检测

风车形卷积(PConv)模块:

作者提出了一种新颖的风车形卷积(PConv)模块,用于替代标准卷积。该模块更好地符合红外小目标(IRST)的高斯空间分布特性,增强了底层特征提取能力,并显著扩展了感受野,同时仅引入了极少的参数增加。

效果图

代码

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
"""风车型卷积,使用了padding再各个方向上实现方向敏感性,增加了参数量"""

def autopad(k, p=None, d=1):  # kernel, padding, dilation
    """Pad to 'same' shape outputs."""
    if d > 1:
        k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]  # actual kernel-size
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad
    return p


class Conv(nn.Module):
    """Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""

    default_act = nn.SiLU()  # default activation

    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
        """Initialize Conv layer with given arguments including activation."""
        super().__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()

    def forward(self, x):
        """Apply convolution, batch normalization and activation to input tensor."""
        return self.act(self.bn(self.conv(x)))

    def forward_fuse(self, x):
        """Perform transposed convolution of 2D data."""
        return self.act(self.conv(x))


class PConv(nn.Module):  
    ''' Pinwheel-shaped Convolution using the Asymmetric Padding method. '''
    
    def __init__(self, c1, c2, k, s):
        super().__init__()

        # self.k = k
        p = [(k, 0, 1, 0), (0, k, 0, 1), (0, 1, k, 0), (1, 0, 0, k)]
        self.pad = [nn.ZeroPad2d(padding=(p[g])) for g in range(4)]
        self.cw = Conv(c1, c2 // 4, (1, k), s=s, p=0)
        self.ch = Conv(c1, c2 // 4, (k, 1), s=s, p=0)
        self.cat = Conv(c2, c2, 2, s=1, p=0)

    def forward(self, x):
        yw0 = self.cw(self.pad[0](x))
        yw1 = self.cw(self.pad[1](x))
        yh0 = self.ch(self.pad[2](x))
        yh1 = self.ch(self.pad[3](x))
        return self.cat(torch.cat([yw0, yw1, yh0, yh1], dim=1))


if __name__ == "__main__":
    x = torch.randn(1, 32, 64, 64).cuda()
    xm = torch.randn(1, 1, 320, 320).cuda()
    model = PConv(c1=32,c2=32,k=3,s=1).cuda()
    #model = Conv(c1=32, c2=32, k=3, s=1).cuda()
    y  = model(x)
    print(y.size())
    print("最大内存占用:", torch.cuda.max_memory_allocated() // 1024 // 1024, "MB")
    # | 模块   | 参数量 | FLOPs   | 特征提取能力       |
    # |--------|--------|---------|--------------------|
    # | Conv   | 3.7K   | 2.4G    | 各向同性特征        |
    # | PConv  | 25.6K  | 3.2G    | 方向敏感特征        |
相关推荐
deephub8 分钟前
AI代理性能提升实战:LangChain+LangGraph内存管理与上下文优化完整指南
人工智能·深度学习·神经网络·langchain·大语言模型·rag
EulerBlind9 分钟前
【运维】SGLang 安装指南
运维·人工智能·语言模型
心之语歌12 分钟前
Spring AI MCP 客户端
人工智能·spring·github
go54631584651 小时前
基于深度学习的食管癌右喉返神经旁淋巴结预测系统研究
图像处理·人工智能·深度学习·神经网络·算法
Blossom.1181 小时前
基于深度学习的图像分类:使用Capsule Networks实现高效分类
人工智能·python·深度学习·神经网络·机器学习·分类·数据挖掘
宇称不守恒4.01 小时前
2025暑期—05神经网络-卷积神经网络
深度学习·神经网络·cnn
想变成树袋熊2 小时前
【自用】NLP算法面经(6)
人工智能·算法·自然语言处理
格林威2 小时前
Baumer工业相机堡盟工业相机如何通过YoloV8深度学习模型实现沙滩小人检测识别(C#代码UI界面版)
人工智能·深度学习·数码相机·yolo·计算机视觉
checkcheckck2 小时前
spring ai 适配 流式回答、mcp、milvus向量数据库、rag、聊天会话记忆
人工智能
Microvision维视智造2 小时前
从“人工眼”到‘智能眼’:EZ-Vision视觉系统如何重构生产线视觉检测精度?
图像处理·人工智能·重构·视觉检测