每日Attention学习15——Cross-Model Grafting Module

模块出处

CVPR 22\] [\[link\]](https://openaccess.thecvf.com/content/CVPR2022/html/Xie_Pyramid_Grafting_Network_for_One-Stage_High_Resolution_Saliency_Detection_CVPR_2022_paper.html) [\[code\]](https://github.com/iCVTEAM/PGNet) Pyramid Grafting Network for One-Stage High Resolution Saliency Detection *** ** * ** *** ##### 模块名称 Cross-Model Grafting Module (CMGM) *** ** * ** *** ##### 模块作用 Transformer与CNN之间的特征融合 *** ** * ** *** ##### 模块结构 ![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/e889d03727a149fd952730a5de5ae405.jpeg) *** ** * ** *** ##### 模块思想 Transformer在全局特征上更优,CNN在局部特征上更优,对这两者进行进行融合的最简单做法是直接相加或相乘。但是,相加或相乘本质上属于"局部"操作,如果某片区域两个特征的不确定性都较高,则会带来许多噪声。为此,本文提出了CMGM模块,通过交叉注意力的形式引入更为广泛的信息来增强融合效果。 *** ** * ** *** ##### 模块代码 ```python import torch.nn.functional as F import torch.nn as nn import torch class CMGM(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 self.k = nn.Linear(dim, dim , bias=qkv_bias) self.qv = nn.Linear(dim, dim * 2, bias=qkv_bias) self.proj = nn.Linear(dim, dim) self.act = nn.ReLU(inplace=True) self.conv = nn.Conv2d(8,8,kernel_size=3, stride=1, padding=1) self.lnx = nn.LayerNorm(64) self.lny = nn.LayerNorm(64) self.bn = nn.BatchNorm2d(8) self.conv2 = nn.Sequential( nn.Conv2d(64,64,kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.Conv2d(64,64,kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True) ) def forward(self, x, y): batch_size = x.shape[0] chanel = x.shape[1] sc = x x = x.view(batch_size, chanel, -1).permute(0, 2, 1) sc1 = x x = self.lnx(x) y = y.view(batch_size, chanel, -1).permute(0, 2, 1) y = self.lny(y) B, N, C = x.shape y_k = self.k(y).reshape(B, N, 1, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) x_qv= self.qv(x).reshape(B,N,2,self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) x_q, x_v = x_qv[0], x_qv[1] y_k = y_k[0] attn = (x_q @ y_k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) x = (attn @ x_v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = (x+sc1) x = x.permute(0,2,1) x = x.view(batch_size,chanel,*sc.size()[2:]) x = self.conv2(x)+x return x, self.act(self.bn(self.conv(attn+attn.transpose(-1,-2)))) if __name__ == '__main__': x = torch.randn([1, 64, 11, 11]) y = torch.randn([1, 64, 11, 11]) cmgm = CMGM(dim=64) out1, out2 = cmgm(x, y) print(out1.shape) # out feature 1, 64, 11, 11 print(out2.shape) # cross attention matrix 1, 8, 121, 121 ``` *** ** * ** ***

相关推荐
檐下翻书1737 小时前
从入门到精通:流程图制作学习路径规划
论文阅读·人工智能·学习·算法·流程图·论文笔记
iiiiii119 小时前
【论文阅读笔记】多实例学习方法 Diverse Density(DD):在特征空间中寻找正概念的坐标
论文阅读·人工智能·笔记·机器学习·ai·学习方法·多实例学习
ModestCoder_10 小时前
【学习笔记】Diffusion Policy for Robotics
论文阅读·人工智能·笔记·学习·机器人·强化学习·具身智能
川西胖墩墩14 小时前
流程图在算法设计中的实战应用
数据库·论文阅读·人工智能·职场和发展·流程图
檐下翻书1733 天前
流程图配色与美化:让你的图表会“说话”
论文阅读·人工智能·信息可视化·流程图·论文笔记
wbzuo4 天前
Clip:Learning Transferable Visual Models From Natural Language Supervision
论文阅读·人工智能·transformer
想成为PhD的小提琴手5 天前
论文阅读13——基于大语言模型和视觉模态融合的可解释端到端自动驾驶框架:DriveLLM-V的设计与应用
论文阅读·语言模型·自动驾驶
想看雪的瓜6 天前
Origin将2D普通的XPS曲线图升级为三维XPS瀑布图
论文阅读·论文笔记
DuHz6 天前
基于信号分解的FMCW雷达相互干扰抑制——论文阅读
论文阅读·算法·汽车·信息与通信·毫米波雷达
m0_650108247 天前
MiniGPT-4:解锁 LLM 驱动的高级视觉语言能力
论文阅读·开源·视觉语言大模型·minigpt-4·跨模态对齐·强llm+视觉对齐