Python从0到100(八十五):神经网络-使用迁移学习完成猫狗分类

前言: 零基础学Python:Python从0到100最新最全教程 想做这件事情很久了,这次我更新了自己所写过的所有博客,汇集成了Python从0到100,共一百节课,帮助大家一个月时间里从零基础到学习Python基础语法、Python爬虫、Web开发、 计算机视觉、机器学习、神经网络以及人工智能相关知识,成为学习学习和学业的先行者!
欢迎大家订阅专栏:零基础学Python:Python从0到100最新最全教程!

今天来学习一下如何使用基于tensorflow和keras 的迁移学习完成猫狗分类,欢迎大家一起前来探讨学习~

说明:在此试验下,我们使用的是使用tf2.x版本,在jupyter环境下完成
在本文中,我们将主要完成以下任务:

  1. 实现基于tensorflow和keras的迁移学习

  2. 加载tensorflow提供的数据集(不得使用cifar10)

  3. 需要使用markdown单元格对数据集进行说明

  4. 加载tensorflow提供的预训练模型(不得使用vgg16)

  5. 需要使用markdown单元格对原始模型进行说明

  6. 网络末端连接任意结构的输出端网络

  7. 用图表显示准确率和损失函数

  8. 用cnn工具可视化一批数据的预测结果

  9. 用cnn工具可视化一个数据样本的各层输出

一、加载数据集

1.调用库函数

python 复制代码
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf
import cnn_utils
from tensorflow.keras.preprocessing import image_dataset_from_directory
from tensorflow.keras.layers import GlobalAveragePooling2D,Dense,Input,Dropout

2.加载数据集

数据集加载,数据是通过这个网站下载的猫狗数据集:http://aimaksen.bslience.cn/cats_and_dogs_filtered.zip,实验中为了训练方便,我们取了一个较小的数据集。

python 复制代码
path_to_zip = tf.keras.utils.get_file(
    'data.zip',
    origin='http://aimaksen.bslience.cn/cats_and_dogs_filtered.zip',
    extract=True,
)
PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')

train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'validation')

BATCH_SIZE = 32
IMG_SIZE = (160, 160)

3.数据集管理

使用image_dataset_from_director 进行数据集管理,使用ImageDataGenerator训练过程中会出现错误,不知道是什么原因,就使用了原始的image_dataset_from_director方法进行数据集管理。

python 复制代码
train_dataset = image_dataset_from_directory(train_dir,
                                             shuffle=True,
                                             batch_size=BATCH_SIZE,
                                             image_size=IMG_SIZE)

validation_dataset = image_dataset_from_directory(validation_dir,
                                                  shuffle=True,
                                                  batch_size=BATCH_SIZE,
                                                  image_size=IMG_SIZE)

二、猫狗数据集介绍

1.猫狗数据集介绍:

猫狗数据集包括25000张训练图片,12500张测试图片 ,包括猫和狗两种图片。在此次实验中为了训练方便,我们取了一个较小的数据集。 数据解压之后会有两个文件夹,一个是 "train" ,一个是 "test" ,顾名思义一个是用来训练的,另一个是作为检验正确性的数据。

在train文件夹里边是一些已经命名好的图像,有猫也有狗。而在test文件夹中是只有编号名的图像。

2.图片展示

下面是数据集中的图片展示:

python 复制代码
class_names = ['cats', 'dogs']

plt.figure(figsize=(10, 10))
for images, labels in train_dataset.take(1):
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(class_names[labels[i]])
    plt.axis("off")

🌟🌟🌟 这里是输出的结果:✨✨✨

三、MobileNetV2网络介绍

1.加载tensorflow提供的预训练模型

python 复制代码
val_batches = tf.data.experimental.cardinality(validation_dataset)
test_dataset = validation_dataset.take(val_batches // 5)
validation_dataset = validation_dataset.skip(val_batches // 5)

2.轻量级网络------MobileNetV2

使用轻量级网络------MobileNetV2 进行数据预处理 说明: MobileNetV2是基于倒置的残差结构 ,普通的残差结构是先经过 1x1 的卷积核把 feature map的通道数压下来,然后经过 3x3 的卷积核,最后再用 1x1 的卷积核将通道数扩张回去,即先压缩后扩张,而MobileNetV2的倒置残差结构是先扩张后压缩

3.MobileNetV2的网络模块

MobileNetV2的网络模块样子是这样的:

MobileNetV2是基于深度级可分离卷积构建的网络 ,它是将标准卷积拆分为了两个操作:深度卷积 和 逐点卷积,深度卷积和标准卷积不同,对于标准卷积其卷积核是用在所有的输入通道上,而深度卷积针对每个输入通道采用不同的卷积核,就是说一个卷积核对应一个输入通道 ,所以说深度卷积是depth级别 的操作。而逐点卷积其实就是普通的卷积,只不过其采用1x1 的卷积核。

MobileNetV2的模型如下图所示,其中t为Bottleneck内部升维的倍数,c为通道数,n为该bottleneck重复的次数,s为sride

其中,当stride=1时,才会使用elementwise 的sum将输入和输出特征连接 (如下图左侧);stride=2时,无short cut连接输入和输出特征 (下图右侧):

四、搭建迁移学习

1.训练

python 复制代码
inital_input = tf.keras.applications.mobilenet_v2.preprocess_input
python 复制代码
IMG_SHAPE = IMG_SIZE + (3,)
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
                                               include_top=False,
                                               weights='imagenet')
python 复制代码
base_model.trainable = False
base_model.summary()

🌟🌟🌟 这里是输出的结果:✨✨✨

2.训练结果可视化

用图表显示准确率和损失函数

python 复制代码
# 训练结果可视化,用图表显示准确率和损失函数
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range=range(initial_epochs)
plt.figure(figsize=(8,8))
plt.subplot(2,1,1)
plt.plot(epochs_range, acc, label="Training Accuracy")
plt.plot(epochs_range, val_acc,label="Validation Accuracy")
plt.legend()
plt.title("Training and Validation Accuracy")
 
plt.subplot(2,1,2)
plt.plot(epochs_range, loss, label="Training Loss")
plt.plot(epochs_range, val_loss,label="Validation Loss")
plt.legend()
plt.title("Training and Validation Loss")
plt.show()

🌟🌟🌟 这里是输出的结果:✨✨✨

3.输出训练的准确率

python 复制代码
# 输出训练的准确率
test_loss, test_accuracy = model.evaluate(test_dataset)
print('test accuracy: {:.2f}'.format(test_accuracy))

🌟🌟🌟 这里是输出的结果:✨✨✨

4.用cnn工具可视化一批数据的预测结果

python 复制代码
label_dict = {
    0: 'cat',
    1: 'dog'
}

test_image_batch, test_label_batch = test_dataset.as_numpy_iterator().next()
# 编码成uint8 以图片形式输出
test_image_batch = test_image_batch.astype('uint8')

cnn_utils.plot_predictions(model, test_image_batch, test_label_batch, label_dict, 32, 5, 5)

🌟🌟🌟 这里是输出的结果:✨✨✨

5.数据输出

python 复制代码
# 数据输出,数字化特征图
test_image_batch, test_label_batch = train_dataset.as_numpy_iterator().next()

img_idx = 0
random_batch = np.random.permutation(np.arange(0,len(test_image_batch)))[:BATCH_SIZE]
image_activation = test_image_batch[random_batch[img_idx]:random_batch[img_idx]+1]

cnn_utils.get_activations(base_model, image_activation[0])

🌟🌟🌟 这里是输出的结果:✨✨✨

6.用cnn工具可视化一个数据样本的各层输出

python 复制代码
cnn_utils.display_activations(cnn_utils.get_activations(base_model, image_activation[0]))

🌟🌟🌟 这里是输出的结果:✨✨✨

7.输出结果图像

🌟🌟🌟 这里是输出的结果:✨✨✨

文末送书

本期推荐1:

《Java面向对象程序设计:AI大模型给程序员插上翅膀》

AI工具助力Java编程:故事引领思政,AI助力学习;任务驱动实践,项目提升能力。

京东:https://item.jd.com/14850722.html

从AI助力角度出发,轻松学习编程

故事引入思政,引发读者动手实践

引出目标任务,明确学习目的和方向

AI学习问答与同步训练,提升学习效率

丰富的学习资源,助力实际项目开发
内容简介

随着云计算、物联网、大数据、人工智能等新一代信息技术的发展,Java 作为一种高性能、跨平台的编程语言,有着广泛的应用。本书从应用的角度详尽介绍了Java开发的核心技术。

全书分为12章,主要介绍了Java开发环境、Java编程基础、类和对象、继承和多态、抽象类和接口、Java常用类、内部类和泛型、集合容器、JDBC编程、图形用户界面设计、多线程,最后通过企业项目管理的方式进行实践,实现一个完整案例。

本书每章都通过故事的方式引入思政,并且从故事中引出目标任务。针对目标任务,辅以人工智能工具(ChatGPT、文心一言、讯飞星火)的帮助,得到行之有效的示例。之后对其进行知识解析,并完成上机练习。通过相关的练习巩固知识,并在合适的阶段引入一些常见的算法,加强学生的逻辑思维能力。在每章末尾有AI学习问答,让读者自行探索,同时加入同步训练,加强学习效果。

本期推荐2:

《Python金融大数据分析》

掌握Python,从零到一速成金融分析高手!实战案例深剖,让数字说话,让决策更精准!深入了解金融数据分析的具体过程和方法,提高实操能力。附赠书中案例源代码。

京东:https://item.jd.com/14827368.html

系统:全面构建Python金融大数据分析框架,从零到一,系统掌握核心技能,让学习之路有条不紊。

经典:凝聚笔者多年智慧,解读大数据在金融领域的应用,确保学习内容前沿且可靠。

深入:深度剖析Python在金融大数据分析中的关键技术,直击核心难点,助您深入理解数据背后的价值。

案例:精选实战案例,让您在真实场景中磨炼技能,实现从理论到实践的完美跨越。
内容简介

本书共分为11 章,全面介绍了以Python为工具的金融大数据的理论和实践,特别是量化投资和交易领域的相关应用,并配有项目实战案例。书中涵盖的内容主要有Python概览,结合金融场景演示Python的基本操作,金融数据的获取及实战,MySQL数据库详解及应用,Python在金融大数据分析方面的核心模块详解,金融分析及量化投资,Python量化交易,数据可视化Matplotlib,基于NumPy的股价统计分析实战,基于Matplotlib的股票技术分析实战,以及量化交易策略实战案例等。

本书内容通俗易懂,案例丰富,实用性强,特别适合以下人群阅读:金融行业的从业者、数据分析师、量化投资者、希望提高数据分析能力的投资者,以及对大数据分析感兴趣的编程人员。另外,本书也适合作为相关培训机构的教材。

相关推荐
请为小H留灯1 小时前
Python中很常用的100个函数整理
开发语言·python
七月初七772 小时前
Excel多级联动下拉菜单设置
python·excel·pandas
Serendipity_Carl2 小时前
Pandas数据清洗实战之清洗猫眼电影
python·pycharm·数据分析·pandas
.昕..2 小时前
(二)seacmsv9注入管理员账号密码+orderby+limit
python·网络安全
HerrFu3 小时前
可狱可囚的爬虫系列课程 17:lxml模块的使用
爬虫·python
码叔义3 小时前
X509TrustManager信任SSL证书
python·网络协议·ssl
阿波拉4 小时前
AttributeError: module ‘backend_interagg‘ has no attribute ‘FigureCanvas’问题解决
开发语言·python
m0_748247804 小时前
Python连接SQL SEVER数据库全流程
数据库·python·sql
BigBookX5 小时前
使用OpenCV来获取视频的帧率
python·opencv
蹦蹦跳跳真可爱5895 小时前
Python----计算机视觉处理(opencv:像素,RGB颜色,图像的存储,opencv安装,代码展示)
人工智能·python·opencv·计算机视觉