系统环境
- Python 3.6
- pytorch 1.8
- GPU 1050Ti
网络结构
本文是一个U-net的简单实现过程,首先是两层卷积层,再经过四次下采样,最后经过四次上采样。 代码结构如下
python
# unet.py
import torch
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
# 两层卷积层
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
# 受一系列的神经网络模块或层作为参数,然后按照它们在参数中的顺序依次应用
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), # 定义二维卷积 通道数量 卷积核的大小 padding 0填充数量 保证图片大小一致
nn.BatchNorm2d(out_channels), # 批归一化 加速训练过程并提高模型的稳定性
nn.ReLU(inplace=True), # 激活函数 ReLu
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class Down(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2), # 二维最大池化层,将输入特征图的大小减半,采样出最显著的特征
DoubleConv(in_channels, out_channels) # 上面的卷积
)
def forward(self, x):
return self.maxpool_conv(x)
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True): # bilinear:插值使用双线性插值方法,考虑了相邻四个像素的权重,以生成新像素值。
super().__init__()
# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
else:
self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2) # 是通过学习的方式进行上采样,它通过卷积操作实现。 stride=2是保证输出特征图的尺寸是输入的两倍。
self.conv = DoubleConv(in_channels, out_channels) # 卷积 双层
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
# 计算 x2 和 x1 在高度和宽度上的尺寸差异
diffY = torch.tensor([x2.size()[2] - x1.size()[2]]) # high
diffX = torch.tensor([x2.size()[3] - x1.size()[3]]) # width
# x1 的高度和宽度两侧进行零填充,以使其尺寸与 x2 一致
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
# 将 x2 和经过上采样后的 x1 沿着通道维度拼接在一起
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.conv(x)
class UNet(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=True):
super(UNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
self.down4 = Down(512, 1024)
self.up1 = Up(1024, 512, bilinear)
self.up2 = Up(512, 256, bilinear)
self.up3 = Up(256, 128, bilinear)
self.up4 = Up(128, 64, bilinear)
self.outc = OutConv(64, n_classes)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
具体的网络结构如下:
整体上,U-Net 架构通过合并下采样路径和上采样路径的特征图,实现了对图像的端到端分割。下采样路径用于提取图像的高级特征,而上采样路径用于还原分辨率和细节。
txt
网络的具体结构,图片展示的可能存在错误
UNet(
(inc): DoubleConv(
(double_conv): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
)
(down1): Down(
(maxpool_conv): Sequential(
(0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(1): DoubleConv(
(double_conv): Sequential(
(0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
)
)
)
(down2): Down(
(maxpool_conv): Sequential(
(0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(1): DoubleConv(
(double_conv): Sequential(
(0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
)
)
)
(down3): Down(
(maxpool_conv): Sequential(
(0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(1): DoubleConv(
(double_conv): Sequential(
(0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
)
)
)
(down4): Down(
(maxpool_conv): Sequential(
(0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(1): DoubleConv(
(double_conv): Sequential(
(0): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
)
)
)
(up1): Up(
(up): ConvTranspose2d(1024, 512, kernel_size=(2, 2), stride=(2, 2))
(conv): DoubleConv(
(double_conv): Sequential(
(0): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
)
)
(up2): Up(
(up): ConvTranspose2d(512, 256, kernel_size=(2, 2), stride=(2, 2))
(conv): DoubleConv(
(double_conv): Sequential(
(0): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
)
)
(up3): Up(
(up): ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2))
(conv): DoubleConv(
(double_conv): Sequential(
(0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
)
)
(up4): Up(
(up): ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2))
(conv): DoubleConv(
(double_conv): Sequential(
(0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
)
)
(outc): OutConv(
(conv): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
)
)
数据准备 Dataloader
用于加载图像数据,并在训练过程中进行数据增强,这样的数据集类可以被用于 PyTorch 的数据加载器(DataLoader
)中,用于批量加载和训练深度学习模型。在训练过程中,每个样本都经过图像增强,以增加模型对不同变换的鲁棒性。
Python
# datalodaer.py
import torch
import cv2
import os
import glob
from torch.utils.data import Dataset
import random
class SelfDataSet(Dataset):
def __init__(self, data_path):
self.data_path = data_path
self.imgs_path = glob.glob(os.path.join(data_path, '*.jpg'))
# 将图片进行翻转
def augment(self, image, flipcode):
flip = cv2.flip(image, flipcode)
return flip
def __getitem__(self, index):
#读取图片和标签
image_path = self.imgs_path[index]
label_path = image_path.replace('image', 'label')
image = cv2.imread(image_path) #RGB 3通道图片
label = cv2.imread(label_path)
# 将数据转为单通道的图片
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
label = cv2.cvtColor(label, cv2.COLOR_BGR2GRAY)
#label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)
#对图片进行预处理preprocess
image = image.reshape(1, image.shape[0], image.shape[1])
label = label.reshape(1, label.shape[0], label.shape[1])
if label.max() > 1:
label = label/255
#图像增强
flipcode = random.choice([-1, 0, 1, 2])
if(flipcode != 2):
image = self.augment(image, flipcode)
label = self.augment(label, flipcode)
return image, label
def __len__(self):
return len(self.imgs_path)
训练网络
首先检查是否有可用的 GPU,并将模型移到 GPU 上。然后创建 U-Net 模型实例,指定数据集路径,调用 Train_Unet
函数进行训练,最后关闭日志文件并绘制训练过程中的损失曲线。
Python
import torch
import torch.optim
from dataloader import SelfDataSet
from log import Logger
from plot import plot_picture
import os
import torch.nn as nn
import sys
from unet import UNet
from torch.utils.data import Dataset
from torch import optim, utils
import time
Unet_train_txt = Logger('Unet_train.txt')
def Train_Unet(net,device,data_path,batch_size=3,epochs=40,lr=0.0001):
#加载数据集
train_dataset = SelfDataSet(data_path)
train_loader = utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
#定义优化算法
opt = optim.Adam((net.parameters()))
#定义损失函数
loss_fun = nn.BCEWithLogitsLoss()
bes_los = float('inf')
for epoch in range(epochs):
net.train()
running_loss = 0.0
i = 0
begin = time.clock()
for image, label in train_loader:
opt.zero_grad()
image = image.to(device=device, dtype=torch.float32)
label = label.to(device=device, dtype=torch.float32)
pred=net(image)
loss = loss_fun(pred, label)
loss.backward()
i = i + 1
running_loss = running_loss+loss.item()
opt.step()
end = time.clock()
loss_avg_epoch = running_loss/i
Unet_train_txt.write(str(format(loss_avg_epoch, '.4f')) + '\n')
print('epoch: %d avg loss: %f time:%d s' % (epoch, loss_avg_epoch, end - begin))
if loss_avg_epoch < bes_los:
bes_los = loss_avg_epoch
state = {'net': net.state_dict(), 'opt': opt.state_dict(), 'epoch': epoch}
torch.save(state, 'model_pth')
if __name__ == '__main__':
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = UNet(1, 1, bilinear=False)
# print(net)
net.to(device=device)
# 对应数据集的路径
data_path = './rsdd/train/image'
Train_Unet(net, device, data_path, epochs=40, batch_size=1)
Unet_train_txt.close()
plot_picture('Unet_train.txt')
训练的损失函数
测试过程
Python
test.py
import glob
import numpy as np
import torch
import os
import cv2
from unet import UNet
if __name__ == "__main__":
# 选择设备,有cuda用cuda,没有就用cpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载网络,图片单通道,分类为1。
net = UNet(n_channels=1, n_classes=1,bilinear=False)
# 将网络拷贝到deivce中
net.to(device=device)
# 加载模型参数
checkpoint = torch.load('model_pth',map_location=device)
net.load_state_dict(checkpoint['net'])
# 测试模式
net.eval()
# 读取所有图片路径
Test_Data_path = './rsdd/test'
tests_path = glob.glob(os.path.join(Test_Data_path, '*.jpg'))
# 遍历所有图片
for test_path in tests_path:
print(test_path.split('.')[1])
# 保存结果地址
file_name, file_extension = os.path.splitext(test_path)
save_res_path = f"{file_name}_res.jpg"
#save_res_path = test_path.split('.')[0] + '_res.jpg'
# 读取图片
img = cv2.imread(test_path)
# 转为灰度图单通道
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
# pytorch要求的格式(batch_size,c,w,h)
img = img.reshape(1, 1, img.shape[0], img.shape[1])
# 转为tensor
img_tensor = torch.from_numpy(img)
img_tensor = img_tensor.to(device=device, dtype=torch.float32)
# 预测
pred = net(img_tensor)
# 提取结果
pred = np.array(pred.data.cpu()[0])[0]
# 处理结果
pred[pred >= 0.5] = 255
pred[pred < 0.5] = 0
# 保存图片
cv2.imwrite(save_res_path, pred)
其他工具函数
log.py 把终端输出的损失函数值记录下来,保存到本地
Python
# log.py
import sys
class Logger():
def __init__(self, filename='log.txt'):
self.terminal = sys.stdout
self.log = open(filename, 'w')
def write(self,message):
#输出到STDOUT终端
#self.terminal.write(message)
#重定向到在指定文件
self.log.write(message)
def flush(self):
pass
def close(self):
self.log.close()
plot.py 绘图函数,将损失函数的值绘制成图像
Python
# plot.py
import matplotlib.pyplot as plt
# writer = SummaryWriter(comment='_Unet')
# for i in range(10):
# writer.add_scalar('var', i**2, global_step=i)
#
# writer.close()
def plot_picture(filename):
with open(filename, 'r') as f:
train_loss = f.readlines()
train_loss = list(map(lambda x: float(x.strip()), train_loss))
x = range(len(train_loss))
y = train_loss
plt.plot(x, y, label='train loss', linewidth=2, color='r', marker='o', markerfacecolor='r', markersize=5)
plt.xlabel('epoch')
plt.ylabel('loss value')
plt.legend()
plt.show()
结果展示
图1 原图
图2 训练label
图3 测试结果
超参值
- batch_size=1
- epochs=40
- lr=0.0001