每日Attention学习7——Frequency-Perception Module

模块出处

[link] [code] [ACM MM 23] Frequency Perception Network for Camouflaged Object Detection


模块名称

Frequency-Perception Module (FPM)


模块作用

获取频域信息,更好识别伪装对象


模块结构
模块代码
python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F


class FirstOctaveConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, alpha=0.5, stride=1, padding=1, dilation=1,
                 groups=1, bias=False):
        super(FirstOctaveConv, self).__init__()
        self.stride = stride
        kernel_size = kernel_size[0]
        self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        self.h2l = torch.nn.Conv2d(in_channels, int(alpha * in_channels),
                                   kernel_size, 1, padding, dilation, groups, bias)
        self.h2h = torch.nn.Conv2d(in_channels, in_channels - int(alpha * in_channels),
                                   kernel_size, 1, padding, dilation, groups, bias)

    def forward(self, x):
        if self.stride ==2:
            x = self.h2g_pool(x)
        X_h2l = self.h2g_pool(x)
        X_h = x
        X_h = self.h2h(X_h)
        X_l = self.h2l(X_h2l)
        return X_h, X_l
    

class OctaveConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, alpha=0.5, stride=1, padding=1, dilation=1,
                 groups=1, bias=False):
        super(OctaveConv, self).__init__()
        kernel_size = kernel_size[0]
        self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        self.upsample = torch.nn.Upsample(scale_factor=2, mode='nearest')
        self.stride = stride
        self.l2l = torch.nn.Conv2d(int(alpha * in_channels), int(alpha * out_channels),
                                   kernel_size, 1, padding, dilation, groups, bias)
        self.l2h = torch.nn.Conv2d(int(alpha * in_channels), out_channels - int(alpha * out_channels),
                                   kernel_size, 1, padding, dilation, groups, bias)
        self.h2l = torch.nn.Conv2d(in_channels - int(alpha * in_channels), int(alpha * out_channels),
                                   kernel_size, 1, padding, dilation, groups, bias)
        self.h2h = torch.nn.Conv2d(in_channels - int(alpha * in_channels),
                                   out_channels - int(alpha * out_channels),
                                   kernel_size, 1, padding, dilation, groups, bias)

    def forward(self, x):
        X_h, X_l = x
        if self.stride == 2:
            X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)
        X_h2l = self.h2g_pool(X_h)
        X_h2h = self.h2h(X_h)
        X_l2h = self.l2h(X_l)
        X_l2l = self.l2l(X_l)
        X_h2l = self.h2l(X_h2l)
        X_l2h = F.interpolate(X_l2h, (int(X_h2h.size()[2]),int(X_h2h.size()[3])), mode='bilinear')
        X_h = X_l2h + X_h2h
        X_l = X_h2l + X_l2l
        return X_h, X_l


class LastOctaveConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, alpha=0.5, stride=1, padding=1, dilation=1,
                 groups=1, bias=False):
        super(LastOctaveConv, self).__init__()
        self.stride = stride
        kernel_size = kernel_size[0]
        self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        self.l2h = torch.nn.Conv2d(int(alpha * out_channels), out_channels,
                                   kernel_size, 1, padding, dilation, groups, bias)
        self.h2h = torch.nn.Conv2d(out_channels - int(alpha * out_channels),
                                   out_channels,
                                   kernel_size, 1, padding, dilation, groups, bias)
        self.upsample = torch.nn.Upsample(scale_factor=2, mode='nearest')
    def forward(self, x):
        X_h, X_l = x
        if self.stride == 2:
            X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)
        X_h2h = self.h2h(X_h) 
        X_l2h = self.l2h(X_l) 
        X_l2h = F.interpolate(X_l2h, (int(X_h2h.size()[2]), int(X_h2h.size()[3])), mode='bilinear')
        X_h = X_h2h + X_l2h 
        return X_h
    

class FPM(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=(3, 3)):
        super(FPM, self).__init__()
        self.fir = FirstOctaveConv(in_channels, out_channels, kernel_size)
        self.mid1 = OctaveConv(in_channels, in_channels, kernel_size)
        self.mid2 = OctaveConv(in_channels, out_channels, kernel_size)
        self.lst = LastOctaveConv(in_channels, out_channels, kernel_size)

    def forward(self, x):
        x_h, x_l = self.fir(x)                  
        x_h_1, x_l_1 = self.mid1((x_h, x_l))     
        x_h_2, x_l_2 = self.mid1((x_h_1, x_l_1)) 
        x_h_5, x_l_5 = self.mid2((x_h_2, x_l_2)) 
        x_ret = self.lst((x_h_5, x_l_5))
        return x_ret
    

if __name__ == '__main__':
    x = torch.randn([3, 256, 16, 16])
    fpm = FPM(in_channels=256, out_channels=64)
    out = fpm(x)
    print(out.shape)  # 3, 64, 16, 16

原文表述

具体来说,我们采用八度卷积以端到端的方式自动感知高频和低频信息,从而实现伪装物体检测的在线学习。八度卷积可以有效避免DCT 引起的块状效应,并利用GPU的计算速度优势。此外,它可以轻松插入任意网络。

相关推荐
有Li17 小时前
解剖学引导的全身PET-CT乳腺癌分割与跨模态自对齐/文献速递-基于深度学习的图像配准与疾病诊断
论文阅读·人工智能·深度学习·文献·医学生
s1ckrain17 小时前
【论文阅读】Towards Learning a Generalist Model for Embodied Navigation
论文阅读·多模态·具身智能
有Li1 天前
用于CBCT到CT合成的纹理保留扩散模型/文献速递-基于人工智能的医学影像技术
论文阅读·人工智能·深度学习·计算机视觉·文献
CV-杨帆2 天前
论文阅读:arixv 2026 Reasoning Models Generate Societies of Thought
论文阅读
YMWM_2 天前
论文阅读“MV-UMI: A Scalable Multi-View Interface for Cross-Embodiment Learning“
论文阅读·umi
YMWM_2 天前
论文阅读“Tactile-reactive gripper with an active palm for dexterous manipulation“
论文阅读·palm·tactile gripper
CV-杨帆3 天前
论文阅读:2026 techrxiv Jailbreak-as-a-Service: The Emerging Threat Landscape
论文阅读
张较瘦_4 天前
[论文阅读] 软件测试 | 跨语言模糊测试大揭秘:C++/Rust/Python谁更胜一筹?
c++·论文阅读·rust
青衫码上行4 天前
Redis常用数据类型操作命令
java·数据库·论文阅读·redis·学习
蓝田生玉1234 天前
qwen论文阅读笔记
论文阅读·笔记