模块出处
NC 25\] [\[link\]](https://doi.org/10.1016/j.neucom.2024.128949) Graph-based context learning network for infrared small target detection *** ** * ** *** ##### 模块名称 Patch-based Graph Reasoning (PGR) *** ** * ** *** ##### 模块结构  *** ** * ** *** ##### 模块特点 * 使用图结构更好的捕捉特征的全局上下文 * 将图结构与特征切片(Patching)相结合,从而促进全局/局部特征互补 *** ** * ** *** ##### 模块代码 ```python import torch import torch.nn as nn import torch.nn.functional as F class graph(nn.Module): def __init__(self, p2=4, nIn=64, N=16): super(graph, self).__init__() self.p2 = p2 self.N = N self.conv30 = nn.Sequential( nn.Conv2d(nIn, self.N, kernel_size=3, stride=1, padding=1, groups=1), nn.ReLU(inplace=True) ) self.conv10 = nn.Sequential( nn.Conv1d(nIn, nIn, kernel_size=1, stride=1, padding=0), nn.ReLU(inplace=True) ) self.conv11 = nn.Sequential( nn.Conv1d(self.N, self.N, kernel_size=1, stride=1, padding=0), nn.ReLU(inplace=True) ) self.adaptivemax = nn.AdaptiveAvgPool2d((8, 8)) self.conv12 = nn.Sequential( nn.Conv1d(p2 ** 2, p2, kernel_size=1, stride=1, padding=0), nn.ReLU(inplace=True), nn.Conv1d(p2, p2, kernel_size=1, stride=1, padding=0), nn.ReLU(inplace=True), nn.Conv1d(p2, p2 ** 2, kernel_size=1, stride=1, padding=0), nn.Sigmoid() ) def ADP_weight(self, x): b, C, H, W = x.shape fg = self.adaptivemax(x) fg1 = fg.view(b, C, self.p2 ** 2) fg1 = torch.transpose(fg1, 1, 2) fg2 = self.conv12(fg1) fg3 = fg2.unsqueeze(-1).unsqueeze(-1) return fg3 def graph_convolution(self, fs, x): b, C, H, W = x.shape h, w = H // self.p2, W // self.p2 L = h * w B = self.conv30(fs) B1 = B.view(-1, self.N, L) fs1 = fs.view(-1, C, L) fs1 = torch.transpose(fs1, 1, 2) fs2 = torch.bmm(B1, fs1) fs3 = self.conv11(fs2) fs5 = self.conv10(torch.transpose(fs3, 1, 2)) fs6 = torch.bmm(torch.transpose(B1, 1, 2), torch.transpose(fs5, 1, 2)) fs6 = torch.transpose(fs6, 1, 2) fs6 = fs6.view(b, self.p2 ** 2, C, h, w) return fs6 def forward(self, fs, x): fs6 = self.graph_convolution(fs, x) weight = self.ADP_weight(x) out = weight * fs6 return out class PGR(nn.Module): def __init__(self, p2=4, nIn=32, nOut=32, add=True): super(PGR, self).__init__() self.p2 = p2 self.N = nIn // 4 self.add = add self.graph0 = graph(p2, nIn, self.N) self.conv31 = nn.Sequential( nn.Conv2d(nOut, nOut, kernel_size=1, stride=1), nn.BatchNorm2d(nOut), nn.ReLU(inplace=True) ) def forward(self, x): b, C, H, W = x.shape h, w = H // self.p2, W // self.p2 L = h * w fs = torch.zeros((b, self.p2 ** 2, C, h, w)).cuda() for i in range(1, self.p2 + 1): for j in range(1, self.p2 + 1): fs[:, i * j - 1, :, :, :] = x[:, :, (i - 1) * h: i * h, (j - 1) * w: j * w] fs = fs.view(b * self.p2 ** 2, C, h, w) fs6 = self.graph0(fs, x) out = torch.zeros_like(x) for i in range(1, self.p2 + 1): for j in range(1, self.p2 + 1): out[:, :, (i - 1) * h: i * h, (j - 1) * w: j * w] = fs6[:, i * j - 1, :, :, :] out = self.conv31(out) if self.add: out = out + x return out if __name__ == '__main__': x = torch.randn([1, 64, 44, 44]).cuda() pgr = PGR(p2=8, nIn=64, nOut=64).cuda() out = pgr(x) print(out.shape) # [1, 64, 44, 44] ``` *** ** * ** ***