Tensorflow2学习之MNIST数据训练和测试

前言

打算学习下Tensorflow2,当前的目标是自动识别出身份证中号码的位置,环境很久前本地搭建过,进度就是随机性了有时间了搞下。当然很久没碰就理所当然的忘了,还有个主要原因是当时找的demo没达到我想要的效果。目前还是发挥我得CV神功学的是API因为高数我还给了老师了,还有关于python和TF的API现学现卖了,这篇算入门记录了。

MNIST介绍

MNIST是一个由60000张训练样本和10000张测试样本组成的处理过的图像数据集。注意它不是图片,描述的是大小为28*28二值化的图片,二值化的意思是只有黑色和白色的单通道图片,关于颜色通道。

下载到本地很简单但是但是需要网。

ini 复制代码
# 加载mnist数据集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

然后文件会下载到这个地方。

网上找些demo训练结果注意事项

训练集训练过程看到的效果很好,但是想自己画个图加载模型后识别下基本识别不出来。网上找了一些说法。

  1. 自己画的图片大小控制不好在缩放的时候会失真。
  2. 灰度和二值化没区分控制好,训练集的数据都是二值化的。
  3. 国外写字跟国内习惯不一样。
  4. 还有提到泛化问题的我还没懂后面学会了再分享。

说下效果吧

训练

我是看的这本书《TensorFlow深度学习(带目录).pdf》。里面有最原始的API处理数据的,但是怎么保存模型我没搞懂放弃了。然后网上找了一些高阶的keras处理的说下理解。

ini 复制代码
# 加载mnist数据集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 归一化主要把数据放到0-1之间 因为二值化的图片像素点是0,255
x_train, x_test = x_train / 255.0, x_test / 255.0

# 将标签转换成独热编码 注意这里跟下面模型配置的损失函数有关系
# 独热是结果只有一位表示 1000000000 表示0
#y_train_onehot = tf.keras.utils.to_categorical(y_train)
#y_test_onehot = tf.keras.utils.to_categorical(y_test)
y_train_onehot =y_train
y_test_onehot =y_test

model = tf.keras.Sequential()
# 28*28 输入的大小
model.add(tf.keras.layers.Flatten(input_shape=(28, 28)))

# 中间隐藏层激活函数用relu
model.add(tf.keras.layers.Dense(128, activation='relu'))

#随机丢弃一些神经元
model.add(tf.keras.layers.Dropout(0.2))

# 多分类输出一般用softmax分类器
model.add(tf.keras.layers.Dense(10, activation='softmax'))

# 打印模型信息的
model.summary()

# optimizer优化器
# 方法用于在配置训练方法时,告知训练时用的优化器、损失函数和准确率评测标准
# loss函数使用交叉熵
# 顺序编码用sparse_categorical_crossentropy
# 独热编码用categorical_crossentropy
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

#开始训练
history = model.fit(x_train, y_train_onehot, epochs=20)
test_loss,test_accuracy=model.evaluate(x_test,  y_test_onehot, verbose=2)
print('Accuracy on test_dataset',test_accuracy)

#保存模型
model.save('model.h5')
print('saved total model.')

#主要打印训练后显示训练图取参数对应的
history_dict = history.history
print(history_dict.keys())

模型加载

ini 复制代码
network = tf.keras.models.load_model('model.h5')
# 调用模型进行预测识别
im = Image.open(r"9.png")  # 读取图片路径
im = im.resize((28, 28))  # 调整大小和模型输入大小一致
im = np.array(im)
print(im.shape)
p3 = im.min(axis=-1)
print(p3.shape)
# # 将白底黑字变成黑底白字   由于训练模型是这种格式
for i in range(28):
    for j in range(28):
        p3[i][j] = 255 - p3[i][j]
# 模型输出结果是每个类别的概率,取最大的概率的类别就是预测的结果
ret = network.predict((p3 / 255.0).reshape((1, 28, 28)))
print(ret)
number = np.argmax(ret, axis=1)
print(number)

自己画图识别成功的一个

相关推荐
CareyWYR1 分钟前
每周AI论文速递(250407-250411)
人工智能
顾琬清14 分钟前
Linux系统Docker部署开源在线协作笔记Trilium Notes与远程访问详细教程
开发语言·后端·golang
李白的粉14 分钟前
基于springboot的个人博客系统
java·spring boot·后端·毕业设计·课程设计·源代码·个人博客系统
Charlie__ZS16 分钟前
Spring其它知识点
java·后端·spring
Aska_Lv21 分钟前
mysql---主从延时问题
后端
机器之心1 小时前
魔改AlphaZero后,《我的世界》AI老玩家问世,干活不用下指令
人工智能
Romantic Rose1 小时前
你所拨打的电话是空号?手机状态查询API
大数据·人工智能
细心的莽夫1 小时前
Docker学习笔记
运维·笔记·后端·学习·docker·容器
羊小猪~~1 小时前
深度学习基础--CNN经典网络之InceptionV1研究与复现(pytorch)
网络·人工智能·pytorch·深度学习·神经网络·机器学习·cnn
省长1 小时前
Sa-Token v1.42.0 发布 🚀,新增 API Key、TOTP 验证码、RefreshToken 反查等能力
java·后端·开源