每日Attention学习7——Frequency-Perception Module

模块出处

[link] [code] [ACM MM 23] Frequency Perception Network for Camouflaged Object Detection


模块名称

Frequency-Perception Module (FPM)


模块作用

获取频域信息,更好识别伪装对象


模块结构
模块代码
python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F


class FirstOctaveConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, alpha=0.5, stride=1, padding=1, dilation=1,
                 groups=1, bias=False):
        super(FirstOctaveConv, self).__init__()
        self.stride = stride
        kernel_size = kernel_size[0]
        self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        self.h2l = torch.nn.Conv2d(in_channels, int(alpha * in_channels),
                                   kernel_size, 1, padding, dilation, groups, bias)
        self.h2h = torch.nn.Conv2d(in_channels, in_channels - int(alpha * in_channels),
                                   kernel_size, 1, padding, dilation, groups, bias)

    def forward(self, x):
        if self.stride ==2:
            x = self.h2g_pool(x)
        X_h2l = self.h2g_pool(x)
        X_h = x
        X_h = self.h2h(X_h)
        X_l = self.h2l(X_h2l)
        return X_h, X_l
    

class OctaveConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, alpha=0.5, stride=1, padding=1, dilation=1,
                 groups=1, bias=False):
        super(OctaveConv, self).__init__()
        kernel_size = kernel_size[0]
        self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        self.upsample = torch.nn.Upsample(scale_factor=2, mode='nearest')
        self.stride = stride
        self.l2l = torch.nn.Conv2d(int(alpha * in_channels), int(alpha * out_channels),
                                   kernel_size, 1, padding, dilation, groups, bias)
        self.l2h = torch.nn.Conv2d(int(alpha * in_channels), out_channels - int(alpha * out_channels),
                                   kernel_size, 1, padding, dilation, groups, bias)
        self.h2l = torch.nn.Conv2d(in_channels - int(alpha * in_channels), int(alpha * out_channels),
                                   kernel_size, 1, padding, dilation, groups, bias)
        self.h2h = torch.nn.Conv2d(in_channels - int(alpha * in_channels),
                                   out_channels - int(alpha * out_channels),
                                   kernel_size, 1, padding, dilation, groups, bias)

    def forward(self, x):
        X_h, X_l = x
        if self.stride == 2:
            X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)
        X_h2l = self.h2g_pool(X_h)
        X_h2h = self.h2h(X_h)
        X_l2h = self.l2h(X_l)
        X_l2l = self.l2l(X_l)
        X_h2l = self.h2l(X_h2l)
        X_l2h = F.interpolate(X_l2h, (int(X_h2h.size()[2]),int(X_h2h.size()[3])), mode='bilinear')
        X_h = X_l2h + X_h2h
        X_l = X_h2l + X_l2l
        return X_h, X_l


class LastOctaveConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, alpha=0.5, stride=1, padding=1, dilation=1,
                 groups=1, bias=False):
        super(LastOctaveConv, self).__init__()
        self.stride = stride
        kernel_size = kernel_size[0]
        self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        self.l2h = torch.nn.Conv2d(int(alpha * out_channels), out_channels,
                                   kernel_size, 1, padding, dilation, groups, bias)
        self.h2h = torch.nn.Conv2d(out_channels - int(alpha * out_channels),
                                   out_channels,
                                   kernel_size, 1, padding, dilation, groups, bias)
        self.upsample = torch.nn.Upsample(scale_factor=2, mode='nearest')
    def forward(self, x):
        X_h, X_l = x
        if self.stride == 2:
            X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)
        X_h2h = self.h2h(X_h) 
        X_l2h = self.l2h(X_l) 
        X_l2h = F.interpolate(X_l2h, (int(X_h2h.size()[2]), int(X_h2h.size()[3])), mode='bilinear')
        X_h = X_h2h + X_l2h 
        return X_h
    

class FPM(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=(3, 3)):
        super(FPM, self).__init__()
        self.fir = FirstOctaveConv(in_channels, out_channels, kernel_size)
        self.mid1 = OctaveConv(in_channels, in_channels, kernel_size)
        self.mid2 = OctaveConv(in_channels, out_channels, kernel_size)
        self.lst = LastOctaveConv(in_channels, out_channels, kernel_size)

    def forward(self, x):
        x_h, x_l = self.fir(x)                  
        x_h_1, x_l_1 = self.mid1((x_h, x_l))     
        x_h_2, x_l_2 = self.mid1((x_h_1, x_l_1)) 
        x_h_5, x_l_5 = self.mid2((x_h_2, x_l_2)) 
        x_ret = self.lst((x_h_5, x_l_5))
        return x_ret
    

if __name__ == '__main__':
    x = torch.randn([3, 256, 16, 16])
    fpm = FPM(in_channels=256, out_channels=64)
    out = fpm(x)
    print(out.shape)  # 3, 64, 16, 16

原文表述

具体来说,我们采用八度卷积以端到端的方式自动感知高频和低频信息,从而实现伪装物体检测的在线学习。八度卷积可以有效避免DCT 引起的块状效应,并利用GPU的计算速度优势。此外,它可以轻松插入任意网络。

相关推荐
传说故事3 小时前
【论文阅读】AWR:Simple and scalable off-policy RL
论文阅读·强化学习
传说故事3 小时前
【论文阅读】通过homeostasis RL学习合成综合机器人行为
论文阅读·人工智能·机器人·具身智能
数智工坊4 小时前
【VarifocalNet(VFNet)论文阅读】:IoU-aware稠密目标检测,把定位质量塞进分类得分
论文阅读·人工智能·深度学习·目标检测·计算机视觉·分类·cnn
STLearner1 天前
AI论文速读 | QuitoBench:支付宝高质量开源时间序列预测基准测试集
大数据·论文阅读·人工智能·深度学习·学习·机器学习·开源
数智工坊1 天前
【Anchor DETR论文阅读】:基于锚点查询设计的Transformer检测器,50epoch收敛且速度精度双升
论文阅读·深度学习·transformer
数智工坊1 天前
【DAB-DETR论文阅读】:动态锚框作为更优查询,彻底解决DETR训练收敛慢难题
网络·论文阅读·人工智能·深度学习·cnn
DuHz2 天前
论文精读:大语言模型 (Large Language Models, LLM) —— 一项调查
论文阅读·人工智能·深度学习·算法·机器学习·计算机视觉·语言模型
锅挤2 天前
来一篇儿:《Saliency Attack: Towards Imperceptible Black-box Adversarial Attack》
论文阅读
Chunyyyen2 天前
【第四十二周】论文阅读
论文阅读·学习
数智工坊3 天前
【DETR论文阅读】端到端目标检测新范式:Transformer改写检测 pipeline
论文阅读·目标检测·transformer