每日Attention学习3——Cross-level Feature Fusion

模块出处

[link] [code] [PR 23] Cross-level Feature Aggregation Network for Polyp Segmentation


模块名称

Cross-level Feature Fusion (CFF)


模块作用

双级特征融合


模块结构

模块代码
python 复制代码
import torch
import torch.nn as nn


class BasicConv2d(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_planes, out_planes,
                              kernel_size=kernel_size, stride=stride,
                              padding=padding, dilation=dilation, bias=False)
        self.bn = nn.BatchNorm2d(out_planes)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x
    

class CFF(nn.Module):
    def __init__(self, in_channel1, in_channel2, out_channel):
        self.init__ = super(CFF, self).__init__()
        act_fn         = nn.ReLU(inplace=True)
                
        self.layer0    = BasicConv2d(in_channel1, out_channel // 2, 1)
        self.layer1    = BasicConv2d(in_channel2, out_channel // 2, 1)
        
        self.layer3_1  = nn.Sequential(nn.Conv2d(out_channel, out_channel // 2, kernel_size=3, stride=1, padding=1),  nn.BatchNorm2d(out_channel // 2),act_fn)
        self.layer3_2  = nn.Sequential(nn.Conv2d(out_channel, out_channel // 2, kernel_size=3, stride=1, padding=1),  nn.BatchNorm2d(out_channel // 2),act_fn)
        
        self.layer5_1  = nn.Sequential(nn.Conv2d(out_channel, out_channel // 2, kernel_size=5, stride=1, padding=2),  nn.BatchNorm2d(out_channel // 2),act_fn)
        self.layer5_2  = nn.Sequential(nn.Conv2d(out_channel, out_channel // 2, kernel_size=5, stride=1, padding=2),  nn.BatchNorm2d(out_channel // 2),act_fn)
        
        self.layer_out = nn.Sequential(nn.Conv2d(out_channel // 2, out_channel, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(out_channel),act_fn)


    def forward(self, x0, x1):
        x0_1  = self.layer0(x0)
        x1_1  = self.layer1(x1)
        x_3_1 = self.layer3_1(torch.cat((x0_1,  x1_1),  dim=1))    
        x_5_1 = self.layer5_1(torch.cat((x1_1,  x0_1),  dim=1))
        x_3_2 = self.layer3_2(torch.cat((x_3_1, x_5_1), dim=1))
        x_5_2 = self.layer5_2(torch.cat((x_5_1, x_3_1), dim=1))
        out   = self.layer_out(x0_1 + x1_1 + torch.mul(x_3_2, x_5_2))
        return out
    
if __name__ == '__main__':
    x1 = torch.randn([1, 256, 16, 16])
    x2 = torch.randn([1, 512, 16, 16])
    cff = CFF(in_channel1=256, in_channel2=512, out_channel=64)
    out = cff(x1, x2)
    print(out.shape)  # 1, 64, 16, 16

原文表述

利用特征提取网络可以获得不同分辨率的多级特征。因此,有效整合多级特征非常重要,这可以提高不同尺度特征的表示能力。因此,我们提出了一个 CFF模块来融合相邻的两个特征,然后将其输入分割网络。

相关推荐
张较瘦_3 小时前
[论文阅读] AI | 用机器学习给深度学习库“体检”:大幅提升测试效率的新思路
论文阅读·人工智能·机器学习
m0_6501082418 小时前
IntNet:面向协同自动驾驶的通信驱动多智能体强化学习框架
论文阅读·marl·多智能体系统·网联自动驾驶·意图共享·自适应通讯·端到端协同
m0_650108241 天前
Raw2Drive:基于对齐世界模型的端到端自动驾驶强化学习方案
论文阅读·机器人·强化学习·端到端自动驾驶·双流架构·引导机制·mbrl自动驾驶
快降重科研小助手1 天前
前瞻与规范:AIGC降重API的技术演进与负责任使用
论文阅读·aigc·ai写作·降重·降ai·快降重
源于花海2 天前
IEEE TIE期刊论文学习——基于元学习与小样本重训练的锂离子电池健康状态估计方法
论文阅读·元学习·电池健康管理·并行网络·小样本重训练
m0_650108242 天前
UniDrive-WM:自动驾驶领域的统一理解、规划与生成世界模型
论文阅读·自动驾驶·轨迹规划·感知、规划与生成融合·场景理解·未来图像生成
蓝田生玉1232 天前
LLaMA论文阅读笔记
论文阅读·笔记·llama
*西瓜2 天前
基于深度学习的视觉水位识别技术与装备
论文阅读·深度学习
大模型最新论文速读2 天前
BAR-RAG: 通过边界感知训练让单轮 RAG 效果媲美深度研究
论文阅读·人工智能·深度学习·机器学习·自然语言处理
觉醒大王3 天前
科研新手如何读文献?从“乱读”到“会读”
论文阅读·笔记·深度学习·学习·自然语言处理·学习方法