TensorFlow2实战-系列教程3:猫狗识别1

🧡💛💚TensorFlow2实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Jupyter Notebook中进行
本篇文章配套的代码资源已经上传

1、项目介绍

基本流程:

  • 数据预处理:图像数据处理,准备训练和验证数据集
  • 卷积网络模型:构建网络架构
  • 过拟合问题:观察训练和验证效果,针对过拟合问题提出解决方法
  • 数据增强:图像数据增强方法与效果
  • 迁移学习:深度学习必备训练策略

在我们的数据中,有训练和验证,训练集中分别有猫狗两个类别,都有1000张图像,验证集则有500张

2、数据读取

python 复制代码
import os
import warnings
warnings.filterwarnings("ignore")
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# 数据所在文件夹
base_dir = './data/cats_and_dogs'
train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')

# 训练集
train_cats_dir = os.path.join(train_dir, 'cats')
train_dogs_dir = os.path.join(train_dir, 'dogs')

# 验证集
validation_cats_dir = os.path.join(validation_dir, 'cats')
validation_dogs_dir = os.path.join(validation_dir, 'dogs')
  1. 导包
  2. 指定数据路径
  3. 训练数据路径
  4. 验证数据路径
  5. 训练数据猫类别路径
  6. 训练数据狗类别路径
  7. 验证数据猫类别路径
  8. 训练数据狗类别路径

3、构建卷积神经网络

python 复制代码
model = tf.keras.models.Sequential([
    #如果训练慢,可以把数据设置的更小一些
    tf.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(64, 64, 3)),
    tf.keras.layers.MaxPooling2D(2, 2),

    tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2,2),

    tf.keras.layers.Conv2D(128, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2,2),
    
    #为全连接层准备
    tf.keras.layers.Flatten(),
    
    tf.keras.layers.Dense(512, activation='relu'),
    # 二分类sigmoid就够了
    tf.keras.layers.Dense(1, activation='sigmoid')
])

3个3x3卷积,穿插3个2x2池化,拉平操作,两个全连接层

python 复制代码
model.summary()

打印一下模型架构:

配置训练器:

python 复制代码
model.compile(loss='binary_crossentropy', optimizer=Adam(lr=1e-4), metrics=['acc'])

4、数据预处理

  • 读进来的数据会被自动转换成tensor(float32)格式,分别准备训练和验证
  • 图像数据归一化(0-1)区间
python 复制代码
train_datagen = ImageDataGenerator(rescale=1./255)
test_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
        train_dir,  # 文件夹路径
        target_size=(64, 64),  # 指定resize成的大小
        batch_size=20,
        # 如果one-hot就是categorical,二分类用binary就可以
        class_mode='binary')

validation_generator = test_datagen.flow_from_directory(
        validation_dir,
        target_size=(64, 64),
        batch_size=20,
        class_mode='binary')

打印结果:

Found 2000 images belonging to 2 classes.

Found 1000 images belonging to 2 classes.

5、模型训练

  • 直接fit也可以,但是通常咱们不能把所有数据全部放入内存,fit_generator相当于一个生成器,动态产生所需的batch数据
  • steps_per_epoch相当给定一个停止条件,因为生成器会不断产生batch数据,说白了就是它不知道一个epoch里需要执行多少个step
python 复制代码
history = model.fit_generator(
      train_generator,
      steps_per_epoch=100,  # 2000 images = batch_size * steps
      epochs=20,
      validation_data=validation_generator,
      validation_steps=50,  # 1000 images = batch_size * steps
      verbose=2)

部分打印结果:

Epoch 1/20 100/100 - 9s - loss: 0.6909 - acc: 0.5240 - val_loss: 0.6952 - val_acc: 0.5000

Epoch 2/20 100/100 - 9s - loss: 0.6645 - acc: 0.5960 - val_loss: 0.6906 - val_acc: 0.5360

...

Epoch 19/20 100/100 - 9s - loss: 0.1750 - acc: 0.9460 - val_loss: 0.6277 - val_acc: 0.7390

Epoch 20/20 100/100 - 9s - loss: 0.1593 - acc: 0.9505 - val_loss: 0.5901 - val_acc: 0.7490

6、预测效果展示

python 复制代码
import matplotlib.pyplot as plt
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(len(acc))

plt.plot(epochs, acc, 'bo', label='Training accuracy')
plt.plot(epochs, val_acc, 'b', label='Validation accuracy')
plt.title('Training and validation accuracy')

plt.figure()

plt.plot(epochs, loss, 'bo', label='Training Loss')
plt.plot(epochs, val_loss, 'b', label='Validation Loss')
plt.title('Training and validation loss')
plt.legend()

plt.show()


将训练损失、准确率和对应的epoch分别画图展示

相关推荐
小王爱学人工智能4 小时前
OpenCV的阈值处理
人工智能·opencv·计算机视觉
新智元5 小时前
刚刚,光刻机巨头 ASML 杀入 AI!豪掷 15 亿押注「欧版 OpenAI」,成最大股东
人工智能·openai
机器之心5 小时前
全球图生视频榜单第一,爱诗科技PixVerse V5如何改变一亿用户的视频创作
人工智能·openai
新智元5 小时前
2025年了,AI还看不懂时钟!90%人都能答对,顶尖AI全军覆没
人工智能·openai
湫兮之风5 小时前
OpenCV: Mat存储方式全解析-单通道、多通道内存布局详解
人工智能·opencv·计算机视觉
机器之心5 小时前
Claude不让我们用!国产平替能顶上吗?
人工智能·openai
程序员柳5 小时前
基于YOLOv8的车辆轨迹识别与目标检测研究分析软件源代码+详细文档
人工智能·yolo·目标检测
算家计算5 小时前
一站式高质量数字人动画框架——EchoMimic-V3本地部署教程: 13 亿参数实现统一多模态、多任务人体动画生成
人工智能·开源
API流转日记5 小时前
Gemini-2.5-Flash-Image-Preview 与 GPT-4o 图像生成能力技术差异解析
人工智能·gpt·ai·chatgpt·ai作画·googlecloud
martinzh5 小时前
切块、清洗、烹饪:RAG知识库构建的三步曲
人工智能