Pytorch语义分割(2)--------模型搭建

经典的模型还是Unet,也可以使用torch自带的unet来训练,但为了更好地了解,还是选择自己搭建。

unet.py:

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F


class Up(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(Up, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, 3, 1, 1),
            nn.BatchNorm2d(out_channel),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.block(x)
        out = F.interpolate(x, scale_factor=2)
        return out


class Down(nn.Module):
    def __init__(self, in_channel, out_channel, stride=2):
        super(Down, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, 3, stride, 1),
            nn.BatchNorm2d(out_channel),
            nn.ReLU()
        )

    def forward(self, x):
        return self.block(x)


class UpConcat(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(UpConcat, self).__init__()
        self.up = nn.Upsample(scale_factor=2)
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channel+out_channel, out_channel, kernel_size=3, padding=1),
            nn.ReLU6(inplace=True),
            nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1),
            nn.ReLU6(inplace=True),
        )

    def forward(self, in_map1, in_map2):
        in_map2 = self.up(in_map2)
        out = torch.cat([in_map1, in_map2], dim=1)
        return self.conv2(out)


class MainNet(nn.Module):
    def __init__(self, num_classes):
        super(MainNet, self).__init__()
        self.down1 = Down(3, 64, stride=1)
        self.down2 = Down(64, 128)
        self.down3 = Down(128, 256)
        self.down4 = Down(256, 512)
        self.down5 = Down(512, 1024)

        # self.conv = nn.Conv2d(1024, 512, 3, 1, 1)

        self.up5concat = UpConcat(1024, 512)
        self.up4concat = UpConcat(512, 256)
        self.up3concat = UpConcat(256, 128)
        self.up2concat = UpConcat(128, 64)

        self.head = nn.Sequential(
            nn.Conv2d(64, num_classes, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        feat1 = self.down1(x)       # 3, 512, 512 ---->64, 512, 512
        feat2 = self.down2(feat1)   # 64, 512, 512 ---->128, 256, 256
        feat3 = self.down3(feat2)   # 128, 256, 256 ---->256,128,128
        feat4 = self.down4(feat3)   # 256,128,128 ---> 512,64,64
        feat5 = self.down5(feat4)   # 512,64,64 ----> 1024,32,32
        print("feat5:", feat5.shape)
        # feat5 = self.conv(feat5)

        feat4_up = self.up5concat(feat4, feat5)
        print("feat4_up:", feat4_up.shape)
        feat3_up = self.up4concat(feat3, feat4_up)
        feat2_up = self.up3concat(feat2, feat3_up)
        feat1_up = self.up2concat(feat1, feat2_up)
        print("feat1_up:", feat1_up.shape)

        print(feat1_up.shape, feat2_up.shape, feat3_up.shape, feat4_up.shape)
        return self.head(feat1_up)


if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tensor = torch.zeros((1, 3, 512, 512)).to(device)
    model = MainNet(num_classes=3).to(device)

    # print(model)
    # model.apply(inplace_relu)

    out = model(tensor)
    # print(out.shape)
    #
    from torchsummary import torchsummary
    torchsummary.summary(model, (3, 512, 512))
    # # from torchstat import stat
    # # stat(model, (3, 512, 512))
    # from thop import profile
    #
    # flops, params = profile(model, inputs=(tensor,))
    #
    # print("FLOPs=", str(flops / 1e9) + '{}'.format("G"))
    # print("params=", str(params / 1e6) + '{}'.format("M"))
    #
    # #FLOPs= 63.406604288G
    # # params= 14.127683M
相关推荐
qq_5290252916 分钟前
Torch.gather
python·深度学习·机器学习
IT古董1 小时前
【漫话机器学习系列】017.大O算法(Big-O Notation)
人工智能·机器学习
凯哥是个大帅比1 小时前
人工智能ACA(五)--深度学习基础
人工智能·深度学习
m0_748232921 小时前
DALL-M:基于大语言模型的上下文感知临床数据增强方法 ,补充
人工智能·语言模型·自然语言处理
szxinmai主板定制专家1 小时前
【国产NI替代】基于FPGA的32通道(24bits)高精度终端采集核心板卡
大数据·人工智能·fpga开发
海棠AI实验室1 小时前
AI的进阶之路:从机器学习到深度学习的演变(三)
人工智能·深度学习·机器学习
机器懒得学习1 小时前
基于YOLOv5的智能水域监测系统:从目标检测到自动报告生成
人工智能·yolo·目标检测
QQ同步助手2 小时前
如何正确使用人工智能:开启智慧学习与创新之旅
人工智能·学习·百度
AIGC大时代2 小时前
如何使用ChatGPT辅助文献综述,以及如何进行优化?一篇说清楚
人工智能·深度学习·chatgpt·prompt·aigc