每日Attention学习18——Grouped Attention Gate

模块出处

ICLR 25 Submission\] [\[link\]](https://openreview.net/forum?id=BefqqrgdZ1) UltraLightUNet: Rethinking U-shaped Network with Multi-kernel Lightweight Convolutions for Medical Image Segmentation *** ** * ** *** ##### 模块名称 Grouped Attention Gate (GAG) *** ** * ** *** ##### 模块作用 轻量特征融合 *** ** * ** *** ##### 模块结构 ![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/a1631ac9f4b94bbc83d1c80f106cbd2c.jpeg) *** ** * ** *** ##### 模块特点 * 特征融合前使用Group Conv进行处理,比标准卷积更加轻量 * 将融合得到的粗特征视为Spatial Attention Map, 并与Encoder特征相乘,从而实现名字中"Gate"的效果 * 相较于特征融合模块,也可以视为一种利用辅助信息(Decoder)特征以增强Encoder特征的增强模块 *** ** * ** *** ##### 模块代码 ```python import torch import torch.nn as nn import torch.nn.functional as F class GAG(nn.Module): def __init__(self, F_g, F_l, F_int, kernel_size=1, groups=1): super(GAG,self).__init__() if kernel_size == 1: groups = 1 self.W_g = nn.Sequential( nn.Conv2d(F_g, F_int, kernel_size=kernel_size,stride=1,padding=kernel_size//2,groups=groups, bias=True), nn.BatchNorm2d(F_int) ) self.W_x = nn.Sequential( nn.Conv2d(F_l, F_int, kernel_size=kernel_size,stride=1,padding=kernel_size//2,groups=groups, bias=True), nn.BatchNorm2d(F_int) ) self.psi = nn.Sequential( nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True), nn.BatchNorm2d(1), nn.Sigmoid() ) self.activation = nn.ReLU(inplace=True) def forward(self,g,x): g1 = self.W_g(g) x1 = self.W_x(x) psi = self.activation(g1+x1) psi = self.psi(psi) return x*psi if __name__ == '__main__': x1 = torch.randn([1, 64, 44, 44]) x2 = torch.randn([1, 64, 44, 44]) gag = GAG(F_g=64, F_l=64, F_int=64//2, kernel_size=3, groups=64//2) out = gag(x1, x2) print(out.shape) # [1, 64, 44, 44] ``` *** ** * ** ***

相关推荐
xiongxyowo8 天前
每日Attention学习27——Patch-based Graph Reasoning
划水
xiongxyowo1 个月前
每日Attention学习24——Strip Convolution Block
划水
xiongxyowo1 个月前
每日Attention学习23——KAN-Block
划水
xiongxyowo2 个月前
每日Attention学习19——Convolutional Multi-Focal Attention
划水
xiongxyowo2 年前
Pytorch基于VGG cosine similarity实现简单的以图搜图(图像检索)
划水