【深度学习】 —— 轻松玩转数据增强

文章目录

  • [1. 简介 & 数据集介绍](#1. 简介 & 数据集介绍)
  • [2. 环境](#2. 环境)
  • [3. 代码实现](#3. 代码实现)
    • [3.1 前期准备](#3.1 前期准备)
      • [3.1.1 设置GPU & 导入库](#3.1.1 设置GPU & 导入库)
    • [3.2 数据预处理](#3.2 数据预处理)
      • [3.2.1 数据集划分与预处理](#3.2.1 数据集划分与预处理)
      • [3.2.2 类别识别](#3.2.2 类别识别)
      • [3.2.3 可视化](#3.2.3 可视化)
      • [3.2.4 批次检查](#3.2.4 批次检查)
    • [3.3 数据增强](#3.3 数据增强)
      • [3.3.1 图像变换](#3.3.1 图像变换)
      • [3.3.2 增强方式一:将其嵌入 model 中](#3.3.2 增强方式一:将其嵌入 model 中)
      • [3.3.3 增强方式二:在 Dataset 数据集中进行数据增强](#3.3.3 增强方式二:在 Dataset 数据集中进行数据增强)
    • [3.4 模型建立与训练](#3.4 模型建立与训练)
      • [3.4.1 构建 CNN 模型并编译](#3.4.1 构建 CNN 模型并编译)
      • [3.4.2 模型训练](#3.4.2 模型训练)
      • [3.4.2.1 增强方式一(将其嵌入 model 中)训练过程](#3.4.2.1 增强方式一(将其嵌入 model 中)训练过程)
      • [3.4.2.1 增强方式二(在 Dataset 数据集中进行数据增强)训练过程](#3.4.2.1 增强方式二(在 Dataset 数据集中进行数据增强)训练过程)
  • [4. 模型评估](#4. 模型评估)
    • [4.1 增强方式一(将其嵌入 model 中)评估结果](#4.1 增强方式一(将其嵌入 model 中)评估结果)
    • [4.2 增强方式二(在 Dataset 数据集中进行数据增强)评估结果](#4.2 增强方式二(在 Dataset 数据集中进行数据增强)评估结果)
  • [5. 自定义增强函数](#5. 自定义增强函数)

1. 简介 & 数据集介绍

本文旨在解决深度学习中因数据量不足导致的过拟合问题,将利用 TensorFlow,通过构建 CNN 网络实现猫狗识别。数据集中有 dog 和 cat 2 类图片,每类图片数量各有 300 张图片。

对于数据增强,不仅将介绍 RandomFlip(随机翻转)和 RandomRotation(随机旋转)等基础图像变换,而且会阐述两种不同的数据增强嵌入工作流:

  • Model 层嵌入:利用 GPU 加速,在 Model.fit 时自动生效。
  • Dataset 层映射:在 CPU/数据流水线预处理阶段进行 map 操作。

此外,还将说明如何编写 aug_img 并嵌入 preprocess_image 的自定义数据增强进阶方式。

2. 环境

  • 语言环境:Python 3.12.7
  • 编译器:Jupyter Notebook
  • 深度学习环境:TensorFlow 2.21.0

3. 代码实现

3.1 前期准备

3.1.1 设置GPU & 导入库

导入必要的库并配置 GPU 显存增长,以解决在 Windows 环境下可能出现的显存占用或驱动兼容性问题。

python 复制代码
import matplotlib.pyplot as plt
import numpy as np
import warnings
from tensorflow.keras import layers
import tensorflow as tf
import random

warnings.filterwarnings('ignore')

gpus = tf.config.list_physical_devices("GPU")

if gpus:
    tf.config.experimental.set_memory_growth(gpus[0], True)
    tf.config.set_visible_devices([gpus[0]],"GPU")

print(gpus)

3.2 数据预处理

3.2.1 数据集划分与预处理

使用 Keras 提供的便捷接口构建了训练数据集,将图像统一缩放至标准的 224x224 尺寸,设置批次大小为 32,并按 7:3 的比例划出了 70% 的数据用于模型训练。

python 复制代码
data_dir   = "./Data/34-data/"
img_height = 224
img_width  = 224
batch_size = 32

train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir, validation_split=0.3, subset="training", seed=12, image_size=(img_height, img_width), batch_size=batch_size)

采用与构建训练集完全相同的参数和随机种子(seed=12),从同一个目录中划分出剩余的30% 作为验证数据集,以确保训练集和验证集互不重叠,用于评估模型性能。

python 复制代码
val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir, validation_split=0.3, subset="validation", seed=12, image_size=(img_height, img_width), batch_size=batch_size)

3.2.2 类别识别

从创建好的训练数据集中提取并打印了分类的类别名称,输出的列表为['cat', 'dog'],明确了当前实验是一个简单的图像二分类任务。

python 复制代码
class_names = train_ds.class_names
print(class_names)

3.2.3 可视化

从训练集中抽取一个批次的数据(包含图像和标签),并打印它们的维度。输出结果 (8, 224, 224, 3) 验证了每批次包含 8 张长宽为 224 的 RGB 三通道彩色图片。

python 复制代码
AUTOTUNE = tf.data.AUTOTUNE

def preprocess_image(image,label):
    return (image/255.0,label)

# 归一化处理
train_ds = train_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)
val_ds   = val_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)
test_ds  = test_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)

train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds   = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

plt.figure(figsize=(15, 10))  # 图形的宽为15高为10

for images, labels in train_ds.take(1):
    for i in range(8):
        
        ax = plt.subplot(5, 8, i + 1) 
        plt.imshow(images[i])
        plt.title(class_names[labels[i]])
        
        plt.axis("off")

3.2.4 批次检查

检查之前验证集和测试集数据划分的批次是否正确。

python 复制代码
val_batches = tf.data.experimental.cardinality(val_ds)
test_ds = val_ds.take(val_batches // 5)
val_ds = val_ds.skip(val_batches // 5)

print('Number of validation batches: %d' % tf.data.experimental.cardinality(val_ds))
print('Number of test batches: %d' % tf.data.experimental.cardinality(test_ds))

3.3 数据增强

3.3.1 图像变换

tf.keras.layers.experimental.preprocessing.RandomFlip:水平和垂直随机翻转每个图像。

tf.keras.layers.experimental.preprocessing.RandomRotation:随机旋转每个图像

python 复制代码
# 第一个层表示进行随机的水平和垂直翻转,而第二个层表示按照 0.2 的弧度值进行随机旋转
data_augmentation = tf.keras.Sequential([tf.keras.layers.experimental.preprocessing.RandomFlip("horizontal_and_vertical"),tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),])

将图片添加到批次中,并查看图像变换后的图像

python 复制代码
image = tf.expand_dims(images[i], 0)
plt.figure(figsize=(8, 8))
for i in range(9):
    augmented_image = data_augmentation(image)
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(augmented_image[0])
    plt.axis("off")

3.3.2 增强方式一:将其嵌入 model 中

python 复制代码
model = tf.keras.Sequential([
  data_augmentation,
  layers.Conv2D(16, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
])

3.3.3 增强方式二:在 Dataset 数据集中进行数据增强

python 复制代码
batch_size = 32
AUTOTUNE = tf.data.AUTOTUNE

def prepare(ds):
    ds = ds.map(lambda x, y: (data_augmentation(x, training=True), y), num_parallel_calls=AUTOTUNE)
    return ds

3.4 模型建立与训练

3.4.1 构建 CNN 模型并编译

构建简单 CNN 模型用于测试不同增强方式。

python 复制代码
model = tf.keras.Sequential([
  layers.Conv2D(16, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(32, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(64, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Flatten(),
  layers.Dense(128, activation='relu'),
  layers.Dense(len(class_names))
])

model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])

3.4.2 模型训练

在训练集上进行了 10 个周期的训练。

python 复制代码
epochs=20
history = model.fit(train_ds, validation_data=val_ds, epochs=epochs)

3.4.2.1 增强方式一(将其嵌入 model 中)训练过程

bash 复制代码
Epoch 1/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 5s 231ms/step - accuracy: 0.5000 - loss: 1.3246 - val_accuracy: 0.5811 - val_loss: 0.6682
Epoch 2/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 212ms/step - accuracy: 0.7119 - loss: 0.6004 - val_accuracy: 0.7365 - val_loss: 0.4991
Epoch 3/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 202ms/step - accuracy: 0.8238 - loss: 0.4000 - val_accuracy: 0.8176 - val_loss: 0.4000
Epoch 4/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 204ms/step - accuracy: 0.8500 - loss: 0.3855 - val_accuracy: 0.7905 - val_loss: 0.4352
Epoch 5/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 196ms/step - accuracy: 0.9000 - loss: 0.2159 - val_accuracy: 0.8851 - val_loss: 0.2889
Epoch 6/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 200ms/step - accuracy: 0.9238 - loss: 0.2057 - val_accuracy: 0.8311 - val_loss: 0.4518
Epoch 7/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 198ms/step - accuracy: 0.9167 - loss: 0.1929 - val_accuracy: 0.8108 - val_loss: 0.6540
Epoch 8/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 204ms/step - accuracy: 0.9357 - loss: 0.1412 - val_accuracy: 0.8311 - val_loss: 0.6994
Epoch 9/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 197ms/step - accuracy: 0.9762 - loss: 0.0607 - val_accuracy: 0.9054 - val_loss: 0.3755
Epoch 10/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 199ms/step - accuracy: 0.9857 - loss: 0.0311 - val_accuracy: 0.8986 - val_loss: 0.2656
Epoch 11/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 199ms/step - accuracy: 1.0000 - loss: 0.0097 - val_accuracy: 0.9054 - val_loss: 0.2767
Epoch 12/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 203ms/step - accuracy: 1.0000 - loss: 0.0055 - val_accuracy: 0.9392 - val_loss: 0.2990
Epoch 13/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 208ms/step - accuracy: 1.0000 - loss: 0.0013 - val_accuracy: 0.9392 - val_loss: 0.2575
Epoch 14/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 206ms/step - accuracy: 1.0000 - loss: 9.5639e-04 - val_accuracy: 0.9527 - val_loss: 0.2756
Epoch 15/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 204ms/step - accuracy: 1.0000 - loss: 6.4617e-04 - val_accuracy: 0.9459 - val_loss: 0.2910
Epoch 16/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 213ms/step - accuracy: 1.0000 - loss: 5.0219e-04 - val_accuracy: 0.9392 - val_loss: 0.2940
Epoch 17/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 205ms/step - accuracy: 1.0000 - loss: 4.2109e-04 - val_accuracy: 0.9392 - val_loss: 0.2988
Epoch 18/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 211ms/step - accuracy: 1.0000 - loss: 3.6237e-04 - val_accuracy: 0.9392 - val_loss: 0.3050
Epoch 19/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 207ms/step - accuracy: 1.0000 - loss: 3.1684e-04 - val_accuracy: 0.9392 - val_loss: 0.3101
Epoch 20/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 220ms/step - accuracy: 1.0000 - loss: 2.8164e-04 - val_accuracy: 0.9392 - val_loss: 0.3141

3.4.2.1 增强方式二(在 Dataset 数据集中进行数据增强)训练过程

bash 复制代码
Epoch 1/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 5s 221ms/step - accuracy: 0.5000 - loss: 1.2542 - val_accuracy: 0.6419 - val_loss: 0.6840
Epoch 2/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 205ms/step - accuracy: 0.5595 - loss: 0.6710 - val_accuracy: 0.5878 - val_loss: 0.6631
Epoch 3/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 200ms/step - accuracy: 0.7667 - loss: 0.5882 - val_accuracy: 0.7838 - val_loss: 0.5397
Epoch 4/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 206ms/step - accuracy: 0.8190 - loss: 0.4261 - val_accuracy: 0.8311 - val_loss: 0.4068
Epoch 5/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 200ms/step - accuracy: 0.9190 - loss: 0.2402 - val_accuracy: 0.8649 - val_loss: 0.3028
Epoch 6/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 203ms/step - accuracy: 0.9810 - loss: 0.1124 - val_accuracy: 0.8986 - val_loss: 0.2857
Epoch 7/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 202ms/step - accuracy: 0.9905 - loss: 0.0635 - val_accuracy: 0.9122 - val_loss: 0.2339
Epoch 8/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 196ms/step - accuracy: 0.9976 - loss: 0.0227 - val_accuracy: 0.9122 - val_loss: 0.2581
Epoch 9/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 201ms/step - accuracy: 1.0000 - loss: 0.0124 - val_accuracy: 0.9189 - val_loss: 0.2602
Epoch 10/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 193ms/step - accuracy: 1.0000 - loss: 0.0073 - val_accuracy: 0.9122 - val_loss: 0.2766
Epoch 11/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 197ms/step - accuracy: 1.0000 - loss: 0.0084 - val_accuracy: 0.9189 - val_loss: 0.3155
Epoch 12/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 192ms/step - accuracy: 1.0000 - loss: 0.0072 - val_accuracy: 0.9257 - val_loss: 0.3138
Epoch 13/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 192ms/step - accuracy: 1.0000 - loss: 0.0031 - val_accuracy: 0.9257 - val_loss: 0.3352
Epoch 14/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 196ms/step - accuracy: 1.0000 - loss: 0.0056 - val_accuracy: 0.9189 - val_loss: 0.5202
Epoch 15/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 212ms/step - accuracy: 0.9976 - loss: 0.0091 - val_accuracy: 0.9257 - val_loss: 0.4963
Epoch 16/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 200ms/step - accuracy: 1.0000 - loss: 0.0040 - val_accuracy: 0.9324 - val_loss: 0.3371
Epoch 17/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 207ms/step - accuracy: 1.0000 - loss: 8.9164e-04 - val_accuracy: 0.9392 - val_loss: 0.3209
Epoch 18/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 205ms/step - accuracy: 1.0000 - loss: 6.4076e-04 - val_accuracy: 0.9324 - val_loss: 0.3563
Epoch 19/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 204ms/step - accuracy: 1.0000 - loss: 2.9915e-04 - val_accuracy: 0.9324 - val_loss: 0.3783
Epoch 20/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 196ms/step - accuracy: 1.0000 - loss: 2.3748e-04 - val_accuracy: 0.9324 - val_loss: 0.3623

4. 模型评估

python 复制代码
loss, acc = model.evaluate(test_ds)
print("Accuracy", acc)

4.1 增强方式一(将其嵌入 model 中)评估结果

4.2 增强方式二(在 Dataset 数据集中进行数据增强)评估结果

5. 自定义增强函数

这个函数可以自己随意定义

python 复制代码
def aug_img(image):
    # 1. 随机微调亮度 (有状态随机,适合在 tf.data 管道中使用)
    image = tf.image.random_brightness(image, max_delta=0.2)
    
    # 2. 随机微调饱和度
    image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
    
    # 3. 概率性触发:50% 的几率执行随机左右翻转
    if tf.random.uniform([]) > 0.5:
        image = tf.image.flip_left_right(image)
        
    return image

image = tf.expand_dims(images[3]*255, 0)
print("Min and max pixel values:", image.numpy().min(), image.numpy().max())
plt.figure(figsize=(8, 8))
for i in range(9):
    augmented_image = aug_img(image)
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(augmented_image[0].numpy().astype("uint8"))

    plt.axis("off")

需要将 aug_img 函数嵌入到 preprocess_image 函数中,可以这样写:

python 复制代码
def preprocess_image(image, label):
    # 1. 基础预处理(标准化、调整大小等)
    image = tf.image.resize(image, [180, 180])
    image = tf.cast(image, tf.float32) / 255.0
    
    # 2. 调用你的自定义增强函数
    image = aug_img(image) 
    
    return image, label

# 应用到数据集
train_ds = train_ds.map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
相关推荐
NQBJT3 分钟前
青鸾云步:基于 Cordova 的 AI 导盲机器人 APP 全栈开发实战
人工智能·app·导盲·轮足机器人·青鸾云步
深兰科技32 分钟前
韩国KAIST AI半导体高管项目代表团到访深兰科技,聚焦AI算力与智能产业合作机会
人工智能·机器人·symfony·ai算力·深兰科技·韩国科学技术院·kaist
快乐on9仔39 分钟前
NLP学习(一)transformers之pipeline体验
人工智能·深度学习
冬奇Lab1 小时前
Agent系列(六):记忆管理——让 Agent 记住重要的事
人工智能·agent
冬奇Lab1 小时前
一天一个开源项目(第113篇):notebooklm-py - 把 Google NotebookLM 变成可编程 API,还能接入 Claude Code
人工智能·google·开源
字节跳动开源2 小时前
Viking AI 搜索 CLI 正式发布:会说话,就能做搜索推荐
数据库·人工智能·开源
阿杰技术2 小时前
AI 编程助手落地实战:从提效到重构的全场景指南
人工智能·重构
Agent手记2 小时前
制造业生产流程自动化,Agent需要具备哪些能力?深度拆解2026工业级智能体落地范式与核心架构
大数据·人工智能·ai·架构·自动化
道里2 小时前
花了 5 万刀用 AI 写代码之后,这是我的全部经验
前端·人工智能