实现mnist手写数字识别

基础知识

tensorflow

TensorFlow是一个开源的机器学习框架,致力于各种数据流图的自动微分和深度神经网络的计算。简而言之,** TensorFlow帮助我们轻松地构建、训练和部署机器学习模型** 。它可以在各种平台上运行,包括桌面计算机、服务器、移动设备和嵌入式设备。

在conda中安装tensorflow

java 复制代码
conda install tensorflow

训练集和测试集

![](https://img-blog.csdnimg.cn/img_convert/11da4e88fdac0a5e0a1f1adfa852d2e7.png)

![](https://img-blog.csdnimg.cn/img_convert/98319b3f90031d4cda444faf34538ca6.png)

```java

![](https://cdn.nlark.com/yuque/0/2024/png/38629240/1728196864148-ae60ac21-fe9c-46b6-b91a-02432c6a2045.png)

<h3 id="nIzm3">models</h3>
```java
models:这是 Keras 中用于构建和管理模型的模块。它提供了两种主要的模型构建方式:
Sequential 模型:按层顺序构建的模型。
Model API:允许创建更复杂的、具备灵活拓扑结构的模型。

models的一些方法包括:

java 复制代码
Sequential:这是 Keras 中最简单的模型类型,它允许按层顺序构建神经网络。可以使用 add() 方法添加不同类型的层。
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Dense(units=64, activation='relu', input_shape=(input_dim,)))
model.add(tf.keras.layers.Dense(units=10, activation='softmax'))
java 复制代码
Model:功能性 API 允许用户构建更复杂的模型,包括多个输入和输出的模型。通过定义输入层和输出层,可以创建多分支网络。
inputs = tf.keras.Input(shape=(input_dim,))
x = tf.keras.layers.Dense(units=64, activation='relu')(inputs)
outputs = tf.keras.layers.Dense(units=10, activation='softmax')(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

                                                                
compile():在训练模型之前,需要调用 compile() 方法配置模型的学习过程,包括指定损失函数、优化器和评价指标。
model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])
java 复制代码
fit():用于训练模型,通过将训练数据输入模型进行学习。
history = model.fit(train_images, train_labels, epochs=10, validation_data=(test_images, test_labels))
java 复制代码
模型的评估
evaluate():在训练完成后,使用该方法在测试数据上评估模型的性能,返回损失值和其他指标。
test_loss, test_accuracy = model.evaluate(test_images, test_labels)
java 复制代码
模型的预测
predictions = model.predict(new_data)
java 复制代码
save():将整个模型(结构和权重)保存到磁盘,通常以 .h5 格式。

model.save('my_model.h5')

load_model():从磁盘加载保存的模型。

from tensorflow.keras.models import load_model
model = load_model('my_model.h5')
java 复制代码
summary():输出模型的概要信息,包括各层的名称、输出形状和参数数量。

model.summary()

调参

java 复制代码
trainable:用于设置某一层的可训练性,以便进行微调或迁移学习。
model.layers[0].trainable = False  # 冻结第一层

CNN

CNN是卷积神经网络

卷积神经网络 (CNN) 基本原理和公式_cnn公式-CSDN博客

机器学习算法之------卷积神经网络(CNN)原理讲解 - 知乎 (zhihu.com)

代码

导入数据集

```java import tensorflow as tf from tensorflow.keras import datasets, layers, models import matplotlib.pyplot as plt

导入mnist数据,依次分别为训练集图片、训练集标签、测试集图片、测试集标签

(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()

解析:

import tensorflow as tf 导入tensorfolw库,重命名为tf

from ... import ...:这种语法用于从某个库或模块中导入特定的部分,以减少命名冲突或不必要的全局命名。这里表示从 tensorflow.keras 模块中导入 datasets、layers 和 models 三个子模块。

tensorflow.keras:Keras 是 TensorFlow 中的一个高级API,用于快速构建和训练神经网络模型。Keras 原本是一个独立的深度学习库,但后来被集成到 TensorFlow 中,成为其一部分存储

datasets:这是 Keras 提供的一个模块,包含了一些常用的数据集加载函数。例如,mnist.load_data() 就是来自这个模块,用于加载 MNIST 手写数字数据集。

layers:这是 Keras 中的核心模块之一,包含了用于构建神经网络模型的各种层(layer)。每一层是神经网络的基本组成部分,比如全连接层(Dense)、卷积层(Conv2D)、池化层(MaxPooling2D)等。你可以用 layers.Dense() 来创建一个全连接层。

models:这是 Keras 中用于构建和管理模型的模块。它提供了两种主要的模型构建方式:

Sequential 模型:按层顺序构建的模型。

Model API:允许创建更复杂的、具备灵活拓扑结构的模型。

<h3 id="kqum6">处理数据集</h3>
```java
# 将像素的值标准化至0到1的区间内。(对于灰度图片来说,每个像素最大值是255,每个像素最小值是0,也就是直接除以255就可以完成归一化。)
train_images, test_images = train_images / 255.0, test_images / 255.0
# 查看数据维数信息
train_images.shape,test_images.shape,train_labels.shape,test_labels.shape

机器学习和深度学习中,对输入数据进行归一化是一个常见且重要的预处理步骤。归一化将数据缩放到一个标准范围(通常是 0 到 1 之间)

对于灰度图像,每个像素的取值范围通常是 0 到 255(对于 8 位图像)。当我们将像素值除以 255 时,像素值将被缩放到 0 到 1 的范围内,这样做的好处包括:

易于理解的数值范围:0 到 1 的范围更直观且容易处理,尤其在使用某些激活函数时(如 sigmoid),它们在 0 到 1 范围内工作良好。

与模型的初始化和学习相匹配:归一化后的数据与模型权重的初始化范围相匹配,尤其是在使用带有随机初始化的深度学习模型时。

展示数据集

```java # 将数据集前20个图片数据可视化显示 # 进行图像大小为20宽、10长的绘图(单位为英寸inch) plt.figure(figsize=(20,10)) # 遍历MNIST数据集下标数值0~49 for i in range(20): # 将整个figure分成5行10列,绘制第i+1个子图。 plt.subplot(2,10,i+1) # 设置不显示x轴刻度 plt.xticks([]) # 设置不显示y轴刻度 plt.yticks([]) # 设置不显示子图网格线 plt.grid(False) # 图像展示,cmap为颜色图谱,"plt.cm.binary"为matplotlib.cm中的色表 plt.imshow(train_images[i], cmap=plt.cm.binary) # 设置x轴标签显示为图片对应的数字 plt.xlabel(train_labels[i]) # 显示图片 plt.show() ```

创建卷积神经网络

```java # 创建并设置卷积神经网络 # 卷积层:通过卷积操作对输入图像进行降维和特征抽取 # 池化层:是一种非线性形式的下采样。主要用于特征降维,压缩数据和参数的数量,减小过拟合,同时提高模型的鲁棒性。 # 全连接层:在经过几个卷积和池化层之后,神经网络中的高级推理通过全连接层来完成。 model = models.Sequential([ # 设置二维卷积层1,设置32个3*3卷积核,activation参数将激活函数设置为ReLu函数,input_shape参数将图层的输入形状设置为(28, 28, 1) # ReLu函数作为激活励函数可以增强判定函数和整个神经网络的非线性特性,而本身并不会改变卷积层 # 相比其它函数来说,ReLU函数更受青睐,这是因为它可以将神经网络的训练速度提升数倍,而并不会对模型的泛化准确度造成显著影响。 layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)), #池化层1,2*2采样 layers.MaxPooling2D((2, 2)), # 设置二维卷积层2,设置64个3*3卷积核,activation参数将激活函数设置为ReLu函数 layers.Conv2D(64, (3, 3), activation='relu'), #池化层2,2*2采样 layers.MaxPooling2D((2, 2)),

layers.Flatten(),                    #Flatten层,连接卷积层与全连接层
layers.Dense(64, activation='relu'), #全连接层,特征进一步提取,64为输出空间的维数,activation参数将激活函数设置为ReLu函数
layers.Dense(10)                     #输出层,输出预期结果,10为输出空间的维数

])

打印网络结构

model.summary()

<h3 id="DDhAH">编译模型</h3>
```java
"""
这里设置优化器、损失函数以及metrics
"""
# model.compile()方法用于在配置训练方法时,告知训练时用的优化器、损失函数和准确率评测标准
model.compile(
	# 设置优化器为Adam优化器
    optimizer='adam',
	# 设置损失函数为交叉熵损失函数(tf.keras.losses.SparseCategoricalCrossentropy())
    # from_logits为True时,会将y_pred转化为概率(用softmax),否则不进行转换,通常情况下用True结果更稳定
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    # 设置性能指标列表,将在模型训练时监控列表中的指标
    metrics=['accuracy'])

训练模型

```java history = model.fit( # 输入训练集图片 train_images, # 输入训练集标签 train_labels, # 设置10个epoch,每一个epoch都将会把所有的数据输入模型完成一次训练。 epochs=10, # 设置验证集 validation_data=(test_images, test_labels)) ```

预测

```java pre = model.predict(test_images) # 对所有测试图片进行预测 pre[1] # 输出第一张图片的预测结果 ```

识别出来应该是对应的1。

之前不懂训练的模型怎么就保存下来了的,刚刚已经保存好的模型在另一个文件中调用。

通过 load_model 加载已经训练好的模型,已经训练好的模型是h5保存在文件中的。

java 复制代码
from tensorflow.keras.models import load_model
model = load_model('mnist.h5')

参考

pytorch MNIST 手写数字识别 + 使用自己的测试集 + 数据增强后再训练_经典数据集-手写数字识别pytorch-CSDN博客

深度学习--TensorFlow(项目)识别自己的手写数字(基于CNN卷积神经网络)_cnn卷积神经网络基于te数据集-CSDN博客

相关推荐
陈苏同学几秒前
4. 将pycharm本地项目同步到(Linux)服务器上——深度学习·科研实践·从0到1
linux·服务器·ide·人工智能·python·深度学习·pycharm
FL162386312929 分钟前
[深度学习][python]yolov11+bytetrack+pyqt5实现目标追踪
深度学习·qt·yolo
羊小猪~~36 分钟前
深度学习项目----用LSTM模型预测股价(包含LSTM网络简介,代码数据均可下载)
pytorch·python·rnn·深度学习·机器学习·数据分析·lstm
龙的爹23331 小时前
论文 | Model-tuning Via Prompts Makes NLP Models Adversarially Robust
人工智能·gpt·深度学习·语言模型·自然语言处理·prompt
工业机器视觉设计和实现1 小时前
cnn突破四(生成卷积核与固定核对比)
人工智能·深度学习·cnn
醒了就刷牙1 小时前
58 深层循环神经网络_by《李沐:动手学深度学习v2》pytorch版
pytorch·rnn·深度学习
985小水博一枚呀1 小时前
【对于Python爬虫的理解】数据挖掘、信息聚合、价格监控、新闻爬取等,附代码。
爬虫·python·深度学习·数据挖掘
萱仔学习自我记录2 小时前
微调大语言模型——超详细步骤
人工智能·深度学习·机器学习
Eric.Lee20213 小时前
音频文件重采样 - python 实现
人工智能·python·深度学习·算法·audio·音频重采样