T10 tensorflow数据增强

T10 使用 TensorFlow 实现数据增强

在深度学习的图像分类任务中,数据增强是一种常用的技术,它通过对现有训练样本进行随机变换(例如翻转、旋转、缩放等),以生成更多的训练数据,帮助模型更好地泛化,提升模型在未知数据上的表现。Pytorch框架数据增强方式较为方便,但对于tensorflow还不熟悉,这周主要学习tensorflow框架下数据增强方法。

1. 环境设置和数据加载

将数据集分为训练集、验证集和测试集。

python 复制代码
# 设置 GPU 显存按需使用
gpus = tf.config.list_physical_devices("GPU")
if gpus:
    tf.config.experimental.set_memory_growth(gpus[0], True)
    tf.config.set_visible_devices([gpus[0]], "GPU")

# 数据路径和参数设定
data_dir = "./34-data/"
img_height, img_width = 224, 224
batch_size = 32

# 加载训练和验证数据集
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.3,
    subset="training",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.3,
    subset="validation",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)

# 分割验证集为验证集和测试集
val_batches = tf.data.experimental.cardinality(val_ds)
test_ds = val_ds.take(val_batches // 5)
val_ds = val_ds.skip(val_batches // 5)

print(f'Number of validation batches: {tf.data.experimental.cardinality(val_ds)}')
print(f'Number of test batches: {tf.data.experimental.cardinality(test_ds)}')
2. 数据增强

数据增强 是提高模型泛化能力的重要技术。通过随机改变图像的属性,如水平/垂直翻转、旋转等,模型可以学会更好地处理不同的图像变体。

在 Keras 中,我们可以使用 tf.keras.layers.RandomFliptf.keras.layers.RandomRotation 等层来实现数据增强。在这里,我们对图像进行随机水平和垂直翻转,并进行一定程度的旋转:

python 复制代码
# 定义数据增强操作
data_augmentation = tf.keras.Sequential([
  tf.keras.layers.RandomFlip("horizontal_and_vertical"),
  tf.keras.layers.RandomRotation(0.2),
])

# 数据增强效果展示
import matplotlib.pyplot as plt

plt.figure(figsize=(8, 8))
for images, labels in train_ds.take(1):
    image = tf.expand_dims(images[0], 0)
    for i in range(9):
        augmented_image = data_augmentation(image)
        ax = plt.subplot(3, 3, i + 1)
        plt.imshow(augmented_image[0])
        plt.axis("off")
plt.show()

在上面的代码中,data_augmentation 是一个 Keras 序列模型,它会对输入的图像进行随机的翻转和旋转。通过上面的可视化代码,我们可以直观地看到数据增强后的图像效果。

3. 数据预处理和模型训练

数据增强可以集成到训练管道中。在训练集上,我们通过 map 函数应用数据增强,同时将图像归一化为 [0, 1] 区间以便模型训练。

python 复制代码
AUTOTUNE = tf.data.AUTOTUNE

def preprocess_image(image, label):
    image = image / 255.0  # 图像归一化
    return image, label

# 预处理和缓存数据集
train_ds = train_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)
val_ds = val_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)
test_ds = test_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)

train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
test_ds = test_ds.cache().prefetch(buffer_size=AUTOTUNE)

# 将数据增强应用到训练集
def prepare(ds):
    ds = ds.map(lambda x, y: (data_augmentation(x, training=True), y), num_parallel_calls=AUTOTUNE)
    return ds

train_ds = prepare(train_ds)

prepare 函数中,我们将数据增强操作应用于训练集,并通过 AUTOTUNE 进行多线程处理,加快数据读取速度。

4. 模型搭建和训练

定义了一个简单的卷积神经网络,包含三层卷积层,最后通过全连接层进行分类。模型使用 Adam 优化器,并通过交叉熵损失函数来优化。

python 复制代码
model = tf.keras.Sequential([
    layers.Conv2D(16, 3, padding='same', activation='relu'),
    layers.MaxPooling2D(),
    layers.Conv2D(32, 3, padding='same', activation='relu'),
    layers.MaxPooling2D(),
    layers.Conv2D(64, 3, padding='same', activation='relu'),
    layers.MaxPooling2D(),
    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.Dense(len(class_names))
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

# 训练模型
epochs = 20
history = model.fit(
  train_ds,
  validation_data=val_ds,
  epochs=epochs
)

结果如下:

5. 模型评估

在模型训练完成后,我们可以通过测试集评估模型的表现。

python 复制代码
# 评估模型
loss, acc = model.evaluate(test_ds)
print(f"Test Accuracy: {acc}")

Accuracy 0.90625

6. 总结

这周学习了如何使用 TensorFlow 和 Keras 实现一个包含数据增强的图像分类任务。数据增强在提升模型泛化能力上有显著作用,尤其在训练样本有限的情况下,随机翻转、旋转等操作能够帮助模型学习到更多的图像变体,从而在测试集上取得更好的表现。

相关推荐
IT_陈寒21 分钟前
Redis实战:5个高频应用场景下的性能优化技巧,让你的QPS提升50%
前端·人工智能·后端
AI小云27 分钟前
【数据操作与可视化】Pandas数据处理-其他操作
python·pandas
龙智DevSecOps解决方案32 分钟前
Perforce《2025游戏技术现状报告》Part 1:游戏引擎技术的广泛影响以及生成式AI的成熟之路
人工智能·unity·游戏引擎·游戏开发·perforce
大佬,救命!!!35 分钟前
更换适配python版本直接进行机器学习深度学习等相关环境配置(非仿真环境)
人工智能·python·深度学习·机器学习·学习笔记·详细配置
星空的资源小屋42 分钟前
VNote:程序员必备Markdown笔记神器
javascript·人工智能·笔记·django
梵得儿SHI1 小时前
(第七篇)Spring AI 基础入门总结:四层技术栈全景图 + 三大坑根治方案 + RAG 进阶预告
java·人工智能·spring·springai的四大核心能力·向量维度·prompt模板化·向量存储检索
亚马逊云开发者1 小时前
Amazon Bedrock助力飞书深诺电商广告分类
人工智能
2301_823438021 小时前
解析论文《复杂海上救援环境中无人机群的双阶段协作路径规划与任务分配》
人工智能·算法·无人机
无心水1 小时前
【Python实战进阶】4、Python字典与集合深度解析
开发语言·人工智能·python·python字典·python集合·python实战进阶·python工业化实战进阶
上班职业摸鱼人1 小时前
python文件中导入另外一个模块这个模块
python