手动实现Unet并在Carvana数据集上进行训练

1. Unet

语义分割(Semantic Segmentation)是图像处理和机器视觉一个重要分支。与分类任务不同,语义分割需要判断图像每个像素点的类别,进行精确分割。语义分割目前在自动驾驶、自动抠图、医疗影像等领域有着比较广泛的应用。

Unet可以说是最常用、最简单的一种分割模型了,它简单、高效、易懂、容易构建、可以从小数据集中训练。

论文地址:https://arxiv.org/abs/1505.04597

1.1 提出初衷

  • Unet提出的初衷是为了解决医学图像分割的问题;
  • 一种U型的网络结构来获取上下文的信息和位置信息;
  • 在2015年的ISBI cell tracking比赛中获得了多个第一,一开始这是为了解决细胞层面的分割的任务的

1.2 网络结构

这个结构就是先对图片进行卷积和池化,在Unet论文中是池化4次,比方说一开始的图片是224x224的,那么就会变成112x112,56x56,28x28,14x14四个不同尺寸的特征。然后我们对14x14的特征图做上采样或者反卷积,得到28x28的特征图,这个28x28的特征图与之前的28x28的特征图进行通道伤的拼接concat,然后再对拼接之后的特征图做卷积和上采样,得到56x56的特征图,再与之前的56x56的特征拼接,卷积,再上采样,经过四次上采样可以得到一个与输入图像尺寸相同的224x224的预测结果。

其实整体来看,这个也是一个Encoder-Decoder的结构:前半部分就是特征提取,后半部分是上采样。在一些文献中把这种结构叫做编码器-解码器结构,由于网络的整体结构是一个大些的英文字母U,所以叫做U-net。

  • Encoder:左半部分,由两个3x3的卷积层(RELU)再加上一个2x2的maxpooling层组成一个下采样的模块(后面代码可以看出);
  • Decoder:右半部分,由一个上采样的卷积层(去卷积层)+特征拼接concat+两个3x3的卷积层(ReLU)反复构成(代码中可以看出来);

Unet网络层越深得到的特征图,有着更大的视野域,浅层卷积关注纹理特征,深层网络关注本质的那种特征,所以深层浅层特征都是有各自的意义的;另外一点是通过反卷积得到的更大的尺寸的特征图的边缘,是缺少信息的,毕竟每一次下采样提炼特征的同时,也必然会损失一些边缘特征,而失去的特征并不能从上采样中找回,因此通过特征的拼接,来实现边缘特征补充。

2. Carvana数据集

Carvana数据集获取地址:
https://www.kaggle.com/competitions/carvana-image-masking-challenge

Carvana数据集通常用于图像分割和识别任务。在深度学习和计算机视觉领域,它常被用来训练和测试各种网络模型,如U-Net。这个数据集包含大量的训练图像和相应的掩码(mask),这些掩码用于指示图像中特定区域的位置和形状。

在使用Carvana数据集时,通常需要将训练图像和蒙版分别保存在不同的文件夹中,例如"carvana/train"和"carvana/train_masks"。然后,可以通过编写代码来读取这些数据,并利用PyTorch等深度学习框架进行处理。

值得注意的是,在大数据集(如ImageNet)上预先训练的网络,在使用Carvana数据集进行微调时,往往能够表现出更好的性能。这种预训练加微调的策略有助于模型更好地适应新的数据集和任务。

总之,Carvana数据集是计算机视觉和深度学习领域中一个非常重要的资源,它对于研究和应用各种图像分割和识别技术具有重要意义。如需获取Carvana数据集,可访问Kaggle等数据共享平台。在使用数据集时,请确保遵守相关的使用条款和规定。

2.1 数据集解压

import os
print(os.listdir(r"D:\data\Carvana\carvana-image-masking-challenge"))

import zipfile
import shutil

DATASET_DIR = r'D:\data\Carvana\carvana-image-masking-challenge\\'
WORKING_DIR = r'D:\pycharm\Vit\Unet\dataset\\'

def unzip(DATASET_DIR, WORKING_DIR):
    if len(os.listdir(WORKING_DIR)) <= 1:

        with zipfile.ZipFile(DATASET_DIR + 'train.zip', 'r') as zip_file:
            zip_file.extractall(WORKING_DIR)

        with zipfile.ZipFile(DATASET_DIR + 'train_masks.zip', 'r') as zip_file:
            zip_file.extractall(WORKING_DIR)

        print(
            len(os.listdir(WORKING_DIR + 'train')),
            len(os.listdir(WORKING_DIR + 'train_masks'))
        )

2.2 数据集划分为训练集和验证集

import os
print(os.listdir(r"D:\data\Carvana\carvana-image-masking-challenge"))

import zipfile
import shutil

DATASET_DIR = r'D:\data\Carvana\carvana-image-masking-challenge\\'
WORKING_DIR = r'D:\pycharm\Vit\Unet\dataset\\'

def unzip(DATASET_DIR, WORKING_DIR):
    if len(os.listdir(WORKING_DIR)) <= 1:

        with zipfile.ZipFile(DATASET_DIR + 'train.zip', 'r') as zip_file:
            zip_file.extractall(WORKING_DIR)

        with zipfile.ZipFile(DATASET_DIR + 'train_masks.zip', 'r') as zip_file:
            zip_file.extractall(WORKING_DIR)

        print(
            len(os.listdir(WORKING_DIR + 'train')),
            len(os.listdir(WORKING_DIR + 'train_masks'))
        )

train_dir = WORKING_DIR + 'train/'
val_dir = WORKING_DIR + 'val/'
os.mkdir(val_dir)
for file in sorted(os.listdir(train_dir))[4600:]:
    shutil.move(train_dir + file, val_dir)

masks_dir = WORKING_DIR + 'train_masks/'
val_masks_dir = WORKING_DIR + 'val_masks/'
os.mkdir(val_masks_dir)
for file in sorted(os.listdir(masks_dir))[4600:]:
    shutil.move(masks_dir + file, val_masks_dir)

# os.mkdir(WORKING_DIR + 'saved_images')

3. Unet

3.1 创建Unet网路模型

import torch
import torch.nn as nn
import torchvision.transforms.functional as TF


class DoubleConv(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, kernel_size=3,
                      stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channel, out_channel, kernel_size=3,
                      stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

class Unet(nn.Module):
    def __init__(self, in_channel=3, out_channel=1, features=[64, 128, 256, 512]):
        super(Unet, self).__init__()

        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        for feature in features:
            self.downs.append(DoubleConv(in_channel, feature))
            in_channel = feature

        for feature in reversed(features):
            self.ups.append(nn.ConvTranspose2d(feature * 2, feature, kernel_size=2,
                                               stride=2, padding=1))
            self.ups.append(DoubleConv(in_channel=feature*2, out_channel=feature))

        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], out_channel, kernel_size=1)

    def forward(self, x):
        skip_connections = []
        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)
        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]
        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]
            if skip_connection.shape != x.shape:
                x = TF.resize(x, size=skip_connection.shape[2:])
            concat = torch.cat([x, skip_connection], dim=1)
            x = self.ups[idx+1](concat)

        return self.final_conv(x)



def testUnet():
    x = torch.randn(1, 3, 320, 320)
    model = Unet(in_channel=3, out_channel=1)
    preds = model(x)
    print(x.shape)
    print(preds.shape)

# testUnet()

3.2 准备Carvana数据集

import numpy as np
from torch.utils.data import Dataset
import os
from PIL import Image


class CarvanaDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        image_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.images[index].
                                 replace(".jpg","_mask.gif"))

        image = np.array(Image.open(image_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"))

        if self.transform is not None:
            augmentation = self.transform(image=image, mask=mask)
            image = augmentation["image"]
            mask = augmentation["mask"]

        return image, mask

3.3 准备dataload

import torch
import torchvision
from dataset import CarvanaDataset
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt


def get_dataloader(train_img_dir, train_mask_dir, val_img_dir, val_mask_dir,
                   train_transform, val_transform, batch_size, num_workers,
                   pin_memory=True):

    train_set = CarvanaDataset(image_dir=train_img_dir, mask_dir=train_mask_dir, transform=train_transform)
    val_set = CarvanaDataset(image_dir=val_img_dir, mask_dir=val_mask_dir, transform=val_transform)

    train_loader = DataLoader(train_set, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory)
    val_loader = DataLoader(val_set, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory)

    return train_loader, val_loader

3.4 训练

import torch
import torch.nn as nn
import torch.optim as optim
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
from model import Unet
from dataset import CarvanaDataset
from utils import get_dataloader

import numpy as np
import random

# Hyper Parameter
LEARNING_RATE = 1e-8
BATCH_SIZE = 8
NUM_EPOCHS = 6
LEARNING_RATE_DECAY = 0
PIN_MEMORY = True
# LOAD_MODEL = False
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# print('Using device:', DEVICE)

TRAIN_IMG_PATH = r'./dataset/train'
TRAIN_MASK_PATH = r'./dataset/train_masks'
VAL_IMG_PATH = r'./dataset/val'
VAL_MASK_PATH = r'./dataset/val_masks'

IMAGE_HEIGHT = 320
IMG_WIDTH = 480
NUM_WORKERS = 8


train_losses = []
val_acc = []
val_dice = []

# 设置随机种子
seed = random.randint(1, 100)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True


def train_fn(loader, model, loss_fn, optimizer, scaler):
    loop = tqdm(loader)
    total_loss = 0.0
    for index, (data, target) in enumerate(loop):
        data = data.to(DEVICE)
        target = target.unsqueeze(1).float().to(DEVICE)

        with torch.cuda.amp.autocast(enabled=True):
            predict = model(data)
            loss = loss_fn(predict, target)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()

        loop.set_postfix(loss=loss.item())

    return total_loss / len(loader)


def check_accuracy(loader, model, device='cuda'):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()

    with torch.no_grad():
        for x, y in tqdm(loader):
            x = x.to(DEVICE)
            y = y.unsqueeze(1).float().to(DEVICE)
            predictions = torch.sigmoid(model(x))
            predictions = (predictions > 0.5).float()
            num_correct += (predictions == y).sum()
            num_pixels += torch.numel(predictions)
            dice_score += (2 * (predictions * y).sum()) / (2 * (predictions * y).sum()
                                                           + ((predictions * y)<1).sum())
    accuracy = round(float(num_correct / num_pixels), 4)
    dice = round(float(dice_score / len(loader)), 4)

    print(f"Got {num_correct} / {num_pixels} with acc {num_correct/num_pixels * 100:.2f}")
    print(f"Dice Score: {dice_score} / {len(loader)}")

    model.train()
    return accuracy, dice


def main():
    train_transform = A.Compose([
        A.Resize(height=IMAGE_HEIGHT, width=IMG_WIDTH),
        A.HorizontalFlip(p=0.5),
        A.Rotate(limit=35, p=1.0),
        A.VerticalFlip(p=1.0),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0
        ),
        ToTensorV2(),
    ],)

    val_transform = A.Compose([
        A.Resize(height=IMAGE_HEIGHT, width=IMG_WIDTH),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0
        ),
        ToTensorV2(),
    ],)

    train_loader, val_loader = get_dataloader(TRAIN_IMG_PATH, TRAIN_MASK_PATH,
                                              VAL_IMG_PATH, VAL_MASK_PATH,
                                              train_transform, val_transform,
                                              BATCH_SIZE, num_workers=NUM_WORKERS,
                                              pin_memory=PIN_MEMORY)

    model = Unet(in_channel=3, out_channel=1).to(device=DEVICE)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scaler = torch.cuda.amp.GradScaler()

    for index in range(NUM_EPOCHS):
        print("Current Epoch: ", index)
        train_loss = train_fn(train_loader, model, loss_fn, optimizer, scaler)
        train_losses.append(train_loss)

        accuracy, dice = check_accuracy(val_loader, model, device=DEVICE)
        val_acc.append(accuracy)
        val_dice.append(dice)
        print(f"accuracy:{accuracy}" )
        print(f"dice score:{dice}" )
if __name__ == "__main__":
    main()

3.5 训练结果

Current Epoch:  0
  0%|          | 0/575 [00:00<?, ?it/s]D:\ProgramData\anaconda3\envs\v8\lib\site-packages\torchvision\transforms\functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True).
  warnings.warn(
100%|██████████| 575/575 [12:21<00:00,  1.29s/it, loss=-]
100%|██████████| 61/61 [01:38<00:00,  1.61s/it]
  0%|          | 0/575 [00:00<?, ?it/s]Got 53332194 / 74956800 with acc 71.15
Dice Score: 52.5135612487793 / 61
accuracy:0.7115
dice score:0.8609
Current Epoch:  1
100%|██████████| 575/575 [14:40<00:00,  1.53s/it, loss=-]
100%|██████████| 61/61 [01:16<00:00,  1.26s/it]
  0%|          | 0/575 [00:00<?, ?it/s]Got 53329337 / 74956800 with acc 71.15
Dice Score: 52.515079498291016 / 61
accuracy:0.7115
dice score:0.8609
Current Epoch:  2
100%|██████████| 575/575 [14:40<00:00,  1.53s/it, loss=-]
100%|██████████| 61/61 [01:16<00:00,  1.26s/it]
  0%|          | 0/575 [00:00<?, ?it/s]Got 53376561 / 74956800 with acc 71.21
Dice Score: 52.535728454589844 / 61
accuracy:0.7121
dice score:0.8612
Current Epoch:  3
100%|██████████| 575/575 [14:40<00:00,  1.53s/it, loss=-]
100%|██████████| 61/61 [01:16<00:00,  1.25s/it]
  0%|          | 0/575 [00:00<?, ?it/s]Got 53358680 / 74956800 with acc 71.19
Dice Score: 52.515995025634766 / 61
accuracy:0.7119
dice score:0.8609
Current Epoch:  4
100%|██████████| 575/575 [14:41<00:00,  1.53s/it, loss=-]
100%|██████████| 61/61 [01:16<00:00,  1.26s/it]
  0%|          | 0/575 [00:00<?, ?it/s]Got 53363412 / 74956800 with acc 71.19
Dice Score: 52.50870895385742 / 61
accuracy:0.7119
dice score:0.8608
Current Epoch:  5
100%|██████████| 575/575 [14:50<00:00,  1.55s/it, loss=-]
100%|██████████| 61/61 [01:18<00:00,  1.28s/it]
Got 53369559 / 74956800 with acc 71.20
Dice Score: 52.52872085571289 / 61
accuracy:0.712
dice score:0.8611
相关推荐
艾思科蓝-何老师【H8053】5 分钟前
【ACM出版】第四届信号处理与通信技术国际学术会议(SPCT 2024)
人工智能·信号处理·论文发表·香港中文大学
weixin_4526006933 分钟前
《青牛科技 GC6125:驱动芯片中的璀璨之星,点亮 IPcamera 和云台控制(替代 BU24025/ROHM)》
人工智能·科技·单片机·嵌入式硬件·新能源充电桩·智能充电枪
学术搬运工33 分钟前
【珠海科技学院主办,暨南大学协办 | IEEE出版 | EI检索稳定 】2024年健康大数据与智能医疗国际会议(ICHIH 2024)
大数据·图像处理·人工智能·科技·机器学习·自然语言处理
右恩1 小时前
AI大模型重塑软件开发:流程革新与未来展望
人工智能
图片转成excel表格1 小时前
WPS Office Excel 转 PDF 后图片丢失的解决方法
人工智能·科技·深度学习
ApiHug1 小时前
ApiSmart x Qwen2.5-Coder 开源旗舰编程模型媲美 GPT-4o, ApiSmart 实测!
人工智能·spring boot·spring·ai编程·apihug
哇咔咔哇咔2 小时前
【科普】简述CNN的各种模型
人工智能·神经网络·cnn
李歘歘2 小时前
万字长文解读深度学习——多模态模型CLIP、BLIP、ViLT
人工智能·深度学习
Chatopera 研发团队2 小时前
机器学习 - 为 Jupyter Notebook 安装新的 Kernel
人工智能·机器学习·jupyter