Keras深度学习框架实战(4):使用U-Net架构进行图像分割

1、图像分割概述

1.1 图像分割的概念

人工智能图像分割(AI Image Segmentation)是计算机视觉领域中的一个关键任务,它涉及使用人工智能算法和技术将图像细分为多个具有相似视觉特性或语义意义的区域或对象。这些区域或对象可以是图像中的不同物体、物体的不同部分,或者是具有某种共同特征(如颜色、纹理、形状等)的图像区域。

在图像分割过程中,算法会自动识别图像中的边界和轮廓,并将它们划分成不同的部分。这些部分可以是具有实际意义的对象(如汽车、建筑物、动物等),也可以是图像中的抽象区域(如天空、草地、水面等)。

  • 基于阈值的分割:通过设定一个或多个阈值,将图像的像素值分为不同的类别或区域。这种方法简单直观,但对于复杂图像或具有多个阈值的图像可能效果不佳。

  • 基于区域的分割:根据像素的相似性(如颜色、纹理、亮度等)将图像划分为不同的区域。这种方法可以捕获图像中的局部特征,但可能受到噪声和光照条件的影响。

  • 基于边缘的分割:利用图像中的边缘信息(即像素值发生显著变化的位置)来分割图像。边缘检测算法常用于检测图像中的对象边界。

  • 基于模型的分割:使用预先定义的模型(如活动轮廓模型、水平集方法等)来拟合图像中的对象或区域。这种方法需要一定的先验知识,但能够处理复杂的图像结构。

  • 深度学习分割:利用深度学习算法(如卷积神经网络、U-Net等)对图像进行分割。深度学习模型能够自动学习图像中的特征,并在大量数据上进行训练,以实现高精度的图像分割。

在人工智能图像分割中,深度学习模型因其强大的特征学习能力和处理复杂图像的能力而广受欢迎。这些模型可以通过训练来识别图像中的不同对象和区域,并生成高精度的分割结果。这些结果可以进一步用于各种计算机视觉任务,如目标检测、场景理解、图像编辑等。

1.2 图像分割任务中的U-Net架构

图像分割的U-Net架构是一种专为医学图像分割而设计的卷积神经网络(CNN)架构,由Olaf Ronneberger、Philipp Fischer和Thomas Brox于2015年提出。该架构在图像分割领域,特别是医学图像分割方面,表现出了出色的性能。以下是关于U-Net架构的详细解释:

1.2.1 架构概述

  • U型结构:U-Net的主要特点是其U型对称结构,由一个"编码器"(收缩路径)和一个"解码器"(扩张路径)组成。
  • 编码器(收缩路径):由多个卷积层和最大池化层组成,用于逐渐降低图像的空间分辨率,同时增加特征通道的数量。每个卷积层通常包括两个卷积操作,后接非线性激活函数(如ReLU)。
  • 解码器(扩张路径):目的是通过上采样过程逐步恢复图像的空间分辨率和细节。这些上采样层通常由转置卷积层实现。在每个上采样步骤之后,将特征图与编码器相对应层的特征图合并(通过跳跃连接),以恢复丢失的空间信息。

1.2.2 主要组成部分

  • 跳跃连接(Skip Connections):U-Net的一个关键特性是其跳跃连接,它将编码器中的特征图与解码器中对应层的特征图连接起来。这有助于在上采样过程中恢复细节信息,并允许网络学习更加精确的输出。
  • 最后的层:网络的最后通常是一个1x1的卷积层,用于将特征图映射到所需的输出类别数。

1.2.3 工作原理

  • 特征提取:在编码阶段,网络通过卷积层和池化层提取图像的特征。每次下采样都增加了特征的数量,同时减少了空间维度,这使得网络能够学习到从局部到全局的特征。
  • 特征恢复与精确定位:在解码阶段,网络通过上采样和跳跃连接逐渐恢复图像的空间分辨率。跳跃连接帮助保留了编码器中丢失的细节信息,使得网络能够进行更精确的像素级预测。
  • 端到端训练:U-Net可以端到端地训练,即从原始图像直接输出分割图,这使得模型在医学图像分割等任务中表现出色。

1.2.4 应用

  • U-Net因其出色的性能和有效的利用数据的能力,在医学图像分割领域被广泛采用。此外,它也被用于其他图像分割任务,如卫星图像分析、生物图像处理等。

1.2.5 优势和特点

  • 有效处理小数据集:U-Net的优势在于其能够有效处理小数据集并产生高精度的分割结果。
  • 结合低层次和高层次的特征:其对称结构和跳跃连接有效地结合了低层次和高层次的特征,使其在精确定位方面特别有效。

U-Net架构在图像分割领域,特别是医学图像分割方面,具有显著的优势和广泛的应用前景。

1.3 U-Net的主要特点和应用场景

1.3.1主要特点

  • U型结构:U-Net的架构呈U型,由收缩路径(编码器)和扩展路径(解码器)组成。这种结构能够捕获上下文信息,并在上采样过程中逐步恢复图像的细节。

  • 跳跃连接:U-Net在编码器与解码器之间使用了跳跃连接(Skip Connections),将编码器中的特征图与解码器中相应层级的特征图进行拼接。这种设计有助于在解码过程中保留更多的空间信息,从而生成更精确的分割结果。

  • 高效利用数据:U-Net在医学图像分割任务中表现出色,特别是在小数据集上。其结构设计和参数选择使得模型能够充分利用有限的训练数据,学习到有效的特征表示。

  • 易于训练和优化:U-Net采用了标准的卷积神经网络(CNN)组件,如卷积层、池化层和激活函数等,使得模型易于训练和优化。同时,其结构也相对简单,减少了过拟合的风险。

  • 支持多类别分割:U-Net的输出层可以根据任务需求设置多个通道,从而支持多类别分割。每个通道对应于一个类别,模型可以为每个像素预测一个类别标签。

1.3.2 应用场景

  • 医学图像分割:U-Net在医学图像分割领域具有广泛的应用,如细胞分割、器官分割、病变区域检测等。通过U-Net模型,可以自动提取出图像中的感兴趣区域,为医生提供辅助诊断信息。

  • 卫星图像分析:在卫星图像分析中,U-Net可用于城市规划、环境监测等领域。通过分割出不同的地表覆盖类型(如建筑、植被、水体等),可以为相关部门提供决策支持。

  • 自动驾驶:在自动驾驶系统中,U-Net可用于道路分割、车辆检测等任务。通过识别出道路边界、车辆位置等信息,可以帮助自动驾驶系统做出更准确的决策。

  • 安防监控:在安防监控领域,U-Net可用于人员检测、行为识别等任务。通过实时分割出图像中的人物、车辆等目标,可以为监控系统提供更丰富的信息。

  • 科研实验:在科研实验中,U-Net可用于处理和分析各种实验图像,如细胞培养、化学反应等过程的可视化研究。通过自动分割出实验中的关键区域,可以帮助研究人员更深入地理解实验现象。

U-Net因其高效、准确和易于训练的特点,在多个领域都有着广泛的应用场景。随着深度学习技术的不断发展,U-Net架构的性能也将得到进一步提升,为更多领域的应用提供有力支持。

2、使用U-Net架构进行图像分割

2.1 数据下载和设置

下载和解压数据

可以在python中按照如下操作下载和解压缩数据,也可以只使用网址进行下载。

python 复制代码
!!wget https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz
!!wget https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz
!
!curl -O https://thor.robots.ox.ac.uk/datasets/pets/images.tar.gz
!curl -O https://thor.robots.ox.ac.uk/datasets/pets/annotations.tar.gz
!
!tar -xf images.tar.gz
!tar -xf annotations.tar.gz

准备输入图像和对应的目标分割掩码(masks)的路径

python 复制代码
import os

input_dir = "images/"
target_dir = "annotations/trimaps/"
img_size = (160, 160)
num_classes = 3
batch_size = 32

input_img_paths = sorted(
    [
        os.path.join(input_dir, fname)
        for fname in os.listdir(input_dir)
        if fname.endswith(".jpg")
    ]
)
target_img_paths = sorted(
    [
        os.path.join(target_dir, fname)
        for fname in os.listdir(target_dir)
        if fname.endswith(".png") and not fname.startswith(".")
    ]
)

print("Number of samples:", len(input_img_paths))

for input_path, target_path in zip(input_img_paths[:10], target_img_paths[:10]):
    print(input_path, "|", target_path)

2.2 初步认识输入图像及其对应的分割掩码

python 复制代码
from IPython.display import Image, display
from keras.utils import load_img
from PIL import ImageOps

# Display input image #7
display(Image(filename=input_img_paths[9]))

# Display auto-contrast version of corresponding target (per-pixel categories)
img = ImageOps.autocontrast(load_img(target_img_paths[9]))
display(img)

上述代码打开一张猫咪的图片,并线束出气分割掩码。

2.3 准备数据集以加载并矢量化批次数据

准备数据集,以便能够加载它,并将数据按批次矢量化,以便后续用于机器学习或深度学习模型的训练或评估。

python 复制代码
import keras
import numpy as np
from tensorflow import data as tf_data
from tensorflow import image as tf_image
from tensorflow import io as tf_io


def get_dataset(
    batch_size,
    img_size,
    input_img_paths,
    target_img_paths,
    max_dataset_len=None,
):
    """Returns a TF Dataset."""

    def load_img_masks(input_img_path, target_img_path):
        input_img = tf_io.read_file(input_img_path)
        input_img = tf_io.decode_png(input_img, channels=3)
        input_img = tf_image.resize(input_img, img_size)
        input_img = tf_image.convert_image_dtype(input_img, "float32")

        target_img = tf_io.read_file(target_img_path)
        target_img = tf_io.decode_png(target_img, channels=1)
        target_img = tf_image.resize(target_img, img_size, method="nearest")
        target_img = tf_image.convert_image_dtype(target_img, "uint8")

        # Ground truth labels are 1, 2, 3. Subtract one to make them 0, 1, 2:
        target_img -= 1
        return input_img, target_img

    # For faster debugging, limit the size of data
    if max_dataset_len:
        input_img_paths = input_img_paths[:max_dataset_len]
        target_img_paths = target_img_paths[:max_dataset_len]
    dataset = tf_data.Dataset.from_tensor_slices((input_img_paths, target_img_paths))
    dataset = dataset.map(load_img_masks, num_parallel_calls=tf_data.AUTOTUNE)
    return dataset.batch(batch_size)

2.4 准备基于Xception风格的U-Net模型

准备或构建一个基于Xception架构特点的U-Net模型,以用于图像分割或其他相关的计算机视觉任务。

python 复制代码
from keras import layers


def get_model(img_size, num_classes):
    inputs = keras.Input(shape=img_size + (3,))

    ### [First half of the network: downsampling inputs] ###

    # Entry block
    x = layers.Conv2D(32, 3, strides=2, padding="same")(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)

    previous_block_activation = x  # Set aside residual

    # Blocks 1, 2, 3 are identical apart from the feature depth.
    for filters in [64, 128, 256]:
        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.MaxPooling2D(3, strides=2, padding="same")(x)

        # Project residual
        residual = layers.Conv2D(filters, 1, strides=2, padding="same")(
            previous_block_activation
        )
        x = layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    ### [Second half of the network: upsampling inputs] ###

    for filters in [256, 128, 64, 32]:
        x = layers.Activation("relu")(x)
        x = layers.Conv2DTranspose(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.Conv2DTranspose(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.UpSampling2D(2)(x)

        # Project residual
        residual = layers.UpSampling2D(2)(previous_block_activation)
        residual = layers.Conv2D(filters, 1, padding="same")(residual)
        x = layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    # Add a per-pixel classification layer
    outputs = layers.Conv2D(num_classes, 3, activation="softmax", padding="same")(x)

    # Define the model
    model = keras.Model(inputs, outputs)
    return model


# Build model
model = get_model(img_size, num_classes)
model.summary()

通过上述代码,我们建立了如下表格所示的模型,表格展示深度学习模型的架构,其中列出了各个层的类型、输出形状、参数数量以及它们之间的连接关系。

层类型 (类型) 输出形状 (Output Shape) 参数数量 (Param #) 连接到 (Connected to)
输入层 (input_layer) (无, 160, 160, 3) 0 -
卷积层 (conv2d) (无, 80, 80, 32) 896 input_layer[0][0]
批量归一化 (batch_normalization) (无, 80, 80, 32) 128 conv2d[0][0]
激活函数 (activation) (无, 80, 80, 32) 0 batch_normalization[0][0]
activation_1 (无, 80, 80, 32) 0 activation[0][0]
分离卷积 (separable_conv2d) (无, 80, 80, 64) 2,400 activation_1[0][0]
分离卷积的批量归一化 (batch_normalization) (无, 80, 80, 64) 256 separable_conv2d[0][0]
activation_2 (无, 80, 80, 64) 0 batch_normalization[0][0]
separable_conv2d_1 (无, 80, 80, 64) 4,736 activation_2[0][0]
separable_conv2d_1的批量归一化 (batch_normalization) (无, 80, 80, 64) 256 separable_conv2d_1[0][0]
最大池化 (max_pooling2d) (无, 40, 40, 64) 0 batch_normalization[0][0]
conv2d_1 (无, 40, 40, 64) 2,112 activation[0][0]
add (无, 40, 40, 64) 0 max_pooling2d[0][0], conv2d_1[0][0]
activation_3 (无, 40, 40, 64) 0 add[0][0]
separable_conv2d_2 (无, 40, 40, 128) 8,896 activation_3[0][0]
separable_conv2d_2的批量归一化 (batch_normalization) (无, 40, 40, 128) 512 separable_conv2d_2[0][0]
activation_4 (无, 40, 40, 128) 0 batch_normalization[0][0]
separable_conv2d_3 (无, 40, 40, 128) 17,664 activation_4[0][0]
separable_conv2d_3的批量归一化 (batch_normalization) (无, 40, 40, 128) 512 separable_conv2d_3[0][0]
max_pooling2d_1 (无, 20, 20, 128) 0 batch_normalization[0][0]
conv2d_2 (无, 20, 20, 128) 8,320 add[0][0]
add_1 (无, 20, 20, 128) 0 max_pooling2d_1[0][0], conv2d_2[0][0]
activation_5 (无, 20, 20, 128) 0 add_1[0][0]
separable_conv2d_4 (无, 20, 20, 256) 34,176 activation_5[0][0]
separable_conv2d_4的批量归一化 (batch_normalization) (无, 20, 20, 256) 1,024 separable_conv2d_4[0][0]
activation_6 (无, 20, 20, 256) 0 batch_normalization[0][0]
separable_conv2d_5 (无, 20, 20, 256) 68,096 activation_6[0][0]
separable_conv2d_5的批量归一化 (batch_normalization) (无, 20, 20, 256) 1,024 separable_conv2d_5[0][0]
max_pooling2d_2 (无, 10, 10, 256) 0 batch_normalization[0][0]
conv2d_3 (无, 10, 10, 256) 33,024 add_1[0][0]
add_2 (无, 10, 10, 256) 0 max_pooling2d_2[0][0], conv2d_3[0][0]
activation_7 (无, 10, 10, 256) 0 add_2[0][0]
转置卷积 (conv2d_transpose) (无, 10, 10, 256) 590,080 activation_7[0][0]
转置卷积的批量归一化 (batch_normalization) (无, 10, 10, 256) 1,024 conv2d_transpose[0][0]
activation_8 (无, 10, 10, 256) 0 batch_normalization[0][0]
conv2d_transpose_1 (无, 10, 10, 256) 590,080 activation_8[0][0]
conv2d_transpose_1的批量归一化 (batch_normalization) (无, 10, 10, 256) 1,024 conv2d_transpose_1[0][0]
up_sampling2d_1 (无, 20, 20, 256) 0 add_2[0][0]
up_sampling2d (无, 20, 20, 256) 0 batch_normalization[0][0]
conv2d_4 (无, 20, 20, 256) 65,792 up_sampling2d_1[0][0]
add_3 (无, 20, 20, 256) 0 up_sampling2d[0][0], conv2d_4[0][0]
activation_9 (无, 20, 20, 256) 0 add_3[0][0]
conv2d_transpose_2 (无, 20, 20, 128) 295,040 activation_9[0][0]
conv2d_transpose_2的批量归一化 (batch_normalization) (无, 20, 20, 128) 512 conv2d_transpose_2[0][0]
activation_10 (无, 20, 20, 128) 0 batch_normalization[0][0]
conv2d_transpose_3 (无, 20, 20, 128) 147,584 activation_10[0][0]
conv2d_transpose_3的批量归一化 (batch_normalization) (无, 20, 20, 128)

2.5 设定验证数据集

从原始数据集中留出一部分作为验证集,用于在模型训练过程中进行性能评估。

python 复制代码
import random

# Split our img paths into a training and a validation set
val_samples = 1000
random.Random(1337).shuffle(input_img_paths)
random.Random(1337).shuffle(target_img_paths)
train_input_img_paths = input_img_paths[:-val_samples]
train_target_img_paths = target_img_paths[:-val_samples]
val_input_img_paths = input_img_paths[-val_samples:]
val_target_img_paths = target_img_paths[-val_samples:]

# Instantiate dataset for each split
# Limit input files in `max_dataset_len` for faster epoch training time.
# Remove the `max_dataset_len` arg when running with full dataset.
train_dataset = get_dataset(
    batch_size,
    img_size,
    train_input_img_paths,
    train_target_img_paths,
    max_dataset_len=1000,
)
valid_dataset = get_dataset(
    batch_size, img_size, val_input_img_paths, val_target_img_paths
)

2.6 训练模型

训练模型通常涉及多次迭代(也称为epochs),每次迭代都会遍历整个训练集一次或多次。在每次迭代中,模型都会根据其在训练数据上的表现来调整其参数,以逐步提高其预测性能。这个过程是由一个称为优化器(optimizer)的算法来指导的,它决定了模型如何更新其参数以最小化预测误差。

python 复制代码
# Configure the model for training.
# We use the "sparse" version of categorical_crossentropy
# because our target data is integers.
model.compile(
    optimizer=keras.optimizers.Adam(1e-4), loss="sparse_categorical_crossentropy"
)

callbacks = [
    keras.callbacks.ModelCheckpoint("oxford_segmentation.keras", save_best_only=True)
]

# Train the model, doing validation at the end of each epoch.
epochs = 50
model.fit(
    train_dataset,
    epochs=epochs,
    validation_data=valid_dataset,
    callbacks=callbacks,
    verbose=2,
)

2.7 可视化预测

在机器学习和深度学习的上下文中,可视化预测通常指的是将模型的预测结果以图形或图像的形式展示出来,以便人们能够直观地理解模型的输出和性能。这有助于发现模型可能存在的问题、理解模型的行为以及进行模型调试和改进。

例如,在图像分割任务中,可视化预测可能意味着将模型预测的分割掩码(segmentation mask)叠加在原始图像上,以便人们可以看到模型如何将图像中的不同区域分割开来。在目标检测任务中,可视化预测可能涉及在图像上绘制边界框(bounding boxes)以标记检测到的目标。

python 复制代码
# Generate predictions for all images in the validation set

val_dataset = get_dataset(
    batch_size, img_size, val_input_img_paths, val_target_img_paths
)
val_preds = model.predict(val_dataset)


def display_mask(i):
    """Quick utility to display a model's prediction."""
    mask = np.argmax(val_preds[i], axis=-1)
    mask = np.expand_dims(mask, axis=-1)
    img = ImageOps.autocontrast(keras.utils.array_to_img(mask))
    display(img)


# Display results for validation image #10
i = 10

# Display input image
display(Image(filename=val_input_img_paths[i]))

# Display ground-truth target mask
img = ImageOps.autocontrast(load_img(val_target_img_paths[i]))
display(img)

# Display mask predicted by our model
display_mask(i)  # Note that the model only sees inputs at 150x150.



3、完整的实验源代码

python 复制代码
"""
## Download the data
"""

"""shell
!wget https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz
!wget https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz

curl -O https://thor.robots.ox.ac.uk/datasets/pets/images.tar.gz
curl -O https://thor.robots.ox.ac.uk/datasets/pets/annotations.tar.gz

tar -xf images.tar.gz
tar -xf annotations.tar.gz
"""

"""
## Prepare paths of input images and target segmentation masks
"""

import os

input_dir = "images/"
target_dir = "annotations/trimaps/"
img_size = (160, 160)
num_classes = 3
batch_size = 32

input_img_paths = sorted(
    [
        os.path.join(input_dir, fname)
        for fname in os.listdir(input_dir)
        if fname.endswith(".jpg")
    ]
)
target_img_paths = sorted(
    [
        os.path.join(target_dir, fname)
        for fname in os.listdir(target_dir)
        if fname.endswith(".png") and not fname.startswith(".")
    ]
)

print("Number of samples:", len(input_img_paths))

for input_path, target_path in zip(input_img_paths[:10], target_img_paths[:10]):
    print(input_path, "|", target_path)

"""
## What does one input image and corresponding segmentation mask look like?
"""

from IPython.display import Image, display
from keras.utils import load_img
from PIL import ImageOps

# Display input image #7
display(Image(filename=input_img_paths[9]))

# Display auto-contrast version of corresponding target (per-pixel categories)
img = ImageOps.autocontrast(load_img(target_img_paths[9]))
display(img)

"""
## Prepare dataset to load & vectorize batches of data
"""

import keras
import numpy as np
from tensorflow import data as tf_data
from tensorflow import image as tf_image
from tensorflow import io as tf_io


def get_dataset(
    batch_size,
    img_size,
    input_img_paths,
    target_img_paths,
    max_dataset_len=None,
):
    """Returns a TF Dataset."""

    def load_img_masks(input_img_path, target_img_path):
        input_img = tf_io.read_file(input_img_path)
        input_img = tf_io.decode_png(input_img, channels=3)
        input_img = tf_image.resize(input_img, img_size)
        input_img = tf_image.convert_image_dtype(input_img, "float32")

        target_img = tf_io.read_file(target_img_path)
        target_img = tf_io.decode_png(target_img, channels=1)
        target_img = tf_image.resize(target_img, img_size, method="nearest")
        target_img = tf_image.convert_image_dtype(target_img, "uint8")

        # Ground truth labels are 1, 2, 3. Subtract one to make them 0, 1, 2:
        target_img -= 1
        return input_img, target_img

    # For faster debugging, limit the size of data
    if max_dataset_len:
        input_img_paths = input_img_paths[:max_dataset_len]
        target_img_paths = target_img_paths[:max_dataset_len]
    dataset = tf_data.Dataset.from_tensor_slices((input_img_paths, target_img_paths))
    dataset = dataset.map(load_img_masks, num_parallel_calls=tf_data.AUTOTUNE)
    return dataset.batch(batch_size)


"""
## Prepare U-Net Xception-style model
"""

from keras import layers


def get_model(img_size, num_classes):
    inputs = keras.Input(shape=img_size + (3,))

    ### [First half of the network: downsampling inputs] ###

    # Entry block
    x = layers.Conv2D(32, 3, strides=2, padding="same")(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)

    previous_block_activation = x  # Set aside residual

    # Blocks 1, 2, 3 are identical apart from the feature depth.
    for filters in [64, 128, 256]:
        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.MaxPooling2D(3, strides=2, padding="same")(x)

        # Project residual
        residual = layers.Conv2D(filters, 1, strides=2, padding="same")(
            previous_block_activation
        )
        x = layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    ### [Second half of the network: upsampling inputs] ###

    for filters in [256, 128, 64, 32]:
        x = layers.Activation("relu")(x)
        x = layers.Conv2DTranspose(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.Conv2DTranspose(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.UpSampling2D(2)(x)

        # Project residual
        residual = layers.UpSampling2D(2)(previous_block_activation)
        residual = layers.Conv2D(filters, 1, padding="same")(residual)
        x = layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    # Add a per-pixel classification layer
    outputs = layers.Conv2D(num_classes, 3, activation="softmax", padding="same")(x)

    # Define the model
    model = keras.Model(inputs, outputs)
    return model


# Build model
model = get_model(img_size, num_classes)
model.summary()

"""
## Set aside a validation split
"""

import random

# Split our img paths into a training and a validation set
val_samples = 1000
random.Random(1337).shuffle(input_img_paths)
random.Random(1337).shuffle(target_img_paths)
train_input_img_paths = input_img_paths[:-val_samples]
train_target_img_paths = target_img_paths[:-val_samples]
val_input_img_paths = input_img_paths[-val_samples:]
val_target_img_paths = target_img_paths[-val_samples:]

# Instantiate dataset for each split
# Limit input files in `max_dataset_len` for faster epoch training time.
# Remove the `max_dataset_len` arg when running with full dataset.
train_dataset = get_dataset(
    batch_size,
    img_size,
    train_input_img_paths,
    train_target_img_paths,
    max_dataset_len=1000,
)
valid_dataset = get_dataset(
    batch_size, img_size, val_input_img_paths, val_target_img_paths
)

"""
## Train the model
"""

# Configure the model for training.
# We use the "sparse" version of categorical_crossentropy
# because our target data is integers.
model.compile(
    optimizer=keras.optimizers.Adam(1e-4), loss="sparse_categorical_crossentropy"
)

callbacks = [
    keras.callbacks.ModelCheckpoint("oxford_segmentation.keras", save_best_only=True)
]

# Train the model, doing validation at the end of each epoch.
epochs = 50
model.fit(
    train_dataset,
    epochs=epochs,
    validation_data=valid_dataset,
    callbacks=callbacks,
    verbose=2,
)

"""
## Visualize predictions
"""

# Generate predictions for all images in the validation set

val_dataset = get_dataset(
    batch_size, img_size, val_input_img_paths, val_target_img_paths
)
val_preds = model.predict(val_dataset)


def display_mask(i):
    """Quick utility to display a model's prediction."""
    mask = np.argmax(val_preds[i], axis=-1)
    mask = np.expand_dims(mask, axis=-1)
    img = ImageOps.autocontrast(keras.utils.array_to_img(mask))
    display(img)


# Display results for validation image #10
i = 10

# Display input image
display(Image(filename=val_input_img_paths[i]))

# Display ground-truth target mask
img = ImageOps.autocontrast(load_img(val_target_img_paths[i]))
display(img)

# Display mask predicted by our model
display_mask(i)  # Note that the model only sees inputs at 150x150.

4、总结

今天关于使用U-Net架构进行图像分割的讨论涉及了多个关键方面,以下是针对这些讨论的总结:

4.1 U-Net架构概述

U-Net是一种广泛使用的深度学习架构,特别适用于图像分割任务。它以其独特的U型结构和跳跃连接(skip connections)而著称,这些特点使得U-Net能够有效地学习图像的上下文信息和局部细节,从而提高分割精度。

4.2 数据准备

  • 数据集:为了训练U-Net模型,需要准备一个包含大量标记图像的数据集。这些标记图像通常使用像素级的标注来指示不同对象或区域的边界。
  • 数据增强:为了增加模型的泛化能力,可以使用数据增强技术(如旋转、缩放、翻转等)来扩展数据集。
  • 验证集:留出一部分数据作为验证集,用于在训练过程中评估模型的性能并进行超参数调整。

4.3 模型训练

  • 损失函数:选择适当的损失函数(如交叉熵损失、Dice损失等)来指导模型的训练过程。损失函数应该能够准确地衡量模型的预测结果与真实标注之间的差异。
  • 优化器:使用优化器(如Adam、SGD等)来更新模型的参数,以最小化损失函数。优化器的选择应考虑到训练速度、收敛性和泛化能力等因素。
  • 训练策略:采用合适的训练策略,如学习率衰减、早停法(early stopping)等,以防止模型过拟合并提高泛化性能。

4.4 模型评估

  • 评估指标:使用合适的评估指标(如像素精度、IoU、Dice系数等)来评估模型的性能。这些指标能够全面地反映模型在分割任务上的表现。
  • 可视化预测:将模型的预测结果进行可视化,以便直观地理解模型的输出和性能。可视化有助于发现模型可能存在的问题并进行改进。

4.5 模型优化与改进

  • 调整网络结构:根据任务需求和数据特点,可以尝试调整U-Net的网络结构(如增加或减少层数、改变卷积核大小等)以优化性能。
  • 集成其他技术:结合其他技术(如注意力机制、多尺度特征融合等)来进一步提高U-Net的分割精度和鲁棒性。
  • 迁移学习:利用预训练的模型或特征进行迁移学习,以加速训练过程并提高模型的性能。

4.6 应用场景

U-Net架构已广泛应用于各种图像分割任务中,包括医学图像分割(如细胞检测、器官分割等)、自然图像分割(如街景理解、自动驾驶等)以及遥感图像分割等。通过不断的研究和改进,U-Net及其变种在图像分割领域取得了显著的成果。

相关推荐
想进大厂的小王1 小时前
项目架构介绍以及Spring cloud、redis、mq 等组件的基本认识
redis·分布式·后端·spring cloud·微服务·架构
阿伟*rui2 小时前
认识微服务,微服务的拆分,服务治理(nacos注册中心,远程调用)
微服务·架构·firefox
ZHOU西口3 小时前
微服务实战系列之玩转Docker(十八)
分布式·docker·云原生·架构·数据安全·etcd·rbac
羊小猪~~3 小时前
神经网络基础--什么是正向传播??什么是方向传播??
人工智能·pytorch·python·深度学习·神经网络·算法·机器学习
软工菜鸡4 小时前
预训练语言模型BERT——PaddleNLP中的预训练模型
大数据·人工智能·深度学习·算法·语言模型·自然语言处理·bert
哔哩哔哩技术5 小时前
B站S赛直播中的关键事件识别与应用
深度学习
deephub5 小时前
Tokenformer:基于参数标记化的高效可扩展Transformer架构
人工智能·python·深度学习·架构·transformer
___Dream5 小时前
【CTFN】基于耦合翻译融合网络的多模态情感分析的层次学习
人工智能·深度学习·机器学习·transformer·人机交互
极客代码5 小时前
【Python TensorFlow】入门到精通
开发语言·人工智能·python·深度学习·tensorflow
架构师那点事儿6 小时前
golang 用unsafe 无所畏惧,但使用不得到会panic
架构·go·掘金技术征文