【深度学习_TensorFlow】调用keras高层API重写手写数字识别项目

写在前面

上一阶段我们完成了手写数字识别项目的构建,了解了网络构建、训练、测试的基本流程,但是对于一些常见的操作,因其使用过于频繁,实际上并无必要手动实现,而早已被封装为函数了。

这篇文章我们将了解keras高层API,将手写数字识别项目用高层API重写一遍。


写在中间

学习之前,来探讨之前一直有的疑问,运行时提示没有相应模块,原因肯定是引入包的原因,如果不懂kerastf.keras的区别,还是很有必要看看这篇文章

  • 其实 keras 可以理解为一套搭建与训练神经网络的高层 API 协议,Keras 本身已经实现了此协议,安装标准的 Keras 库就可以方便地调用TensorFlow、CNTK 等后端完成加速计算;在 TensorFlow 中,也实现了一套 keras 协议,即 tf.keras,它与 TensorFlow 深度融合,且只能基于 TensorFlow 后端运算,并对TensorFlow 的支持更完美。对于使用 TensorFlow 的开发者来说,tf.keras 可以理解为一个普通的子模块,与其他子模块,如 tf.math,tf.data 等并没有什么差别。

但是为了方便我们操作,为避免混淆,我们就选择tf.keras来完成代码中的相关操作。

注意:tensorflow版本和keras版本一定要相兼容,不兼容的话,引入tensorflow.keras就会报错。


1. 引包

Python 复制代码
import tensorflow as tf
from tensorflow.keras import datasets, layers, Sequential, optimizers, models, losses
# pycharm中会出现红色波浪线,但不影响运行

2. 数据集的读取与处理

这一步就老生常谈了,直接复制粘贴过来

Python 复制代码
def preprocess(x, y):

    x = tf.cast(x, dtype=tf.float32) / 255.
    x = tf.reshape(x, [-1, 28*28])
    y = tf.cast(y, dtype=tf.int32)
    y = tf.one_hot(y, depth=10)
    return x, y


(x, y), (x_test, y_test) = datasets.mnist.load_data()


# 数据集的处理,由于返回的数据集是numpy类型的,若要使用GPU加速,需转换为张量
train_db = tf.data.Dataset.from_tensor_slices((x, y))
train_db = train_db.shuffle(60000).batch(128).map(preprocess).repeat(5)
# 对测试集的简单处理
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.shuffle(10000).batch(128).map(preprocess)

3. 网络层的构建

封装创建


对于常见的网络,原来需要手动调用每一层的类实例完成前向传播运算,当网络层数变得较深时,这一部分代码显得非常臃肿。可以通过 tf.keras 提供的网络容器 Sequential 将多个网络层封装成一个大网络模型,只需要调用网络模型的实例一次即可完成数据从第一层到最末层的顺序传播运算。

在完成网络创建时,网络层类并没有创建内部权值张量等成员变量,此时通过调用类的 build 方法并指定输入大小,即可自动创建所有层的内部张量。通过 summary()函数可以方便打印出网络结构和参数量

Python 复制代码
# 创建网络
network = Sequential([layers.Dense(256, activation='relu'),
                     layers.Dense(128, activation='relu'),
                     layers.Dense(64, activation='relu'),
                     layers.Dense(32, activation='relu'),
                     layers.Dense(10)])
network.build(input_shape=(None, 28 * 28))  # None代表batch不定
network.summary()

就如我们会打印出以下信息:

  • Layer (type):层名称、层类型

  • Output Shape:输出形状

  • Param #:层的参数个数

  • Trainable params、Non-trainable params:可优化的参数、不可优化的参数

Plain 复制代码
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense (Dense)               (None, 256)               200960    
                                                                 
 dense_1 (Dense)             (None, 128)               32896     
                                                                 
 dense_2 (Dense)             (None, 64)                8256      
                                                                 
 dense_3 (Dense)             (None, 32)                2080      
                                                                 
 dense_4 (Dense)             (None, 10)                330       
                                                                 
=================================================================
Total params: 244,522
Trainable params: 244,522
Non-trainable params: 0
_________________________________________________________________

4. 模型装配、训练、测试

在训练网络时,一般的流程是通过前向计算获得网络的输出值,再通过损失函数计算网络误差,然后通过自动求导工具计算梯度并更新,同时间隔性地测试网络的性能。对于这种常用的训练逻辑,可以直接通过 Keras 提供的模型装配与训练等高层接口实现,简洁清晰。

tf.keras 中提供了 compile()fit()函数方便实现上述逻辑。首先通过compile 函数指定网络使用的优化器对象、损失函数类型,评价指标等设定,这一步称为装配。

Python 复制代码
#  模型装配
network.compile(optimizer=optimizers.Adam(learning_rate=0.01),
                loss=losses.CategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'])

模型装配完成后,即可通过 fit()函数送入待训练的数据集和验证用的数据集,这一步称为模型训练。

Python 复制代码
# 指定训练集和测试集,训练5个epochs,每2个epoch验证一次
network.fit(train_db, epochs=5, validation_data=test_db, validation_freq=2)

如果只是简单的测试模型的性能,可以通过 Model.evaluate(test_db)循环测试完 test_db数据集上所有样本,并打印出性能指标

Python 复制代码
network.evaluate(test_db)

5. 模型保存与加载

在训练时间隔性地保存模型状态也是非常好的习惯,这一点对于训练大规模的网络尤其重要。一般大规模的网络需要训练数天乃至数周的时长,一旦训练过程被中断或者发生宕机等意外,之前训练的进度将全部丢失。如果能够间断地保存模型状态到文件系统,即使发生宕机等意外,也可以从最近一次的网络状态文件中恢复,从而避免浪费大量的训练时间和计算资源。因此模型的保存与加载非常重要。

仅保存网络参数


这种保存与加载网络的方式最为轻量级,文件中保存的仅仅是张量参数的数值,并没有其它额外的结构参数。但是它需要使用相同的网络结构才能够正确恢复网络状态,因此一般在拥有网络源文件的情况下使用。

Python 复制代码
print('模型参数自动保存...')
network.save_weights('weights.ckpt')

print('模拟意外情况,网络删除...')
del network

print('重新加载模型的参数...')

# 重新创建相同的网络结构 
network = Sequential([layers.Dense(256, activation='relu'),
                      layers.Dense(128, activation='relu'),
                      layers.Dense(64, activation='relu'),
                      layers.Dense(32, activation='relu'),
                      layers.Dense(10)])

network.compile(optimizer=optimizers.Adam(learning_rate=0.01),
                loss=tf.losses.CategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'])
# 从参数文件中读取数据并写入当前网络
network.load_weights('weights.ckpt')

保存模型及参数


这是一种不需要网络源文件,仅仅需要模型参数文件即可恢复出网络模型的方法。通过 Model.save(path)函数可以将模型的结构以及模型的参数保存到path文件上,在不需要网络源文件的条件下,通过tf.keras.models.load_model(path)即可恢复网络结构和网络参数。

Python 复制代码
network.save('model.h5')
print('模型已自动保存...')

print('模拟意外情况,网络删除...')
del network

print('重新加载模型中...')
network = tf.keras.models.load_model('model.h5', compile=False)

学到这里就基本将手写数字识别的重点变化列举了出来,下面我们就去摩拳擦掌地试试吧!

写在最后

👍🏻点赞,你的认可是我创作的动力!

⭐收藏,你的青睐是我努力的方向!

✏️评论,你的意见是我进步的财富!

相关推荐
YRr YRr13 分钟前
深度学习:Transformer Decoder详解
人工智能·深度学习·transformer
Shy96041822 分钟前
Bert完形填空
python·深度学习·bert
老艾的AI世界40 分钟前
新一代AI换脸更自然,DeepLiveCam下载介绍(可直播)
图像处理·人工智能·深度学习·神经网络·目标检测·机器学习·ai换脸·视频换脸·直播换脸·图片换脸
浊酒南街1 小时前
吴恩达深度学习笔记:卷积神经网络(Foundations of Convolutional Neural Networks)4.9-4.10
人工智能·深度学习·神经网络·cnn
懒惰才能让科技进步2 小时前
从零学习大模型(十二)-----基于梯度的重要性剪枝(Gradient-based Pruning)
人工智能·深度学习·学习·算法·chatgpt·transformer·剪枝
没有不重的名么2 小时前
门控循环单元GRU
人工智能·深度学习·gru
love_and_hope2 小时前
Pytorch学习--神经网络--搭建小实战(手撕CIFAR 10 model structure)和 Sequential 的使用
人工智能·pytorch·python·深度学习·学习
学术头条2 小时前
AI 的「phone use」竟是这样练成的,清华、智谱团队发布 AutoGLM 技术报告
人工智能·科技·深度学习·语言模型
孙同学要努力3 小时前
《深度学习》——深度学习基础知识(全连接神经网络)
人工智能·深度学习·神经网络
喵~来学编程啦3 小时前
【论文精读】LPT: Long-tailed prompt tuning for image classification
人工智能·深度学习·机器学习·计算机视觉·论文笔记