TensorFlow2实战-系列教程3:猫狗识别1

🧡💛💚TensorFlow2实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Jupyter Notebook中进行
本篇文章配套的代码资源已经上传

1、项目介绍

基本流程:

  • 数据预处理:图像数据处理,准备训练和验证数据集
  • 卷积网络模型:构建网络架构
  • 过拟合问题:观察训练和验证效果,针对过拟合问题提出解决方法
  • 数据增强:图像数据增强方法与效果
  • 迁移学习:深度学习必备训练策略

在我们的数据中,有训练和验证,训练集中分别有猫狗两个类别,都有1000张图像,验证集则有500张

2、数据读取

python 复制代码
import os
import warnings
warnings.filterwarnings("ignore")
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# 数据所在文件夹
base_dir = './data/cats_and_dogs'
train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')

# 训练集
train_cats_dir = os.path.join(train_dir, 'cats')
train_dogs_dir = os.path.join(train_dir, 'dogs')

# 验证集
validation_cats_dir = os.path.join(validation_dir, 'cats')
validation_dogs_dir = os.path.join(validation_dir, 'dogs')
  1. 导包
  2. 指定数据路径
  3. 训练数据路径
  4. 验证数据路径
  5. 训练数据猫类别路径
  6. 训练数据狗类别路径
  7. 验证数据猫类别路径
  8. 训练数据狗类别路径

3、构建卷积神经网络

python 复制代码
model = tf.keras.models.Sequential([
    #如果训练慢,可以把数据设置的更小一些
    tf.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(64, 64, 3)),
    tf.keras.layers.MaxPooling2D(2, 2),

    tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2,2),

    tf.keras.layers.Conv2D(128, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2,2),
    
    #为全连接层准备
    tf.keras.layers.Flatten(),
    
    tf.keras.layers.Dense(512, activation='relu'),
    # 二分类sigmoid就够了
    tf.keras.layers.Dense(1, activation='sigmoid')
])

3个3x3卷积,穿插3个2x2池化,拉平操作,两个全连接层

python 复制代码
model.summary()

打印一下模型架构:

配置训练器:

python 复制代码
model.compile(loss='binary_crossentropy', optimizer=Adam(lr=1e-4), metrics=['acc'])

4、数据预处理

  • 读进来的数据会被自动转换成tensor(float32)格式,分别准备训练和验证
  • 图像数据归一化(0-1)区间
python 复制代码
train_datagen = ImageDataGenerator(rescale=1./255)
test_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
        train_dir,  # 文件夹路径
        target_size=(64, 64),  # 指定resize成的大小
        batch_size=20,
        # 如果one-hot就是categorical,二分类用binary就可以
        class_mode='binary')

validation_generator = test_datagen.flow_from_directory(
        validation_dir,
        target_size=(64, 64),
        batch_size=20,
        class_mode='binary')

打印结果:

Found 2000 images belonging to 2 classes.

Found 1000 images belonging to 2 classes.

5、模型训练

  • 直接fit也可以,但是通常咱们不能把所有数据全部放入内存,fit_generator相当于一个生成器,动态产生所需的batch数据
  • steps_per_epoch相当给定一个停止条件,因为生成器会不断产生batch数据,说白了就是它不知道一个epoch里需要执行多少个step
python 复制代码
history = model.fit_generator(
      train_generator,
      steps_per_epoch=100,  # 2000 images = batch_size * steps
      epochs=20,
      validation_data=validation_generator,
      validation_steps=50,  # 1000 images = batch_size * steps
      verbose=2)

部分打印结果:

Epoch 1/20 100/100 - 9s - loss: 0.6909 - acc: 0.5240 - val_loss: 0.6952 - val_acc: 0.5000

Epoch 2/20 100/100 - 9s - loss: 0.6645 - acc: 0.5960 - val_loss: 0.6906 - val_acc: 0.5360

...

Epoch 19/20 100/100 - 9s - loss: 0.1750 - acc: 0.9460 - val_loss: 0.6277 - val_acc: 0.7390

Epoch 20/20 100/100 - 9s - loss: 0.1593 - acc: 0.9505 - val_loss: 0.5901 - val_acc: 0.7490

6、预测效果展示

python 复制代码
import matplotlib.pyplot as plt
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(len(acc))

plt.plot(epochs, acc, 'bo', label='Training accuracy')
plt.plot(epochs, val_acc, 'b', label='Validation accuracy')
plt.title('Training and validation accuracy')

plt.figure()

plt.plot(epochs, loss, 'bo', label='Training Loss')
plt.plot(epochs, val_loss, 'b', label='Validation Loss')
plt.title('Training and validation loss')
plt.legend()

plt.show()


将训练损失、准确率和对应的epoch分别画图展示

相关推荐
lindsayshuo2 分钟前
jetson orin系列开发版安装cuda的gpu版本的opencv
人工智能·opencv
向阳逐梦2 分钟前
ROS机器视觉入门:从基础到人脸识别与目标检测
人工智能·目标检测·计算机视觉
陈鋆27 分钟前
智慧城市初探与解决方案
人工智能·智慧城市
qdprobot28 分钟前
ESP32桌面天气摆件加文心一言AI大模型对话Mixly图形化编程STEAM创客教育
网络·人工智能·百度·文心一言·arduino
QQ395753323728 分钟前
金融量化交易模型的突破与前景分析
人工智能·金融
QQ395753323729 分钟前
金融量化交易:技术突破与模型优化
人工智能·金融
The_Ticker41 分钟前
CFD平台如何接入实时行情源
java·大数据·数据库·人工智能·算法·区块链·软件工程
Elastic 中国社区官方博客1 小时前
Elasticsearch 开放推理 API 增加了对 IBM watsonx.ai Slate 嵌入模型的支持
大数据·数据库·人工智能·elasticsearch·搜索引擎·ai·全文检索
jwolf21 小时前
摸一下elasticsearch8的AI能力:语义搜索/vector向量搜索案例
人工智能·搜索引擎
有Li1 小时前
跨视角差异-依赖网络用于体积医学图像分割|文献速递-生成式模型与transformer在医学影像中的应用
人工智能·计算机视觉