- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊
文章目录
- [1. 简介 & 数据集介绍](#1. 简介 & 数据集介绍)
- [2. 环境](#2. 环境)
- [3. 代码实现](#3. 代码实现)
-
- [3.1 前期准备](#3.1 前期准备)
-
- [3.1.1 设置GPU & 导入库](#3.1.1 设置GPU & 导入库)
- [3.2 数据预处理](#3.2 数据预处理)
-
- [3.2.1 数据集划分与预处理](#3.2.1 数据集划分与预处理)
- [3.2.2 类别识别](#3.2.2 类别识别)
- [3.2.3 可视化](#3.2.3 可视化)
- [3.2.4 批次检查](#3.2.4 批次检查)
- [3.3 数据增强](#3.3 数据增强)
-
- [3.3.1 图像变换](#3.3.1 图像变换)
- [3.3.2 增强方式一:将其嵌入 model 中](#3.3.2 增强方式一:将其嵌入 model 中)
- [3.3.3 增强方式二:在 Dataset 数据集中进行数据增强](#3.3.3 增强方式二:在 Dataset 数据集中进行数据增强)
- [3.4 模型建立与训练](#3.4 模型建立与训练)
-
- [3.4.1 构建 CNN 模型并编译](#3.4.1 构建 CNN 模型并编译)
- [3.4.2 模型训练](#3.4.2 模型训练)
- [3.4.2.1 增强方式一(将其嵌入 model 中)训练过程](#3.4.2.1 增强方式一(将其嵌入 model 中)训练过程)
- [3.4.2.1 增强方式二(在 Dataset 数据集中进行数据增强)训练过程](#3.4.2.1 增强方式二(在 Dataset 数据集中进行数据增强)训练过程)
- [4. 模型评估](#4. 模型评估)
-
- [4.1 增强方式一(将其嵌入 model 中)评估结果](#4.1 增强方式一(将其嵌入 model 中)评估结果)
- [4.2 增强方式二(在 Dataset 数据集中进行数据增强)评估结果](#4.2 增强方式二(在 Dataset 数据集中进行数据增强)评估结果)
- [5. 自定义增强函数](#5. 自定义增强函数)
1. 简介 & 数据集介绍
本文旨在解决深度学习中因数据量不足导致的过拟合问题,将利用 TensorFlow,通过构建 CNN 网络实现猫狗识别。数据集中有 dog 和 cat 2 类图片,每类图片数量各有 300 张图片。
对于数据增强,不仅将介绍 RandomFlip(随机翻转)和 RandomRotation(随机旋转)等基础图像变换,而且会阐述两种不同的数据增强嵌入工作流:
- Model 层嵌入:利用 GPU 加速,在 Model.fit 时自动生效。
- Dataset 层映射:在 CPU/数据流水线预处理阶段进行 map 操作。
此外,还将说明如何编写 aug_img 并嵌入 preprocess_image 的自定义数据增强进阶方式。
2. 环境
- 语言环境:Python 3.12.7
- 编译器:Jupyter Notebook
- 深度学习环境:TensorFlow 2.21.0
3. 代码实现
3.1 前期准备
3.1.1 设置GPU & 导入库
导入必要的库并配置 GPU 显存增长,以解决在 Windows 环境下可能出现的显存占用或驱动兼容性问题。
python
import matplotlib.pyplot as plt
import numpy as np
import warnings
from tensorflow.keras import layers
import tensorflow as tf
import random
warnings.filterwarnings('ignore')
gpus = tf.config.list_physical_devices("GPU")
if gpus:
tf.config.experimental.set_memory_growth(gpus[0], True)
tf.config.set_visible_devices([gpus[0]],"GPU")
print(gpus)
3.2 数据预处理
3.2.1 数据集划分与预处理
使用 Keras 提供的便捷接口构建了训练数据集,将图像统一缩放至标准的 224x224 尺寸,设置批次大小为 32,并按 7:3 的比例划出了 70% 的数据用于模型训练。
python
data_dir = "./Data/34-data/"
img_height = 224
img_width = 224
batch_size = 32
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir, validation_split=0.3, subset="training", seed=12, image_size=(img_height, img_width), batch_size=batch_size)

采用与构建训练集完全相同的参数和随机种子(seed=12),从同一个目录中划分出剩余的30% 作为验证数据集,以确保训练集和验证集互不重叠,用于评估模型性能。
python
val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir, validation_split=0.3, subset="validation", seed=12, image_size=(img_height, img_width), batch_size=batch_size)

3.2.2 类别识别
从创建好的训练数据集中提取并打印了分类的类别名称,输出的列表为['cat', 'dog'],明确了当前实验是一个简单的图像二分类任务。
python
class_names = train_ds.class_names
print(class_names)

3.2.3 可视化
从训练集中抽取一个批次的数据(包含图像和标签),并打印它们的维度。输出结果 (8, 224, 224, 3) 验证了每批次包含 8 张长宽为 224 的 RGB 三通道彩色图片。
python
AUTOTUNE = tf.data.AUTOTUNE
def preprocess_image(image,label):
return (image/255.0,label)
# 归一化处理
train_ds = train_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)
val_ds = val_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)
test_ds = test_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)
train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
plt.figure(figsize=(15, 10)) # 图形的宽为15高为10
for images, labels in train_ds.take(1):
for i in range(8):
ax = plt.subplot(5, 8, i + 1)
plt.imshow(images[i])
plt.title(class_names[labels[i]])
plt.axis("off")

3.2.4 批次检查
检查之前验证集和测试集数据划分的批次是否正确。
python
val_batches = tf.data.experimental.cardinality(val_ds)
test_ds = val_ds.take(val_batches // 5)
val_ds = val_ds.skip(val_batches // 5)
print('Number of validation batches: %d' % tf.data.experimental.cardinality(val_ds))
print('Number of test batches: %d' % tf.data.experimental.cardinality(test_ds))

3.3 数据增强
3.3.1 图像变换
● tf.keras.layers.experimental.preprocessing.RandomFlip:水平和垂直随机翻转每个图像。
● tf.keras.layers.experimental.preprocessing.RandomRotation:随机旋转每个图像
python
# 第一个层表示进行随机的水平和垂直翻转,而第二个层表示按照 0.2 的弧度值进行随机旋转
data_augmentation = tf.keras.Sequential([tf.keras.layers.experimental.preprocessing.RandomFlip("horizontal_and_vertical"),tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),])
将图片添加到批次中,并查看图像变换后的图像
python
image = tf.expand_dims(images[i], 0)
plt.figure(figsize=(8, 8))
for i in range(9):
augmented_image = data_augmentation(image)
ax = plt.subplot(3, 3, i + 1)
plt.imshow(augmented_image[0])
plt.axis("off")

3.3.2 增强方式一:将其嵌入 model 中
python
model = tf.keras.Sequential([
data_augmentation,
layers.Conv2D(16, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
])
3.3.3 增强方式二:在 Dataset 数据集中进行数据增强
python
batch_size = 32
AUTOTUNE = tf.data.AUTOTUNE
def prepare(ds):
ds = ds.map(lambda x, y: (data_augmentation(x, training=True), y), num_parallel_calls=AUTOTUNE)
return ds
3.4 模型建立与训练
3.4.1 构建 CNN 模型并编译
构建简单 CNN 模型用于测试不同增强方式。
python
model = tf.keras.Sequential([
layers.Conv2D(16, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(32, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(64, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dense(len(class_names))
])
model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])
3.4.2 模型训练
在训练集上进行了 10 个周期的训练。
python
epochs=20
history = model.fit(train_ds, validation_data=val_ds, epochs=epochs)
3.4.2.1 增强方式一(将其嵌入 model 中)训练过程
bash
Epoch 1/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 5s 231ms/step - accuracy: 0.5000 - loss: 1.3246 - val_accuracy: 0.5811 - val_loss: 0.6682
Epoch 2/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 212ms/step - accuracy: 0.7119 - loss: 0.6004 - val_accuracy: 0.7365 - val_loss: 0.4991
Epoch 3/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 202ms/step - accuracy: 0.8238 - loss: 0.4000 - val_accuracy: 0.8176 - val_loss: 0.4000
Epoch 4/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 204ms/step - accuracy: 0.8500 - loss: 0.3855 - val_accuracy: 0.7905 - val_loss: 0.4352
Epoch 5/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 196ms/step - accuracy: 0.9000 - loss: 0.2159 - val_accuracy: 0.8851 - val_loss: 0.2889
Epoch 6/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 200ms/step - accuracy: 0.9238 - loss: 0.2057 - val_accuracy: 0.8311 - val_loss: 0.4518
Epoch 7/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 198ms/step - accuracy: 0.9167 - loss: 0.1929 - val_accuracy: 0.8108 - val_loss: 0.6540
Epoch 8/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 204ms/step - accuracy: 0.9357 - loss: 0.1412 - val_accuracy: 0.8311 - val_loss: 0.6994
Epoch 9/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 197ms/step - accuracy: 0.9762 - loss: 0.0607 - val_accuracy: 0.9054 - val_loss: 0.3755
Epoch 10/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 199ms/step - accuracy: 0.9857 - loss: 0.0311 - val_accuracy: 0.8986 - val_loss: 0.2656
Epoch 11/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 199ms/step - accuracy: 1.0000 - loss: 0.0097 - val_accuracy: 0.9054 - val_loss: 0.2767
Epoch 12/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 203ms/step - accuracy: 1.0000 - loss: 0.0055 - val_accuracy: 0.9392 - val_loss: 0.2990
Epoch 13/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 208ms/step - accuracy: 1.0000 - loss: 0.0013 - val_accuracy: 0.9392 - val_loss: 0.2575
Epoch 14/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 206ms/step - accuracy: 1.0000 - loss: 9.5639e-04 - val_accuracy: 0.9527 - val_loss: 0.2756
Epoch 15/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 204ms/step - accuracy: 1.0000 - loss: 6.4617e-04 - val_accuracy: 0.9459 - val_loss: 0.2910
Epoch 16/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 213ms/step - accuracy: 1.0000 - loss: 5.0219e-04 - val_accuracy: 0.9392 - val_loss: 0.2940
Epoch 17/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 205ms/step - accuracy: 1.0000 - loss: 4.2109e-04 - val_accuracy: 0.9392 - val_loss: 0.2988
Epoch 18/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 211ms/step - accuracy: 1.0000 - loss: 3.6237e-04 - val_accuracy: 0.9392 - val_loss: 0.3050
Epoch 19/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 207ms/step - accuracy: 1.0000 - loss: 3.1684e-04 - val_accuracy: 0.9392 - val_loss: 0.3101
Epoch 20/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 220ms/step - accuracy: 1.0000 - loss: 2.8164e-04 - val_accuracy: 0.9392 - val_loss: 0.3141
3.4.2.1 增强方式二(在 Dataset 数据集中进行数据增强)训练过程
bash
Epoch 1/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 5s 221ms/step - accuracy: 0.5000 - loss: 1.2542 - val_accuracy: 0.6419 - val_loss: 0.6840
Epoch 2/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 205ms/step - accuracy: 0.5595 - loss: 0.6710 - val_accuracy: 0.5878 - val_loss: 0.6631
Epoch 3/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 200ms/step - accuracy: 0.7667 - loss: 0.5882 - val_accuracy: 0.7838 - val_loss: 0.5397
Epoch 4/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 206ms/step - accuracy: 0.8190 - loss: 0.4261 - val_accuracy: 0.8311 - val_loss: 0.4068
Epoch 5/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 200ms/step - accuracy: 0.9190 - loss: 0.2402 - val_accuracy: 0.8649 - val_loss: 0.3028
Epoch 6/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 203ms/step - accuracy: 0.9810 - loss: 0.1124 - val_accuracy: 0.8986 - val_loss: 0.2857
Epoch 7/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 202ms/step - accuracy: 0.9905 - loss: 0.0635 - val_accuracy: 0.9122 - val_loss: 0.2339
Epoch 8/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 196ms/step - accuracy: 0.9976 - loss: 0.0227 - val_accuracy: 0.9122 - val_loss: 0.2581
Epoch 9/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 201ms/step - accuracy: 1.0000 - loss: 0.0124 - val_accuracy: 0.9189 - val_loss: 0.2602
Epoch 10/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 193ms/step - accuracy: 1.0000 - loss: 0.0073 - val_accuracy: 0.9122 - val_loss: 0.2766
Epoch 11/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 197ms/step - accuracy: 1.0000 - loss: 0.0084 - val_accuracy: 0.9189 - val_loss: 0.3155
Epoch 12/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 192ms/step - accuracy: 1.0000 - loss: 0.0072 - val_accuracy: 0.9257 - val_loss: 0.3138
Epoch 13/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 192ms/step - accuracy: 1.0000 - loss: 0.0031 - val_accuracy: 0.9257 - val_loss: 0.3352
Epoch 14/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 196ms/step - accuracy: 1.0000 - loss: 0.0056 - val_accuracy: 0.9189 - val_loss: 0.5202
Epoch 15/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 212ms/step - accuracy: 0.9976 - loss: 0.0091 - val_accuracy: 0.9257 - val_loss: 0.4963
Epoch 16/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 200ms/step - accuracy: 1.0000 - loss: 0.0040 - val_accuracy: 0.9324 - val_loss: 0.3371
Epoch 17/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 207ms/step - accuracy: 1.0000 - loss: 8.9164e-04 - val_accuracy: 0.9392 - val_loss: 0.3209
Epoch 18/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 205ms/step - accuracy: 1.0000 - loss: 6.4076e-04 - val_accuracy: 0.9324 - val_loss: 0.3563
Epoch 19/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 204ms/step - accuracy: 1.0000 - loss: 2.9915e-04 - val_accuracy: 0.9324 - val_loss: 0.3783
Epoch 20/20
14/14 ━━━━━━━━━━━━━━━━━━━━ 3s 196ms/step - accuracy: 1.0000 - loss: 2.3748e-04 - val_accuracy: 0.9324 - val_loss: 0.3623
4. 模型评估
python
loss, acc = model.evaluate(test_ds)
print("Accuracy", acc)
4.1 增强方式一(将其嵌入 model 中)评估结果

4.2 增强方式二(在 Dataset 数据集中进行数据增强)评估结果

5. 自定义增强函数
这个函数可以自己随意定义
python
def aug_img(image):
# 1. 随机微调亮度 (有状态随机,适合在 tf.data 管道中使用)
image = tf.image.random_brightness(image, max_delta=0.2)
# 2. 随机微调饱和度
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
# 3. 概率性触发:50% 的几率执行随机左右翻转
if tf.random.uniform([]) > 0.5:
image = tf.image.flip_left_right(image)
return image
image = tf.expand_dims(images[3]*255, 0)
print("Min and max pixel values:", image.numpy().min(), image.numpy().max())
plt.figure(figsize=(8, 8))
for i in range(9):
augmented_image = aug_img(image)
ax = plt.subplot(3, 3, i + 1)
plt.imshow(augmented_image[0].numpy().astype("uint8"))
plt.axis("off")
需要将 aug_img 函数嵌入到 preprocess_image 函数中,可以这样写:
python
def preprocess_image(image, label):
# 1. 基础预处理(标准化、调整大小等)
image = tf.image.resize(image, [180, 180])
image = tf.cast(image, tf.float32) / 255.0
# 2. 调用你的自定义增强函数
image = aug_img(image)
return image, label
# 应用到数据集
train_ds = train_ds.map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)