经典的模型还是Unet,也可以使用torch自带的unet来训练,但为了更好地了解,还是选择自己搭建。
python
import torch
import torch.nn as nn
import torch.nn.functional as F
class Up(nn.Module):
def __init__(self, in_channel, out_channel):
super(Up, self).__init__()
self.block = nn.Sequential(
nn.Conv2d(in_channel, out_channel, 3, 1, 1),
nn.BatchNorm2d(out_channel),
nn.ReLU()
)
def forward(self, x):
x = self.block(x)
out = F.interpolate(x, scale_factor=2)
return out
class Down(nn.Module):
def __init__(self, in_channel, out_channel, stride=2):
super(Down, self).__init__()
self.block = nn.Sequential(
nn.Conv2d(in_channel, out_channel, 3, stride, 1),
nn.BatchNorm2d(out_channel),
nn.ReLU()
)
def forward(self, x):
return self.block(x)
class UpConcat(nn.Module):
def __init__(self, in_channel, out_channel):
super(UpConcat, self).__init__()
self.up = nn.Upsample(scale_factor=2)
self.conv2 = nn.Sequential(
nn.Conv2d(in_channel+out_channel, out_channel, kernel_size=3, padding=1),
nn.ReLU6(inplace=True),
nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1),
nn.ReLU6(inplace=True),
)
def forward(self, in_map1, in_map2):
in_map2 = self.up(in_map2)
out = torch.cat([in_map1, in_map2], dim=1)
return self.conv2(out)
class MainNet(nn.Module):
def __init__(self, num_classes):
super(MainNet, self).__init__()
self.down1 = Down(3, 64, stride=1)
self.down2 = Down(64, 128)
self.down3 = Down(128, 256)
self.down4 = Down(256, 512)
self.down5 = Down(512, 1024)
# self.conv = nn.Conv2d(1024, 512, 3, 1, 1)
self.up5concat = UpConcat(1024, 512)
self.up4concat = UpConcat(512, 256)
self.up3concat = UpConcat(256, 128)
self.up2concat = UpConcat(128, 64)
self.head = nn.Sequential(
nn.Conv2d(64, num_classes, 1),
nn.Sigmoid()
)
def forward(self, x):
feat1 = self.down1(x) # 3, 512, 512 ---->64, 512, 512
feat2 = self.down2(feat1) # 64, 512, 512 ---->128, 256, 256
feat3 = self.down3(feat2) # 128, 256, 256 ---->256,128,128
feat4 = self.down4(feat3) # 256,128,128 ---> 512,64,64
feat5 = self.down5(feat4) # 512,64,64 ----> 1024,32,32
print("feat5:", feat5.shape)
# feat5 = self.conv(feat5)
feat4_up = self.up5concat(feat4, feat5)
print("feat4_up:", feat4_up.shape)
feat3_up = self.up4concat(feat3, feat4_up)
feat2_up = self.up3concat(feat2, feat3_up)
feat1_up = self.up2concat(feat1, feat2_up)
print("feat1_up:", feat1_up.shape)
print(feat1_up.shape, feat2_up.shape, feat3_up.shape, feat4_up.shape)
return self.head(feat1_up)
if __name__ == '__main__':
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tensor = torch.zeros((1, 3, 512, 512)).to(device)
model = MainNet(num_classes=3).to(device)
# print(model)
# model.apply(inplace_relu)
out = model(tensor)
# print(out.shape)
#
from torchsummary import torchsummary
torchsummary.summary(model, (3, 512, 512))
# # from torchstat import stat
# # stat(model, (3, 512, 512))
# from thop import profile
#
# flops, params = profile(model, inputs=(tensor,))
#
# print("FLOPs=", str(flops / 1e9) + '{}'.format("G"))
# print("params=", str(params / 1e6) + '{}'.format("M"))
#
# #FLOPs= 63.406604288G
# # params= 14.127683M