深度学习笔记10-数据增强(Tensorflow)

前言

在深度学习中,数据增强(Data Augmentation)是一种通过对现有数据进行各种转换和变换,从而生成更多训练样本的方法。在计算机视觉中,常见的数据增强方法包括随机裁剪、旋转、翻转、缩放、平移、亮度调整、对比度调整、添加噪声等。其主要目的是通过增加数据量和多样性,帮助模型学习到更加泛化的特征,提高模型的鲁棒性,并减少过拟合现象。

一、前期工作

1.加载数据

python 复制代码
import matplotlib.pyplot as plt
import numpy as np
#隐藏警告
import warnings
warnings.filterwarnings('ignore')
import tensorflow as tf
from tensorflow.keras import layers
python 复制代码
data_dir   = "./T10/"
img_height = 224
img_width  = 224
batch_size = 32

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.3,
    subset="training",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.3,
    subset="validation",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)
python 复制代码
class_names = train_ds.class_names
print(class_names)

2.创建测试集

由于原始数据集不包含测试集,因此需要创建一个。使用tf.data.experimental.cardinality确定验证集中有多少批次的数据,然后将其中的20%移至测试集。

python 复制代码
val_batches = tf.data.experimental.cardinality(val_ds)
test_ds     = val_ds.take(val_batches // 5)
val_ds      = val_ds.skip(val_batches // 5)

print('Number of validation batches: %d' % tf.data.experimental.cardinality(val_ds))
print('Number of test batches: %d' % tf.data.experimental.cardinality(test_ds))

3.配置数据集

python 复制代码
AUTOTUNE=tf.data.AUTOTUNE
def preprocess_image(image,label):
    return(image/255.0,label)
#归一化处理
train_ds=train_ds.map(preprocess_image,num_parallel_calls=AUTOTUNE)
val_ds=val_ds.map(preprocess_image,num_parallel_calls=AUTOTUNE)
test_ds=test_ds.map(preprocess_image,num_parallel_calls=AUTOTUNE)

train_ds=train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds=train_ds.cache().prefetch(buffer_size=AUTOTUNE)

4.可视化

python 复制代码
plt.figure(figsize=(15,10))
for image,labels in train_ds.take(1):
    for i in range(8):
        ax=plt.subplot(5,8,i+1)
        plt.imshow(image[i])
        plt.title(class_name[labels[i]])

        plt.axis('off')

二、数据增强

我们可以使用 tf.keras.layers.RandomFliptf.keras.layers.RandomRotation 进行数据增强

  • tf.keras.layers.RandomFlip:水平和垂直随机翻转每个图像。
  • tf.keras.layers.RandomRotation:随机旋转每个图像
python 复制代码
data_augmentation = tf.keras.Sequential([
  tf.keras.layers.RandomFlip("horizontal_and_vertical"),#添加一个随机翻转层,该层以一定的概率对输入图像进行水平和垂直翻转
  tf.keras.layers.RandomRotation(0.2),#添加一个随机旋转层,该层以一定的概率对输入图像进行旋转,旋转角度在-0.2到0.2弧度之间
])
python 复制代码
# Add the image to a batch.
image = tf.expand_dims(images[i], 0) #0表示在数组的最前面增加一个维度,这样原本的单个图像就变成了一个批次。
python 复制代码
plt.figure(figsize=(8, 8))
for i in range(9):
    augmented_image = data_augmentation(image)
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(augmented_image[0])
    plt.axis("off")

三、增强方式

1.方式一:嵌入model

注意:只有在模型训练时(Model.fit)才会进行增强,在模型评估(Model.evaluate)以及预测(Model.predict)时并不会进行增强操作。

python 复制代码
model = tf.keras.Sequential([
  data_augmentation,
  layers.Conv2D(16, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
])

2.方式二:在dataset中进行

python 复制代码
batch_size = 32
AUTOTUNE = tf.data.AUTOTUNE

def prepare(ds):
    ds = ds.map(lambda x, y: (data_augmentation(x, training=True), y), num_parallel_calls=AUTOTUNE)
    return ds

四、训练模型

python 复制代码
model = tf.keras.Sequential([
  layers.Conv2D(16, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(32, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(64, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Flatten(),
  layers.Dense(128, activation='relu'),
  layers.Dense(len(class_names))
])
python 复制代码
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
python 复制代码
epochs=20
history = model.fit(train_ds , validation_data=val_ds , epochs=epochs)

python 复制代码
acc,loss=model.evaluate(test_ds)
print("Accuracy", acc)

五、自定义增强函数

随机亮度、对比度、色度、饱和度的设置

python 复制代码
import random
# 这是大家可以自由发挥的一个地方
def aug_img(image):
    seed = (random.randint(0, 9), 0)
    # 随机改变图像对比度
    stateless_random_contrast = tf.image.stateless_random_contrast(image, lower=0.1, upper=1.0, seed=seed)
    # 随机改变图像的亮度
    stateless_random_brightness = tf.image.stateless_random_brightness(stateless_random_contrast, max_delta=0.3,seed=seed)
    # 随机改变图像的色度
    stateless_random_hue = tf.image.stateless_random_hue(stateless_random_brightness, max_delta=0.3,seed=seed)
    # 随机改变图像的饱和度
    stateless_random_saturation = tf.image.stateless_random_saturation(stateless_random_hue, lower=0.1, upper=1.0, seed=seed)
    
    return stateless_random_saturation
python 复制代码
image = tf.expand_dims(images[7]*255, 0)
print("Min and max pixel values:", image.numpy().min(), image.numpy().max())
python 复制代码
plt.figure(figsize=(8, 8))
for i in range(9):
    augmented_image = aug_img(image)
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(augmented_image[0].numpy().astype("uint8"))

    plt.axis("off")

六、总结

1.基础数据增强方式

  • 几何数据增强:包括旋转、平移、错切等操作,这些技术通过改变图像中像素值的位置来增强数据。
  • 非几何数据增强:侧重于图像的视觉外观,如噪声注入、翻转、裁剪、调整大小和色彩空间操作。
  • 翻转:水平或垂直翻转图像,是一种常用的数据增强技术。
  • 裁剪和调整大小:通过随机裁剪或中心裁剪作为数据增强,减小图像大小后再调整回原始大小。
  • 注入噪声:向图像中注入噪声,帮助模型学习稳健的特征。
  • 光度增强:通过改变RGB通道值来控制亮度,避免模型偏向特定光照条件。
  • 扰动:随机改变图像的亮度、对比度、饱和度和色调。
  • 核过滤:使用核或高斯模糊过滤器来锐化或模糊图像。

2.tf.data.experimental.cardinality

++tf.data.experimental.cardinality++ 是 TensorFlow 的一个函数,用于估计一个++tf.data.Dataset++ 数据集的元素数量。这个函数返回一个整数或None。如果返回整数,它代表数据集中元素的估计数量;如果返回None,则表示数据集的元素数量未知或无法确定。

3.翻转和旋转

  • tf.keras.layers.RandomFlip:水平和垂直随机翻转每个图像。
  • tf.keras.layers.RandomRotation:随机旋转每个图像

4.有状态随机变换和无状态随机变换

相关推荐
1101 11014 小时前
STM32-笔记36-ADC(模拟/数字转换器)
笔记·stm32·嵌入式硬件
cxr8284 小时前
五类推理(逻辑推理、概率推理、图推理、基于深度学习的推理)的开源库 (二)
人工智能·深度学习
MediaTea5 小时前
Ae 效果详解:放大
图像处理·人工智能·深度学习·计算机视觉
未完成的歌~5 小时前
Kali 离线安装 ipmitool 笔记
linux·运维·笔记
qq_273900236 小时前
pytorch张量列表索引和多维度张量索引比较
人工智能·pytorch·深度学习
深蓝海拓6 小时前
基于深度学习的视觉检测小项目(六) 项目的信号和变量的规划
pytorch·深度学习·yolo·视觉检测·pyqt
夏天是冰红茶7 小时前
Vision Transformer模型详解(附pytorch实现)
人工智能·深度学习·transformer
qq_273900238 小时前
torch.reciprocal介绍
人工智能·pytorch·python·深度学习
青松@FasterAI9 小时前
【NLP高频面题 - 分布式训练篇】ZeRO主要为了解决什么问题?
人工智能·深度学习·自然语言处理·分布式训练·nlp面试
玩具工匠9 小时前
字玩FontPlayer开发笔记3 性能优化 大量canvas渲染卡顿问题
前端·javascript·vue.js·笔记·elementui·typescript