Tensorflow2学习之MNIST数据训练和测试

前言

打算学习下Tensorflow2,当前的目标是自动识别出身份证中号码的位置,环境很久前本地搭建过,进度就是随机性了有时间了搞下。当然很久没碰就理所当然的忘了,还有个主要原因是当时找的demo没达到我想要的效果。目前还是发挥我得CV神功学的是API因为高数我还给了老师了,还有关于python和TF的API现学现卖了,这篇算入门记录了。

MNIST介绍

MNIST是一个由60000张训练样本和10000张测试样本组成的处理过的图像数据集。注意它不是图片,描述的是大小为28*28二值化的图片,二值化的意思是只有黑色和白色的单通道图片,关于颜色通道。

下载到本地很简单但是但是需要网。

ini 复制代码
# 加载mnist数据集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

然后文件会下载到这个地方。

网上找些demo训练结果注意事项

训练集训练过程看到的效果很好,但是想自己画个图加载模型后识别下基本识别不出来。网上找了一些说法。

  1. 自己画的图片大小控制不好在缩放的时候会失真。
  2. 灰度和二值化没区分控制好,训练集的数据都是二值化的。
  3. 国外写字跟国内习惯不一样。
  4. 还有提到泛化问题的我还没懂后面学会了再分享。

说下效果吧

训练

我是看的这本书《TensorFlow深度学习(带目录).pdf》。里面有最原始的API处理数据的,但是怎么保存模型我没搞懂放弃了。然后网上找了一些高阶的keras处理的说下理解。

ini 复制代码
# 加载mnist数据集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 归一化主要把数据放到0-1之间 因为二值化的图片像素点是0,255
x_train, x_test = x_train / 255.0, x_test / 255.0

# 将标签转换成独热编码 注意这里跟下面模型配置的损失函数有关系
# 独热是结果只有一位表示 1000000000 表示0
#y_train_onehot = tf.keras.utils.to_categorical(y_train)
#y_test_onehot = tf.keras.utils.to_categorical(y_test)
y_train_onehot =y_train
y_test_onehot =y_test

model = tf.keras.Sequential()
# 28*28 输入的大小
model.add(tf.keras.layers.Flatten(input_shape=(28, 28)))

# 中间隐藏层激活函数用relu
model.add(tf.keras.layers.Dense(128, activation='relu'))

#随机丢弃一些神经元
model.add(tf.keras.layers.Dropout(0.2))

# 多分类输出一般用softmax分类器
model.add(tf.keras.layers.Dense(10, activation='softmax'))

# 打印模型信息的
model.summary()

# optimizer优化器
# 方法用于在配置训练方法时,告知训练时用的优化器、损失函数和准确率评测标准
# loss函数使用交叉熵
# 顺序编码用sparse_categorical_crossentropy
# 独热编码用categorical_crossentropy
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

#开始训练
history = model.fit(x_train, y_train_onehot, epochs=20)
test_loss,test_accuracy=model.evaluate(x_test,  y_test_onehot, verbose=2)
print('Accuracy on test_dataset',test_accuracy)

#保存模型
model.save('model.h5')
print('saved total model.')

#主要打印训练后显示训练图取参数对应的
history_dict = history.history
print(history_dict.keys())

模型加载

ini 复制代码
network = tf.keras.models.load_model('model.h5')
# 调用模型进行预测识别
im = Image.open(r"9.png")  # 读取图片路径
im = im.resize((28, 28))  # 调整大小和模型输入大小一致
im = np.array(im)
print(im.shape)
p3 = im.min(axis=-1)
print(p3.shape)
# # 将白底黑字变成黑底白字   由于训练模型是这种格式
for i in range(28):
    for j in range(28):
        p3[i][j] = 255 - p3[i][j]
# 模型输出结果是每个类别的概率,取最大的概率的类别就是预测的结果
ret = network.predict((p3 / 255.0).reshape((1, 28, 28)))
print(ret)
number = np.argmax(ret, axis=1)
print(number)

自己画图识别成功的一个

相关推荐
bastgia34 分钟前
Tokenformer: 下一代Transformer架构
人工智能·机器学习·llm
菜狗woc42 分钟前
opencv-python的简单练习
人工智能·python·opencv
15年网络推广青哥1 小时前
国际抖音TikTok矩阵运营的关键要素有哪些?
大数据·人工智能·矩阵
isolusion1 小时前
Springboot的创建方式
java·spring boot·后端
weixin_387545641 小时前
探索 AnythingLLM:借助开源 AI 打造私有化智能知识库
人工智能
zjw_rp1 小时前
Spring-AOP
java·后端·spring·spring-aop
TodoCoder2 小时前
【编程思想】CopyOnWrite是如何解决高并发场景中的读写瓶颈?
java·后端·面试
engchina2 小时前
如何在 Python 中忽略烦人的警告?
开发语言·人工智能·python
paixiaoxin2 小时前
CV-OCR经典论文解读|An Empirical Study of Scaling Law for OCR/OCR 缩放定律的实证研究
人工智能·深度学习·机器学习·生成对抗网络·计算机视觉·ocr·.net
凌虚2 小时前
Kubernetes APF(API 优先级和公平调度)简介
后端·程序员·kubernetes