每日Attention学习7——Frequency-Perception Module

模块出处

[link] [code] [ACM MM 23] Frequency Perception Network for Camouflaged Object Detection


模块名称

Frequency-Perception Module (FPM)


模块作用

获取频域信息,更好识别伪装对象


模块结构
模块代码
python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F


class FirstOctaveConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, alpha=0.5, stride=1, padding=1, dilation=1,
                 groups=1, bias=False):
        super(FirstOctaveConv, self).__init__()
        self.stride = stride
        kernel_size = kernel_size[0]
        self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        self.h2l = torch.nn.Conv2d(in_channels, int(alpha * in_channels),
                                   kernel_size, 1, padding, dilation, groups, bias)
        self.h2h = torch.nn.Conv2d(in_channels, in_channels - int(alpha * in_channels),
                                   kernel_size, 1, padding, dilation, groups, bias)

    def forward(self, x):
        if self.stride ==2:
            x = self.h2g_pool(x)
        X_h2l = self.h2g_pool(x)
        X_h = x
        X_h = self.h2h(X_h)
        X_l = self.h2l(X_h2l)
        return X_h, X_l
    

class OctaveConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, alpha=0.5, stride=1, padding=1, dilation=1,
                 groups=1, bias=False):
        super(OctaveConv, self).__init__()
        kernel_size = kernel_size[0]
        self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        self.upsample = torch.nn.Upsample(scale_factor=2, mode='nearest')
        self.stride = stride
        self.l2l = torch.nn.Conv2d(int(alpha * in_channels), int(alpha * out_channels),
                                   kernel_size, 1, padding, dilation, groups, bias)
        self.l2h = torch.nn.Conv2d(int(alpha * in_channels), out_channels - int(alpha * out_channels),
                                   kernel_size, 1, padding, dilation, groups, bias)
        self.h2l = torch.nn.Conv2d(in_channels - int(alpha * in_channels), int(alpha * out_channels),
                                   kernel_size, 1, padding, dilation, groups, bias)
        self.h2h = torch.nn.Conv2d(in_channels - int(alpha * in_channels),
                                   out_channels - int(alpha * out_channels),
                                   kernel_size, 1, padding, dilation, groups, bias)

    def forward(self, x):
        X_h, X_l = x
        if self.stride == 2:
            X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)
        X_h2l = self.h2g_pool(X_h)
        X_h2h = self.h2h(X_h)
        X_l2h = self.l2h(X_l)
        X_l2l = self.l2l(X_l)
        X_h2l = self.h2l(X_h2l)
        X_l2h = F.interpolate(X_l2h, (int(X_h2h.size()[2]),int(X_h2h.size()[3])), mode='bilinear')
        X_h = X_l2h + X_h2h
        X_l = X_h2l + X_l2l
        return X_h, X_l


class LastOctaveConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, alpha=0.5, stride=1, padding=1, dilation=1,
                 groups=1, bias=False):
        super(LastOctaveConv, self).__init__()
        self.stride = stride
        kernel_size = kernel_size[0]
        self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        self.l2h = torch.nn.Conv2d(int(alpha * out_channels), out_channels,
                                   kernel_size, 1, padding, dilation, groups, bias)
        self.h2h = torch.nn.Conv2d(out_channels - int(alpha * out_channels),
                                   out_channels,
                                   kernel_size, 1, padding, dilation, groups, bias)
        self.upsample = torch.nn.Upsample(scale_factor=2, mode='nearest')
    def forward(self, x):
        X_h, X_l = x
        if self.stride == 2:
            X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)
        X_h2h = self.h2h(X_h) 
        X_l2h = self.l2h(X_l) 
        X_l2h = F.interpolate(X_l2h, (int(X_h2h.size()[2]), int(X_h2h.size()[3])), mode='bilinear')
        X_h = X_h2h + X_l2h 
        return X_h
    

class FPM(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=(3, 3)):
        super(FPM, self).__init__()
        self.fir = FirstOctaveConv(in_channels, out_channels, kernel_size)
        self.mid1 = OctaveConv(in_channels, in_channels, kernel_size)
        self.mid2 = OctaveConv(in_channels, out_channels, kernel_size)
        self.lst = LastOctaveConv(in_channels, out_channels, kernel_size)

    def forward(self, x):
        x_h, x_l = self.fir(x)                  
        x_h_1, x_l_1 = self.mid1((x_h, x_l))     
        x_h_2, x_l_2 = self.mid1((x_h_1, x_l_1)) 
        x_h_5, x_l_5 = self.mid2((x_h_2, x_l_2)) 
        x_ret = self.lst((x_h_5, x_l_5))
        return x_ret
    

if __name__ == '__main__':
    x = torch.randn([3, 256, 16, 16])
    fpm = FPM(in_channels=256, out_channels=64)
    out = fpm(x)
    print(out.shape)  # 3, 64, 16, 16

原文表述

具体来说,我们采用八度卷积以端到端的方式自动感知高频和低频信息,从而实现伪装物体检测的在线学习。八度卷积可以有效避免DCT 引起的块状效应,并利用GPU的计算速度优势。此外,它可以轻松插入任意网络。

相关推荐
张较瘦_1 天前
[论文阅读] AI + 软件工程 | LLM救场Serverless开发!SlsReuse框架让函数复用率飙升至91%,还快了44%
论文阅读·人工智能·软件工程
m0_650108241 天前
InstructBLIP:面向通用视觉语言模型的指令微调技术解析
论文阅读·人工智能·q-former·指令微调的视觉语言大模型·零样本跨任务泛化·通用视觉语言模型
做cv的小昊2 天前
VLM经典论文阅读:【综述】An Introduction to Vision-Language Modeling
论文阅读·人工智能·计算机视觉·语言模型·自然语言处理·bert·transformer
m0_650108242 天前
PaLM-E:具身智能的多模态语言模型新范式
论文阅读·人工智能·机器人·具身智能·多模态大语言模型·palm-e·大模型驱动
m0_650108243 天前
PaLM:Pathways 驱动的大规模语言模型 scaling 实践
论文阅读·人工智能·palm·谷歌大模型·大规模语言模型·全面评估与行为分析·scaling效应
小殊小殊3 天前
【论文笔记】视频RAG-Vgent:基于图结构的视频检索推理框架
论文阅读·人工智能·深度学习
有点不太正常3 天前
《ShadowCoT: Cognitive Hijacking for Stealthy Reasoning Backdoors in LLMs》——论文阅读
论文阅读·大模型·agent安全
小殊小殊3 天前
【论文笔记】大型语言模型的知识蒸馏与数据集蒸馏
论文阅读·人工智能·深度学习
SatoshiGogo4 天前
AIGC 论文笔记
论文阅读·aigc
walnut_oyb5 天前
arXiv|SARLANG-1M:用于 SAR 图像理解的视觉-语言建模基准
论文阅读·人工智能·机器学习·计算机视觉·语言模型·自然语言处理