比如改一个多输入模块,我们需要记录输入1的通道,输入2的通道,Conv_reduce的输入通道
YOLO中这个模块接受层1和层2的作为输入,那么层1和层2的输出通道肯定是知道的,所以现在只需要在yaml里面标记整个模块的输出通道即可。
python
class AF(nn.Module):
def __init__(self,c1,c2,dim1,dim2):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv_atten = nn.Sequential(
nn.Conv2d(c1, c1,1),
nn.Sigmoid()
)
self.conv_redu = nn.Conv2d(c1, c2, kernel_size=1, bias=False)
self.conv1 = nn.Conv2d(dim1, 1, 1, 1)
self.conv2 = nn.Conv2d(dim2, 1, 1, 1)
self.nonlin = nn.Sigmoid()
def forward(self, x):
output = torch.cat(x,1)
att = self.conv_atten(self.avg_pool(output))
#print(att.shape)
output = output * att
output = self.conv_redu(output)
#print(output.shape)
att = self.conv1(x[0]) + self.conv2(x[1])
att = self.nonlin(att)
#print(att.shape)
output = output * att
return output
html
- [[-1, 6], 1, AF, [32]] # cat backbone P4
例如这条yaml,接受第6层和上一层的输入,输出通道数为32。这里参数为什么是一个?因为这里只需要给出输出通道数即可,其他参数可以再网络的记录中得到。
python
elif m is AF:
c1 = sum(ch[x] for x in f)
c3 = ch[f[0]]
c4 = ch[f[1]]
c2 = args[0]
args = [c1,c2,c3,c4]
print(args)
f是一个表表示来自那一层,这里的f里面就保存的内容相当于【-1,6】的索引,ch是每一层的输出通道数,ch[层索引]不就得到某层的输出通道了。这里随便借助一个中间变量,c1,c2,c3,c4,记录参数后,合成列表【c1,c2,c3,c4】
python
torch.nn.Sequential(*(m(*args))
m相当于类名称,加入类名为AF,不就相当于AF(c1,c2,c3,c4)吗