模块出处
TIP 21\] [\[link\]](https://ieeexplore.ieee.org/abstract/document/9563125) CoANet: Connectivity Attention Network for Road Extraction From Satellite Imagery *** ** * ** *** ##### 模块名称 Strip Convolution Block (SCB) *** ** * ** *** ##### 模块作用 多方向条形特征提取 *** ** * ** *** ##### 模块结构  *** ** * ** *** ##### 模块特点 * 类PSP设计,采用四个并行分支提取不同维度的信息 * 相比于经典的横向/纵向条形卷积,引入了两种斜方向的卷积来更好的学习斜向线条 *** ** * ** *** ##### 模块代码 ```python import torch import torch.nn as nn import torch.nn.functional as F class SCB(nn.Module): def __init__(self, in_channels, n_filters): super(SCB, self).__init__() self.conv1 = nn.Conv2d(in_channels, in_channels // 4, 1) self.bn1 = nn.BatchNorm2d(in_channels // 4) self.relu1 = nn.ReLU() self.deconv1 = nn.Conv2d( in_channels // 4, in_channels // 8, (1, 9), padding=(0, 4) ) self.deconv2 = nn.Conv2d( in_channels // 4, in_channels // 8, (9, 1), padding=(4, 0) ) self.deconv3 = nn.Conv2d( in_channels // 4, in_channels // 8, (9, 1), padding=(4, 0) ) self.deconv4 = nn.Conv2d( in_channels // 4, in_channels // 8, (1, 9), padding=(0, 4) ) self.bn2 = nn.BatchNorm2d(in_channels // 4 + in_channels // 4) self.relu2 = nn.ReLU() self.conv3 = nn.Conv2d( in_channels // 4 + in_channels // 4, n_filters, 1) self.bn3 = nn.BatchNorm2d(n_filters) self.relu3 = nn.ReLU() def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu1(x) x1 = self.deconv1(x) x2 = self.deconv2(x) x3 = self.inv_h_transform(self.deconv3(self.h_transform(x))) x4 = self.inv_v_transform(self.deconv4(self.v_transform(x))) x = torch.cat((x1, x2, x3, x4), 1) x = self.bn2(x) x = self.relu2(x) x = self.conv3(x) x = self.bn3(x) x = self.relu3(x) return x def h_transform(self, x): shape = x.size() x = torch.nn.functional.pad(x, (0, shape[-1])) x = x.reshape(shape[0], shape[1], -1)[..., :-shape[-1]] x = x.reshape(shape[0], shape[1], shape[2], 2*shape[3]-1) return x def inv_h_transform(self, x): shape = x.size() x = x.reshape(shape[0], shape[1], -1).contiguous() x = torch.nn.functional.pad(x, (0, shape[-2])) x = x.reshape(shape[0], shape[1], shape[-2], 2*shape[-2]) x = x[..., 0: shape[-2]] return x def v_transform(self, x): x = x.permute(0, 1, 3, 2) shape = x.size() x = torch.nn.functional.pad(x, (0, shape[-1])) x = x.reshape(shape[0], shape[1], -1)[..., :-shape[-1]] x = x.reshape(shape[0], shape[1], shape[2], 2*shape[3]-1) return x.permute(0, 1, 3, 2) def inv_v_transform(self, x): x = x.permute(0, 1, 3, 2) shape = x.size() x = x.reshape(shape[0], shape[1], -1) x = torch.nn.functional.pad(x, (0, shape[-2])) x = x.reshape(shape[0], shape[1], shape[-2], 2*shape[-2]) x = x[..., 0: shape[-2]] return x.permute(0, 1, 3, 2) if __name__ == '__main__': x = torch.randn([1, 64, 44, 44]) scb = SCB(in_channels=64, n_filters=64) out = scb(x) print(out.shape) # [1, 64, 44, 44] ```