Pytorch语义分割(2)--------模型搭建

经典的模型还是Unet,也可以使用torch自带的unet来训练,但为了更好地了解,还是选择自己搭建。

unet.py:

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F


class Up(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(Up, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, 3, 1, 1),
            nn.BatchNorm2d(out_channel),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.block(x)
        out = F.interpolate(x, scale_factor=2)
        return out


class Down(nn.Module):
    def __init__(self, in_channel, out_channel, stride=2):
        super(Down, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, 3, stride, 1),
            nn.BatchNorm2d(out_channel),
            nn.ReLU()
        )

    def forward(self, x):
        return self.block(x)


class UpConcat(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(UpConcat, self).__init__()
        self.up = nn.Upsample(scale_factor=2)
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channel+out_channel, out_channel, kernel_size=3, padding=1),
            nn.ReLU6(inplace=True),
            nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1),
            nn.ReLU6(inplace=True),
        )

    def forward(self, in_map1, in_map2):
        in_map2 = self.up(in_map2)
        out = torch.cat([in_map1, in_map2], dim=1)
        return self.conv2(out)


class MainNet(nn.Module):
    def __init__(self, num_classes):
        super(MainNet, self).__init__()
        self.down1 = Down(3, 64, stride=1)
        self.down2 = Down(64, 128)
        self.down3 = Down(128, 256)
        self.down4 = Down(256, 512)
        self.down5 = Down(512, 1024)

        # self.conv = nn.Conv2d(1024, 512, 3, 1, 1)

        self.up5concat = UpConcat(1024, 512)
        self.up4concat = UpConcat(512, 256)
        self.up3concat = UpConcat(256, 128)
        self.up2concat = UpConcat(128, 64)

        self.head = nn.Sequential(
            nn.Conv2d(64, num_classes, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        feat1 = self.down1(x)       # 3, 512, 512 ---->64, 512, 512
        feat2 = self.down2(feat1)   # 64, 512, 512 ---->128, 256, 256
        feat3 = self.down3(feat2)   # 128, 256, 256 ---->256,128,128
        feat4 = self.down4(feat3)   # 256,128,128 ---> 512,64,64
        feat5 = self.down5(feat4)   # 512,64,64 ----> 1024,32,32
        print("feat5:", feat5.shape)
        # feat5 = self.conv(feat5)

        feat4_up = self.up5concat(feat4, feat5)
        print("feat4_up:", feat4_up.shape)
        feat3_up = self.up4concat(feat3, feat4_up)
        feat2_up = self.up3concat(feat2, feat3_up)
        feat1_up = self.up2concat(feat1, feat2_up)
        print("feat1_up:", feat1_up.shape)

        print(feat1_up.shape, feat2_up.shape, feat3_up.shape, feat4_up.shape)
        return self.head(feat1_up)


if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tensor = torch.zeros((1, 3, 512, 512)).to(device)
    model = MainNet(num_classes=3).to(device)

    # print(model)
    # model.apply(inplace_relu)

    out = model(tensor)
    # print(out.shape)
    #
    from torchsummary import torchsummary
    torchsummary.summary(model, (3, 512, 512))
    # # from torchstat import stat
    # # stat(model, (3, 512, 512))
    # from thop import profile
    #
    # flops, params = profile(model, inputs=(tensor,))
    #
    # print("FLOPs=", str(flops / 1e9) + '{}'.format("G"))
    # print("params=", str(params / 1e6) + '{}'.format("M"))
    #
    # #FLOPs= 63.406604288G
    # # params= 14.127683M
相关推荐
再不会python就不礼貌了21 分钟前
Ollama 0.4 发布!支持 Llama 3.2 Vision,实现多模态 RAG
人工智能·学习·机器学习·ai·开源·产品经理·llama
大大大反派24 分钟前
ONLYOFFICE 8.2深度测评:集成PDF编辑、数据可视化与AI功能的强大办公套件
人工智能·信息可视化·pdf
DK2215139 分钟前
机器学习系列-----主成分分析(PCA)
人工智能·算法·机器学习
SmallBambooCode2 小时前
【人工智能】阿里云PAI平台DSW实例一键安装Python脚本
linux·人工智能·python·阿里云·debian·脚本·模型训练
顾京2 小时前
基于扩散模型的表单插补
人工智能·深度学习·算法
NoneCoder2 小时前
AI时代IDE解析
ide·人工智能
狂奔solar2 小时前
yelp数据集上试验SVD,SVDPP,PMF,NMF 推荐算法
人工智能·机器学习·推荐算法
武子康2 小时前
大数据-216 数据挖掘 机器学习理论 - KMeans 基于轮廓系数来选择 n_clusters
大数据·人工智能·机器学习·数据挖掘·回归·scikit-learn·kmeans
liupenglove2 小时前
ElasticSearch向量检索技术方案介绍
大数据·人工智能·深度学习·elasticsearch·搜索引擎·自动驾驶
黄焖鸡能干四碗3 小时前
【系统文档】系统安全保障措施,安全运营保障,系统应急预案,系统验收相关资料(word原件)
大数据·人工智能·需求分析·软件需求·规格说明书