TensorFlow——Keras 框架

摘要:本文介绍了使用Keras框架在TensorFlow上构建卷积神经网络(CNN)处理MNIST手写数字识别的完整流程。首先加载并预处理数据,包括维度调整、归一化和独热编码;然后构建包含两个卷积层、池化层、Dropout层和全连接层的序贯模型;接着使用交叉熵损失和Adam优化器编译模型;经过10轮训练后,模型在测试集上达到99.1%的准确率。整个过程展示了Keras简化深度学习模型开发的优势,包括直观的API设计、灵活的层配置和高效的训练流程。

目录

[TensorFlow------Keras 框架](#TensorFlow——Keras 框架)

[利用 Keras 构建深度学习模型的八大步骤](#利用 Keras 构建深度学习模型的八大步骤)

步骤一:加载并预处理数据

步骤二:定义模型架构

步骤三:编译模型

步骤四:训练模型

术语备注


TensorFlow------Keras 框架

Keras 是一款轻量易用的高级 Python 库,运行在 TensorFlow 框架之上。该库的设计核心是帮助开发者理解深度学习相关技术,比如为神经网络搭建网络层,同时兼顾维度形态与数学细节的相关概念。

Keras 可搭建的模型框架主要分为以下两种类型:

  • 序贯式 API(Sequential API)
  • 函数式 API(Functional API)

利用 Keras 构建深度学习模型的八大步骤

  1. 加载数据
  2. 对加载的数据进行预处理
  3. 定义模型结构
  4. 编译模型
  5. 训练模型
  6. 评估模型性能
  7. 执行所需的预测任务
  8. 保存模型

本文将使用 Jupyter 笔记本完成代码运行与结果输出,具体操作步骤如下:

步骤一:加载并预处理数据

这是运行深度学习模型的首要步骤,先导入相关库和模块,再完成数据的加载与预处理。

python 复制代码
import warnings
warnings.filterwarnings('ignore')
import numpy as np
np.random.seed(123)  # 固定随机种子,保证实验可复现
from keras.models import Sequential
from keras.layers import Flatten, MaxPool2D, Conv2D, Dense, Reshape, Dropout
from keras.utils import np_utils
# 后端使用TensorFlow
from keras.datasets import mnist

# 加载已打乱的MNIST手写数字数据集,划分为训练集和测试集
(X_train, y_train), (X_test, y_test) = mnist.load_data()
# 重塑训练集数据维度,适配卷积层输入
X_train = X_train.reshape(X_train.shape[0], 28, 28, 1)
# 重塑测试集数据维度,适配卷积层输入
X_test = X_test.reshape(X_test.shape[0], 28, 28, 1)

# 将数据类型转换为32位浮点型
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
# 数据归一化,将像素值缩放到0-1区间
X_train /= 255
X_test /= 255

# 将标签进行独热编码,适配多分类任务
Y_train = np_utils.to_categorical(y_train, 10)
Y_test = np_utils.to_categorical(y_test, 10)
步骤二:定义模型架构

采用序贯式模型搭建卷积神经网络结构:

python 复制代码
model = Sequential()
# 添加卷积层,32个3×3卷积核,激活函数为ReLU,指定输入维度为28×28×1
model.add(Conv2D(32, 3, 3, activation ='relu', input_shape = (28,28,1)))
# 再次添加卷积层,提取更深层特征
model.add(Conv2D(32, 3, 3, activation ='relu'))
# 添加最大池化层,2×2池化窗口,降维并保留关键特征
model.add(MaxPool2D(pool_size = (2,2)))
# 添加Dropout层,随机丢弃25%的神经元,防止过拟合
model.add(Dropout(0.25))
# 展平层,将多维特征映射为一维,连接卷积层与全连接层
model.add(Flatten())
# 全连接层,128个神经元,激活函数为ReLU
model.add(Dense(128,  activation = 'relu'))
# 再次添加Dropout层,随机丢弃50%的神经元,进一步防止过拟合
model.add(Dropout(0.5))
# 输出层,10个神经元,softmax激活函数,输出各分类的概率
model.add(Dense(10,  activation = 'softmax'))
步骤三:编译模型

配置模型的损失函数、优化器和评估指标,为训练做准备:

python 复制代码
# 损失函数选用交叉熵损失,优化器为Adam,评估指标为准确率
model.compile(loss = 'categorical_crossentropy', optimizer = 'adam', metrics = ['accuracy'])
步骤四:训练模型

使用训练集数据对模型进行训练,设置训练参数:

python 复制代码
# 批次大小32,训练轮数10,显示训练过程
model.fit(X_train, Y_train, batch_size = 32, epochs = 10, verbose = 1)

训练过程的迭代输出结果如下:

plaintext

python 复制代码
第1轮/共10轮 60000/60000 [==============================] - 65s - 损失值:0.2124 - 准确率:0.9345
第2轮/共10轮 60000/60000 [==============================] - 62s - 损失值:0.0893 - 准确率:0.9740
第3轮/共10轮 60000/60000 [==============================] - 58s - 损失值:0.0665 - 准确率:0.9802
第4轮/共10轮 60000/60000 [==============================] - 62s - 损失值:0.0571 - 准确率:0.9830
第5轮/共10轮 60000/60000 [==============================] - 62s - 损失值:0.0474 - 准确率:0.9855
第6轮/共10轮 60000/60000 [==============================] - 59s - 损失值:0.0416 - 准确率:0.9871
第7轮/共10轮 60000/60000 [==============================] - 61s - 损失值:0.0380 - 准确率:0.9877
第8轮/共10轮 60000/60000 [==============================] - 63s - 损失值:0.0333 - 准确率:0.9895
第9轮/共10轮 60000/60000 [==============================] - 64s - 损失值:0.0325 - 准确率:0.9898
第10轮/共10轮 60000/60000 [==============================] - 60s - 损失值:0.0284 - 准确率:0.9910

术语备注

  1. Sequential API:序贯式 API,是 Keras 中最简单的模型构建方式,适用于层与层之间依次连接的线性模型
  2. Functional API:函数式 API,更灵活的模型构建方式,可搭建多输入、多输出、带残差连接的复杂网络
  3. one-hot encoding:独热编码,将离散型标签转换为二进制向量,避免标签间的数值大小干扰模型训练
  4. Dropout:随机失活,深度学习中常用的正则化方法,通过随机丢弃部分神经元,解决模型过拟合问题
  5. Adam:一种自适应学习率优化器,结合了动量法和 RMSprop 的优点,收敛速度快且稳定性好
  6. softmax:归一化指数函数,将神经网络的输出转换为 0-1 之间的概率值,且所有类别概率之和为 1,适用于多分类任务
相关推荐
小陈Coding2 小时前
AI编程助手如何提升开发效率
人工智能·ai·软件开发·代码生成·编程助手·效率提升·技术文章
小王毕业啦2 小时前
2011-2024年 省、市北京大学数字普惠金融指数(xlsx)
大数据·人工智能·金融·数据挖掘·数据分析·社科数据·经管数据
Bruce_Liuxiaowei2 小时前
面对AI时代,关于“动手能力”的思索
人工智能
说私域2 小时前
流量思维向长效思维转型:开源链动2+1模式AI智能名片小程序赋能私域电商品牌建设
人工智能·小程序·开源·产品运营·私域运营
懒惰的bit2 小时前
Python入门学习记录
python·学习
weixin_446260852 小时前
[特殊字符]提升强化学习效率的开源框架——slime
人工智能
MaoziShan2 小时前
[WACV‘26] 不用给每一帧“打关键点”,也能做出可动画的3D狗:4D-Animal 把成本从“人工标注”转移到“密集线索 + 工具链”
人工智能·3d
米羊1212 小时前
Spring 框架漏洞
开发语言·python
二十雨辰2 小时前
[python]-闭包和装饰器
python