基于U-Net的轨道缺陷识别(分割)(Pytorch 简单实现)

系统环境

  • Python 3.6
  • pytorch 1.8
  • GPU 1050Ti

网络结构

本文是一个U-net的简单实现过程,首先是两层卷积层,再经过四次下采样,最后经过四次上采样。 代码结构如下

python 复制代码
# unet.py
import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""
    # 两层卷积层
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            # 受一系列的神经网络模块或层作为参数,然后按照它们在参数中的顺序依次应用
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),  # 定义二维卷积  通道数量  卷积核的大小  padding 0填充数量 保证图片大小一致
            nn.BatchNorm2d(out_channels),  # 批归一化 加速训练过程并提高模型的稳定性
            nn.ReLU(inplace=True),  # 激活函数 ReLu
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),        # 二维最大池化层,将输入特征图的大小减半,采样出最显著的特征
            DoubleConv(in_channels, out_channels)  # 上面的卷积
        )

    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):          # bilinear:插值使用双线性插值方法,考虑了相邻四个像素的权重,以生成新像素值。
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)  # 是通过学习的方式进行上采样,它通过卷积操作实现。 stride=2是保证输出特征图的尺寸是输入的两倍。

        self.conv = DoubleConv(in_channels, out_channels)   # 卷积 双层

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        # 计算 x2 和 x1 在高度和宽度上的尺寸差异
        diffY = torch.tensor([x2.size()[2] - x1.size()[2]])  # high
        diffX = torch.tensor([x2.size()[3] - x1.size()[3]])  # width

        # x1 的高度和宽度两侧进行零填充,以使其尺寸与 x2 一致
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # 将 x2 和经过上采样后的 x1 沿着通道维度拼接在一起
        x = torch.cat([x2, x1], dim=1)

        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024)
        self.up1 = Up(1024, 512, bilinear)
        self.up2 = Up(512, 256, bilinear)
        self.up3 = Up(256, 128, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

具体的网络结构如下:

整体上,U-Net 架构通过合并下采样路径和上采样路径的特征图,实现了对图像的端到端分割。下采样路径用于提取图像的高级特征,而上采样路径用于还原分辨率和细节。

txt 复制代码
网络的具体结构,图片展示的可能存在错误
UNet(
  (inc): DoubleConv(
    (double_conv): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down1): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (5): ReLU(inplace=True)
        )
      )
    )
  )
  (down2): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (5): ReLU(inplace=True)
        )
      )
    )
  )
  (down3): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (5): ReLU(inplace=True)
        )
      )
    )
  )
  (down4): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (5): ReLU(inplace=True)
        )
      )
    )
  )
  (up1): Up(
    (up): ConvTranspose2d(1024, 512, kernel_size=(2, 2), stride=(2, 2))
    (conv): DoubleConv(
      (double_conv): Sequential(
        (0): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
  )
  (up2): Up(
    (up): ConvTranspose2d(512, 256, kernel_size=(2, 2), stride=(2, 2))
    (conv): DoubleConv(
      (double_conv): Sequential(
        (0): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
  )
  (up3): Up(
    (up): ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2))
    (conv): DoubleConv(
      (double_conv): Sequential(
        (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
  )
  (up4): Up(
    (up): ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2))
    (conv): DoubleConv(
      (double_conv): Sequential(
        (0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
  )
  (outc): OutConv(
    (conv): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
  )
)

数据准备 Dataloader

用于加载图像数据,并在训练过程中进行数据增强,这样的数据集类可以被用于 PyTorch 的数据加载器(DataLoader)中,用于批量加载和训练深度学习模型。在训练过程中,每个样本都经过图像增强,以增加模型对不同变换的鲁棒性。

Python 复制代码
# datalodaer.py
import torch

import cv2

import os

import glob

from torch.utils.data import Dataset

import random

class SelfDataSet(Dataset):
    def __init__(self, data_path):
        self.data_path = data_path
        self.imgs_path = glob.glob(os.path.join(data_path, '*.jpg'))

    # 将图片进行翻转
    def augment(self, image, flipcode):
        flip = cv2.flip(image, flipcode)
        return flip

    def __getitem__(self, index):
        #读取图片和标签
        image_path = self.imgs_path[index]
        label_path = image_path.replace('image', 'label')
        image = cv2.imread(image_path) #RGB 3通道图片
        label = cv2.imread(label_path)
        # 将数据转为单通道的图片
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        label = cv2.cvtColor(label, cv2.COLOR_BGR2GRAY)
        #label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)
        #对图片进行预处理preprocess
        image = image.reshape(1, image.shape[0], image.shape[1])
        label = label.reshape(1, label.shape[0], label.shape[1])

        if label.max() > 1:
            label = label/255
        #图像增强
        flipcode = random.choice([-1, 0, 1, 2])
        if(flipcode != 2):
            image = self.augment(image, flipcode)
            label = self.augment(label, flipcode)
        return image, label

    def __len__(self):
        return len(self.imgs_path)

训练网络

首先检查是否有可用的 GPU,并将模型移到 GPU 上。然后创建 U-Net 模型实例,指定数据集路径,调用 Train_Unet 函数进行训练,最后关闭日志文件并绘制训练过程中的损失曲线。

Python 复制代码
import torch
import torch.optim
from dataloader import SelfDataSet
from log import Logger
from plot import plot_picture
import os
import torch.nn as nn
import sys
from unet import UNet
from torch.utils.data import Dataset
from torch import optim, utils
import time

Unet_train_txt = Logger('Unet_train.txt')

def Train_Unet(net,device,data_path,batch_size=3,epochs=40,lr=0.0001):
    #加载数据集
    train_dataset = SelfDataSet(data_path)
    train_loader = utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    #定义优化算法
    opt = optim.Adam((net.parameters()))
    #定义损失函数
    loss_fun = nn.BCEWithLogitsLoss()
    bes_los = float('inf')

    for epoch in range(epochs):
        net.train()
        running_loss = 0.0
        i = 0
        begin = time.clock()
        for image, label in train_loader:
            opt.zero_grad()
            image = image.to(device=device, dtype=torch.float32)
            label = label.to(device=device, dtype=torch.float32)
            pred=net(image)
            loss = loss_fun(pred, label)
            loss.backward()
            i = i + 1
            running_loss = running_loss+loss.item()
            opt.step()
        end = time.clock()
        loss_avg_epoch = running_loss/i
        Unet_train_txt.write(str(format(loss_avg_epoch, '.4f')) + '\n')
        print('epoch: %d avg loss: %f time:%d s' % (epoch, loss_avg_epoch, end - begin))
        if loss_avg_epoch < bes_los:
            bes_los = loss_avg_epoch
            state = {'net': net.state_dict(), 'opt': opt.state_dict(), 'epoch': epoch}
            torch.save(state, 'model_pth')



if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = UNet(1, 1,  bilinear=False)
    # print(net)
    net.to(device=device)
    # 对应数据集的路径
    data_path = './rsdd/train/image'    
    Train_Unet(net, device, data_path, epochs=40, batch_size=1)
    Unet_train_txt.close()
    plot_picture('Unet_train.txt')

训练的损失函数

测试过程

Python 复制代码
test.py
import glob
import numpy as np
import torch
import os
import cv2
from unet import UNet

if __name__ == "__main__":
    # 选择设备,有cuda用cuda,没有就用cpu
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # 加载网络,图片单通道,分类为1。
    net = UNet(n_channels=1, n_classes=1,bilinear=False)
    # 将网络拷贝到deivce中
    net.to(device=device)
    # 加载模型参数
    checkpoint = torch.load('model_pth',map_location=device)
    net.load_state_dict(checkpoint['net'])
    # 测试模式
    net.eval()
    # 读取所有图片路径
    Test_Data_path = './rsdd/test'
    tests_path = glob.glob(os.path.join(Test_Data_path, '*.jpg'))
    # 遍历所有图片
    for test_path in tests_path:
        print(test_path.split('.')[1])
        # 保存结果地址
        file_name, file_extension = os.path.splitext(test_path)
        save_res_path = f"{file_name}_res.jpg"
        #save_res_path = test_path.split('.')[0] + '_res.jpg'
        # 读取图片
        img = cv2.imread(test_path)
        # 转为灰度图单通道
        img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        # pytorch要求的格式(batch_size,c,w,h)
        img = img.reshape(1, 1, img.shape[0], img.shape[1])
        # 转为tensor
        img_tensor = torch.from_numpy(img)
        img_tensor = img_tensor.to(device=device, dtype=torch.float32)
        # 预测
        pred = net(img_tensor)
        # 提取结果
        pred = np.array(pred.data.cpu()[0])[0]
        # 处理结果
        pred[pred >= 0.5] = 255
        pred[pred < 0.5] = 0
        # 保存图片
        cv2.imwrite(save_res_path, pred)

其他工具函数

log.py 把终端输出的损失函数值记录下来,保存到本地

Python 复制代码
# log.py
import sys

class Logger():
    def __init__(self, filename='log.txt'):
        self.terminal = sys.stdout
        self.log = open(filename, 'w')

    def write(self,message):
        #输出到STDOUT终端
        #self.terminal.write(message)
        #重定向到在指定文件
        self.log.write(message)

    def flush(self):
        pass
    def close(self):
        self.log.close()

plot.py 绘图函数,将损失函数的值绘制成图像

Python 复制代码
# plot.py
import matplotlib.pyplot as plt


# writer = SummaryWriter(comment='_Unet')
# for i in range(10):
#     writer.add_scalar('var', i**2, global_step=i)
#
# writer.close()


def plot_picture(filename):
    with open(filename, 'r') as f:
        train_loss = f.readlines()
        train_loss = list(map(lambda x: float(x.strip()), train_loss))
    x = range(len(train_loss))
    y = train_loss
    plt.plot(x, y, label='train loss', linewidth=2, color='r', marker='o', markerfacecolor='r', markersize=5)
    plt.xlabel('epoch')
    plt.ylabel('loss value')
    plt.legend()
    plt.show()

结果展示

图1 原图

图2 训练label

图3 测试结果

超参值

  • batch_size=1
  • epochs=40
  • lr=0.0001
相关推荐
Kalika0-01 小时前
猴子吃桃-C语言
c语言·开发语言·数据结构·算法
sp_fyf_20241 小时前
计算机前沿技术-人工智能算法-大语言模型-最新研究进展-2024-10-02
人工智能·神经网络·算法·计算机视觉·语言模型·自然语言处理·数据挖掘
我是哈哈hh3 小时前
专题十_穷举vs暴搜vs深搜vs回溯vs剪枝_二叉树的深度优先搜索_算法专题详细总结
服务器·数据结构·c++·算法·机器学习·深度优先·剪枝
Tisfy3 小时前
LeetCode 2187.完成旅途的最少时间:二分查找
算法·leetcode·二分查找·题解·二分
Mephisto.java3 小时前
【力扣 | SQL题 | 每日四题】力扣2082, 2084, 2072, 2112, 180
sql·算法·leetcode
robin_suli3 小时前
滑动窗口->dd爱框框
算法
丶Darling.3 小时前
LeetCode Hot100 | Day1 | 二叉树:二叉树的直径
数据结构·c++·学习·算法·leetcode·二叉树
labuladuo5204 小时前
Codeforces Round 977 (Div. 2) C2 Adjust The Presentation (Hard Version)(思维,set)
数据结构·c++·算法
jiyisuifeng19914 小时前
代码随想录训练营第54天|单调栈+双指针
数据结构·算法
꧁༺❀氯ྀൢ躅ྀൢ❀༻꧂4 小时前
实验4 循环结构
c语言·算法·基础题