手动实现Unet并在Carvana数据集上进行训练

1. Unet

语义分割(Semantic Segmentation)是图像处理和机器视觉一个重要分支。与分类任务不同,语义分割需要判断图像每个像素点的类别,进行精确分割。语义分割目前在自动驾驶、自动抠图、医疗影像等领域有着比较广泛的应用。

Unet可以说是最常用、最简单的一种分割模型了,它简单、高效、易懂、容易构建、可以从小数据集中训练。

论文地址:https://arxiv.org/abs/1505.04597

1.1 提出初衷

  • Unet提出的初衷是为了解决医学图像分割的问题;
  • 一种U型的网络结构来获取上下文的信息和位置信息;
  • 在2015年的ISBI cell tracking比赛中获得了多个第一,一开始这是为了解决细胞层面的分割的任务的

1.2 网络结构

这个结构就是先对图片进行卷积和池化,在Unet论文中是池化4次,比方说一开始的图片是224x224的,那么就会变成112x112,56x56,28x28,14x14四个不同尺寸的特征。然后我们对14x14的特征图做上采样或者反卷积,得到28x28的特征图,这个28x28的特征图与之前的28x28的特征图进行通道伤的拼接concat,然后再对拼接之后的特征图做卷积和上采样,得到56x56的特征图,再与之前的56x56的特征拼接,卷积,再上采样,经过四次上采样可以得到一个与输入图像尺寸相同的224x224的预测结果。

其实整体来看,这个也是一个Encoder-Decoder的结构:前半部分就是特征提取,后半部分是上采样。在一些文献中把这种结构叫做编码器-解码器结构,由于网络的整体结构是一个大些的英文字母U,所以叫做U-net。

  • Encoder:左半部分,由两个3x3的卷积层(RELU)再加上一个2x2的maxpooling层组成一个下采样的模块(后面代码可以看出);
  • Decoder:右半部分,由一个上采样的卷积层(去卷积层)+特征拼接concat+两个3x3的卷积层(ReLU)反复构成(代码中可以看出来);

Unet网络层越深得到的特征图,有着更大的视野域,浅层卷积关注纹理特征,深层网络关注本质的那种特征,所以深层浅层特征都是有各自的意义的;另外一点是通过反卷积得到的更大的尺寸的特征图的边缘,是缺少信息的,毕竟每一次下采样提炼特征的同时,也必然会损失一些边缘特征,而失去的特征并不能从上采样中找回,因此通过特征的拼接,来实现边缘特征补充。

2. Carvana数据集

Carvana数据集获取地址:
https://www.kaggle.com/competitions/carvana-image-masking-challenge

Carvana数据集通常用于图像分割和识别任务。在深度学习和计算机视觉领域,它常被用来训练和测试各种网络模型,如U-Net。这个数据集包含大量的训练图像和相应的掩码(mask),这些掩码用于指示图像中特定区域的位置和形状。

在使用Carvana数据集时,通常需要将训练图像和蒙版分别保存在不同的文件夹中,例如"carvana/train"和"carvana/train_masks"。然后,可以通过编写代码来读取这些数据,并利用PyTorch等深度学习框架进行处理。

值得注意的是,在大数据集(如ImageNet)上预先训练的网络,在使用Carvana数据集进行微调时,往往能够表现出更好的性能。这种预训练加微调的策略有助于模型更好地适应新的数据集和任务。

总之,Carvana数据集是计算机视觉和深度学习领域中一个非常重要的资源,它对于研究和应用各种图像分割和识别技术具有重要意义。如需获取Carvana数据集,可访问Kaggle等数据共享平台。在使用数据集时,请确保遵守相关的使用条款和规定。

2.1 数据集解压

import os
print(os.listdir(r"D:\data\Carvana\carvana-image-masking-challenge"))

import zipfile
import shutil

DATASET_DIR = r'D:\data\Carvana\carvana-image-masking-challenge\\'
WORKING_DIR = r'D:\pycharm\Vit\Unet\dataset\\'

def unzip(DATASET_DIR, WORKING_DIR):
    if len(os.listdir(WORKING_DIR)) <= 1:

        with zipfile.ZipFile(DATASET_DIR + 'train.zip', 'r') as zip_file:
            zip_file.extractall(WORKING_DIR)

        with zipfile.ZipFile(DATASET_DIR + 'train_masks.zip', 'r') as zip_file:
            zip_file.extractall(WORKING_DIR)

        print(
            len(os.listdir(WORKING_DIR + 'train')),
            len(os.listdir(WORKING_DIR + 'train_masks'))
        )

2.2 数据集划分为训练集和验证集

import os
print(os.listdir(r"D:\data\Carvana\carvana-image-masking-challenge"))

import zipfile
import shutil

DATASET_DIR = r'D:\data\Carvana\carvana-image-masking-challenge\\'
WORKING_DIR = r'D:\pycharm\Vit\Unet\dataset\\'

def unzip(DATASET_DIR, WORKING_DIR):
    if len(os.listdir(WORKING_DIR)) <= 1:

        with zipfile.ZipFile(DATASET_DIR + 'train.zip', 'r') as zip_file:
            zip_file.extractall(WORKING_DIR)

        with zipfile.ZipFile(DATASET_DIR + 'train_masks.zip', 'r') as zip_file:
            zip_file.extractall(WORKING_DIR)

        print(
            len(os.listdir(WORKING_DIR + 'train')),
            len(os.listdir(WORKING_DIR + 'train_masks'))
        )

train_dir = WORKING_DIR + 'train/'
val_dir = WORKING_DIR + 'val/'
os.mkdir(val_dir)
for file in sorted(os.listdir(train_dir))[4600:]:
    shutil.move(train_dir + file, val_dir)

masks_dir = WORKING_DIR + 'train_masks/'
val_masks_dir = WORKING_DIR + 'val_masks/'
os.mkdir(val_masks_dir)
for file in sorted(os.listdir(masks_dir))[4600:]:
    shutil.move(masks_dir + file, val_masks_dir)

# os.mkdir(WORKING_DIR + 'saved_images')

3. Unet

3.1 创建Unet网路模型

import torch
import torch.nn as nn
import torchvision.transforms.functional as TF


class DoubleConv(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, kernel_size=3,
                      stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channel, out_channel, kernel_size=3,
                      stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

class Unet(nn.Module):
    def __init__(self, in_channel=3, out_channel=1, features=[64, 128, 256, 512]):
        super(Unet, self).__init__()

        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        for feature in features:
            self.downs.append(DoubleConv(in_channel, feature))
            in_channel = feature

        for feature in reversed(features):
            self.ups.append(nn.ConvTranspose2d(feature * 2, feature, kernel_size=2,
                                               stride=2, padding=1))
            self.ups.append(DoubleConv(in_channel=feature*2, out_channel=feature))

        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], out_channel, kernel_size=1)

    def forward(self, x):
        skip_connections = []
        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)
        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]
        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]
            if skip_connection.shape != x.shape:
                x = TF.resize(x, size=skip_connection.shape[2:])
            concat = torch.cat([x, skip_connection], dim=1)
            x = self.ups[idx+1](concat)

        return self.final_conv(x)



def testUnet():
    x = torch.randn(1, 3, 320, 320)
    model = Unet(in_channel=3, out_channel=1)
    preds = model(x)
    print(x.shape)
    print(preds.shape)

# testUnet()

3.2 准备Carvana数据集

import numpy as np
from torch.utils.data import Dataset
import os
from PIL import Image


class CarvanaDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        image_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.images[index].
                                 replace(".jpg","_mask.gif"))

        image = np.array(Image.open(image_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"))

        if self.transform is not None:
            augmentation = self.transform(image=image, mask=mask)
            image = augmentation["image"]
            mask = augmentation["mask"]

        return image, mask

3.3 准备dataload

import torch
import torchvision
from dataset import CarvanaDataset
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt


def get_dataloader(train_img_dir, train_mask_dir, val_img_dir, val_mask_dir,
                   train_transform, val_transform, batch_size, num_workers,
                   pin_memory=True):

    train_set = CarvanaDataset(image_dir=train_img_dir, mask_dir=train_mask_dir, transform=train_transform)
    val_set = CarvanaDataset(image_dir=val_img_dir, mask_dir=val_mask_dir, transform=val_transform)

    train_loader = DataLoader(train_set, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory)
    val_loader = DataLoader(val_set, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory)

    return train_loader, val_loader

3.4 训练

import torch
import torch.nn as nn
import torch.optim as optim
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
from model import Unet
from dataset import CarvanaDataset
from utils import get_dataloader

import numpy as np
import random

# Hyper Parameter
LEARNING_RATE = 1e-8
BATCH_SIZE = 8
NUM_EPOCHS = 6
LEARNING_RATE_DECAY = 0
PIN_MEMORY = True
# LOAD_MODEL = False
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# print('Using device:', DEVICE)

TRAIN_IMG_PATH = r'./dataset/train'
TRAIN_MASK_PATH = r'./dataset/train_masks'
VAL_IMG_PATH = r'./dataset/val'
VAL_MASK_PATH = r'./dataset/val_masks'

IMAGE_HEIGHT = 320
IMG_WIDTH = 480
NUM_WORKERS = 8


train_losses = []
val_acc = []
val_dice = []

# 设置随机种子
seed = random.randint(1, 100)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True


def train_fn(loader, model, loss_fn, optimizer, scaler):
    loop = tqdm(loader)
    total_loss = 0.0
    for index, (data, target) in enumerate(loop):
        data = data.to(DEVICE)
        target = target.unsqueeze(1).float().to(DEVICE)

        with torch.cuda.amp.autocast(enabled=True):
            predict = model(data)
            loss = loss_fn(predict, target)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()

        loop.set_postfix(loss=loss.item())

    return total_loss / len(loader)


def check_accuracy(loader, model, device='cuda'):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()

    with torch.no_grad():
        for x, y in tqdm(loader):
            x = x.to(DEVICE)
            y = y.unsqueeze(1).float().to(DEVICE)
            predictions = torch.sigmoid(model(x))
            predictions = (predictions > 0.5).float()
            num_correct += (predictions == y).sum()
            num_pixels += torch.numel(predictions)
            dice_score += (2 * (predictions * y).sum()) / (2 * (predictions * y).sum()
                                                           + ((predictions * y)<1).sum())
    accuracy = round(float(num_correct / num_pixels), 4)
    dice = round(float(dice_score / len(loader)), 4)

    print(f"Got {num_correct} / {num_pixels} with acc {num_correct/num_pixels * 100:.2f}")
    print(f"Dice Score: {dice_score} / {len(loader)}")

    model.train()
    return accuracy, dice


def main():
    train_transform = A.Compose([
        A.Resize(height=IMAGE_HEIGHT, width=IMG_WIDTH),
        A.HorizontalFlip(p=0.5),
        A.Rotate(limit=35, p=1.0),
        A.VerticalFlip(p=1.0),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0
        ),
        ToTensorV2(),
    ],)

    val_transform = A.Compose([
        A.Resize(height=IMAGE_HEIGHT, width=IMG_WIDTH),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0
        ),
        ToTensorV2(),
    ],)

    train_loader, val_loader = get_dataloader(TRAIN_IMG_PATH, TRAIN_MASK_PATH,
                                              VAL_IMG_PATH, VAL_MASK_PATH,
                                              train_transform, val_transform,
                                              BATCH_SIZE, num_workers=NUM_WORKERS,
                                              pin_memory=PIN_MEMORY)

    model = Unet(in_channel=3, out_channel=1).to(device=DEVICE)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scaler = torch.cuda.amp.GradScaler()

    for index in range(NUM_EPOCHS):
        print("Current Epoch: ", index)
        train_loss = train_fn(train_loader, model, loss_fn, optimizer, scaler)
        train_losses.append(train_loss)

        accuracy, dice = check_accuracy(val_loader, model, device=DEVICE)
        val_acc.append(accuracy)
        val_dice.append(dice)
        print(f"accuracy:{accuracy}" )
        print(f"dice score:{dice}" )
if __name__ == "__main__":
    main()

3.5 训练结果

Current Epoch:  0
  0%|          | 0/575 [00:00<?, ?it/s]D:\ProgramData\anaconda3\envs\v8\lib\site-packages\torchvision\transforms\functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True).
  warnings.warn(
100%|██████████| 575/575 [12:21<00:00,  1.29s/it, loss=-]
100%|██████████| 61/61 [01:38<00:00,  1.61s/it]
  0%|          | 0/575 [00:00<?, ?it/s]Got 53332194 / 74956800 with acc 71.15
Dice Score: 52.5135612487793 / 61
accuracy:0.7115
dice score:0.8609
Current Epoch:  1
100%|██████████| 575/575 [14:40<00:00,  1.53s/it, loss=-]
100%|██████████| 61/61 [01:16<00:00,  1.26s/it]
  0%|          | 0/575 [00:00<?, ?it/s]Got 53329337 / 74956800 with acc 71.15
Dice Score: 52.515079498291016 / 61
accuracy:0.7115
dice score:0.8609
Current Epoch:  2
100%|██████████| 575/575 [14:40<00:00,  1.53s/it, loss=-]
100%|██████████| 61/61 [01:16<00:00,  1.26s/it]
  0%|          | 0/575 [00:00<?, ?it/s]Got 53376561 / 74956800 with acc 71.21
Dice Score: 52.535728454589844 / 61
accuracy:0.7121
dice score:0.8612
Current Epoch:  3
100%|██████████| 575/575 [14:40<00:00,  1.53s/it, loss=-]
100%|██████████| 61/61 [01:16<00:00,  1.25s/it]
  0%|          | 0/575 [00:00<?, ?it/s]Got 53358680 / 74956800 with acc 71.19
Dice Score: 52.515995025634766 / 61
accuracy:0.7119
dice score:0.8609
Current Epoch:  4
100%|██████████| 575/575 [14:41<00:00,  1.53s/it, loss=-]
100%|██████████| 61/61 [01:16<00:00,  1.26s/it]
  0%|          | 0/575 [00:00<?, ?it/s]Got 53363412 / 74956800 with acc 71.19
Dice Score: 52.50870895385742 / 61
accuracy:0.7119
dice score:0.8608
Current Epoch:  5
100%|██████████| 575/575 [14:50<00:00,  1.55s/it, loss=-]
100%|██████████| 61/61 [01:18<00:00,  1.28s/it]
Got 53369559 / 74956800 with acc 71.20
Dice Score: 52.52872085571289 / 61
accuracy:0.712
dice score:0.8611
相关推荐
中杯可乐多加冰39 分钟前
【玩转OCR | 腾讯云智能结构化OCR应用探索和场景实践】
人工智能·深度学习·信息可视化·云计算·ocr·腾讯云·玩转腾讯云ocr
ROBOT玲玉1 小时前
PaddleOCROCR关键信息抽取训练过程
人工智能·ocr
feifeikon4 小时前
机器学习DAY3续:逻辑回归、极大似然、梯度下降 (逻辑回归完)
人工智能·机器学习·逻辑回归
贝多财经4 小时前
高频生活场景带动低频金融服务,美团企业版点燃场景金融建设引擎
人工智能·金融·生活
百家方案5 小时前
「下载」“一机游”智慧旅游平台解决方案:智慧文旅4大应用8大特色,实现旅游监管、营销与服务的全面升级
大数据·人工智能·智慧文旅·智慧旅游
deephub6 小时前
SCOPE:面向大语言模型长序列生成的双阶段KV缓存优化框架
人工智能·深度学习·transformer·大语言模型·kv缓存
AidLux6 小时前
智能边缘计算×软硬件一体化:开启全场景效能革命新征程(高校开发者作品)
人工智能·边缘计算
程序猿阿伟6 小时前
《迁移学习与联邦学习:推动人工智能发展的关键力量》
人工智能·机器学习·迁移学习
mt4481396 小时前
突发!刚刚,OpenAI裂变成了两块
人工智能·yolo·语言模型·chatgpt·gpt-3·bard·文心一言
CES_Asia7 小时前
数据资产试点开启,CES Asia 2025聚焦智慧城市新发展
人工智能·科技·数码相机·智能手机·智慧城市·智能手表