深度学习笔记2:数据增强

上一节由于训练数据集样本量较小,模型过早拟合最终我们在测试数据集的分类精度只达到了70%,本章节我们通过使用数据增强降低过拟合的方法。使用数据增强之后,模型的分类精度将提高到 80%~85%。数据增强是指从现有的训练样本中生成更多的训练数据,做法是利用一些能够生成可信图像的随机变换来增强(augment)样本。数据增强的目标是,模型在训练时不会两次查看完全相同的图片。这有助于模型观察到数据的更多内容,从而具有更强的泛化能力。

数据准备

定义数据增强代码

复制代码
from tensorflow import keras
from tensorflow.keras import layers

data_augmentation = keras.Sequential([
    layers.RandomFlip("horizontal"),        
    layers.RandomRotation(0.1),        
    layers.RandomZoom(0.2),    
    ])

加载数据集

复制代码
import  pathlib
new_base_dir = pathlib.Path('C:/Users/wuchh/.keras/datasets/dogs-vs-cats-small')

batch_size = 32
img_height = 180
img_width = 180

train_dataset = keras.preprocessing.image_dataset_from_directory(
    new_base_dir / 'train' ,
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size
)

validation_dataset = keras.preprocessing.image_dataset_from_directory(
    new_base_dir / 'train' ,
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size
)
显示几张增强后的训练图像
复制代码
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))
for images, _ in train_dataset.take(1):
    for i in range(9):
        augmented_images = data_augmentation(images)       
        ax = plt.subplot(3, 3, i + 1)        
        plt.imshow(augmented_images[0].numpy().astype("uint8"))
        plt.axis("off")        
    plt.show()

数据增强神经网络模型

复制代码
inputs = keras.Input(shape=(180, 180, 3))
x = data_augmentation(inputs) #数字增强
x = layers.Rescaling(1./255)(x)
x = layers.Conv2D(filters=32, kernel_size=3, activation="relu")(x)
x = layers.MaxPooling2D(pool_size=2)(x)
x = layers.Conv2D(filters=64, kernel_size
=3, activation="relu")(x)
x = layers.MaxPooling2D(pool_size=2)(x)
x = layers.Conv2D(filters=128, kernel_size=3, activation="relu")(x)
x = layers.MaxPooling2D(pool_size=2)(x)
x = layers.Conv2D(filters=256, kernel_size=3, activation="relu")(x)
x = layers.MaxPooling2D(pool_size=2)(x)
x = layers.Conv2D(filters=256, kernel_size=3, activation="relu")(x)
x = layers.Flatten()(x)
x = layers.Dropout(0.5)(x)
outputs = layers.Dense(1, activation="sigmoid")(x)
model = keras.Model(inputs=inputs, outputs=outputs)
model.compile(loss="binary_crossentropy",optimizer="rmsprop",metrics=["accuracy"])

训练卷积神经网络

复制代码
callbacks = [keras.callbacks.ModelCheckpoint(filepath="convnet_from_scratch_with_augmentation.model",        
                                             save_best_only=True,monitor="val_loss")]
history = model.fit(    train_dataset,    epochs=100,    validation_data=validation_dataset,    callbacks=callbacks)

在测试集上评估模型

复制代码
test_dataset = keras.preprocessing.image_dataset_from_directory(
new_base_dir / 'test' ,
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)

test_model = keras.models.load_model("convnet_from_scratch_with_augmentation.model")
test_loss, test_acc = test_model.evaluate(test_dataset)
print(f"Test accuracy: {test_acc:.3f}")
复制代码
>>> test_model = keras.models.load_model("convnet_from_scratch_with_augmentation.model")
>>> test_loss, test_acc = test_model.evaluate(test_dataset)
 3/32 [=>............................] - ETA: 2s - loss: 0.4401 - accuracy: 0.7604Corrupt JPEG da 9/32 [=======>......................] - ETA: 1s - loss: 0.5328 - accuracy: 0.7535Corrupt JPEG da32/32 [==============================] - 2s 69ms/step - loss: 0.4752 - accuracy: 0.7890
>>> print(f"Test accuracy: {test_acc:.3f}")
Test accuracy: 0.789
>>>

这次我的测试精度达到了 78.9%,这个进步不少!

相关推荐
Full Stack Developme7 小时前
Python Redis 教程
开发语言·redis·python
码界筑梦坊7 小时前
267-基于Django的携程酒店数据分析推荐系统
python·数据分析·django·毕业设计·echarts
Cherry Zack7 小时前
Django视图进阶:快捷函数、装饰器与请求响应
后端·python·django
qq_4924484468 小时前
Jmeter设置负载阶梯式压测场景(详解教程)
开发语言·python·jmeter
天一生水water8 小时前
什么是时间序列互相关分析(CCF)
机器学习·时间序列
lianyinghhh8 小时前
瓦力机器人-舵机控制(基于树莓派5)
人工智能·python·自然语言处理·硬件工程
Mike_Zhang9 小时前
python3.14版本的free-threading功能体验
python
StarPrayers.9 小时前
旅行商问题(TSP)(2)(heuristics.py)(TSP 的两种贪心启发式算法实现)
前端·人工智能·python·算法·pycharm·启发式算法
koo3649 小时前
李宏毅机器学习笔记21
人工智能·笔记·机器学习
木头左9 小时前
波动率聚类现象对ETF网格密度配置的启示与应对策略
python