Tensorflow2学习之MNIST数据训练和测试

前言

打算学习下Tensorflow2,当前的目标是自动识别出身份证中号码的位置,环境很久前本地搭建过,进度就是随机性了有时间了搞下。当然很久没碰就理所当然的忘了,还有个主要原因是当时找的demo没达到我想要的效果。目前还是发挥我得CV神功学的是API因为高数我还给了老师了,还有关于python和TF的API现学现卖了,这篇算入门记录了。

MNIST介绍

MNIST是一个由60000张训练样本和10000张测试样本组成的处理过的图像数据集。注意它不是图片,描述的是大小为28*28二值化的图片,二值化的意思是只有黑色和白色的单通道图片,关于颜色通道。

下载到本地很简单但是但是需要网。

ini 复制代码
# 加载mnist数据集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

然后文件会下载到这个地方。

网上找些demo训练结果注意事项

训练集训练过程看到的效果很好,但是想自己画个图加载模型后识别下基本识别不出来。网上找了一些说法。

  1. 自己画的图片大小控制不好在缩放的时候会失真。
  2. 灰度和二值化没区分控制好,训练集的数据都是二值化的。
  3. 国外写字跟国内习惯不一样。
  4. 还有提到泛化问题的我还没懂后面学会了再分享。

说下效果吧

训练

我是看的这本书《TensorFlow深度学习(带目录).pdf》。里面有最原始的API处理数据的,但是怎么保存模型我没搞懂放弃了。然后网上找了一些高阶的keras处理的说下理解。

ini 复制代码
# 加载mnist数据集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 归一化主要把数据放到0-1之间 因为二值化的图片像素点是0,255
x_train, x_test = x_train / 255.0, x_test / 255.0

# 将标签转换成独热编码 注意这里跟下面模型配置的损失函数有关系
# 独热是结果只有一位表示 1000000000 表示0
#y_train_onehot = tf.keras.utils.to_categorical(y_train)
#y_test_onehot = tf.keras.utils.to_categorical(y_test)
y_train_onehot =y_train
y_test_onehot =y_test

model = tf.keras.Sequential()
# 28*28 输入的大小
model.add(tf.keras.layers.Flatten(input_shape=(28, 28)))

# 中间隐藏层激活函数用relu
model.add(tf.keras.layers.Dense(128, activation='relu'))

#随机丢弃一些神经元
model.add(tf.keras.layers.Dropout(0.2))

# 多分类输出一般用softmax分类器
model.add(tf.keras.layers.Dense(10, activation='softmax'))

# 打印模型信息的
model.summary()

# optimizer优化器
# 方法用于在配置训练方法时,告知训练时用的优化器、损失函数和准确率评测标准
# loss函数使用交叉熵
# 顺序编码用sparse_categorical_crossentropy
# 独热编码用categorical_crossentropy
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

#开始训练
history = model.fit(x_train, y_train_onehot, epochs=20)
test_loss,test_accuracy=model.evaluate(x_test,  y_test_onehot, verbose=2)
print('Accuracy on test_dataset',test_accuracy)

#保存模型
model.save('model.h5')
print('saved total model.')

#主要打印训练后显示训练图取参数对应的
history_dict = history.history
print(history_dict.keys())

模型加载

ini 复制代码
network = tf.keras.models.load_model('model.h5')
# 调用模型进行预测识别
im = Image.open(r"9.png")  # 读取图片路径
im = im.resize((28, 28))  # 调整大小和模型输入大小一致
im = np.array(im)
print(im.shape)
p3 = im.min(axis=-1)
print(p3.shape)
# # 将白底黑字变成黑底白字   由于训练模型是这种格式
for i in range(28):
    for j in range(28):
        p3[i][j] = 255 - p3[i][j]
# 模型输出结果是每个类别的概率,取最大的概率的类别就是预测的结果
ret = network.predict((p3 / 255.0).reshape((1, 28, 28)))
print(ret)
number = np.argmax(ret, axis=1)
print(number)

自己画图识别成功的一个

相关推荐
代码之光_19804 分钟前
SpringBoot校园资料分享平台:设计与实现
java·spring boot·后端
工业机器视觉设计和实现11 分钟前
cnn突破四(生成卷积核与固定核对比)
人工智能·深度学习·cnn
编程老船长17 分钟前
第26章 Java操作Mongodb实现数据持久化
数据库·后端·mongodb
IT果果日记38 分钟前
DataX+Crontab实现多任务顺序定时同步
后端
想要打 Acm 的小周同学呀1 小时前
实现mnist手写数字识别
深度学习·tensorflow·实现mnist手写数字识别
我算是程序猿1 小时前
用AI做电子萌宠,快速涨粉变现
人工智能·stable diffusion·aigc
萱仔学习自我记录1 小时前
微调大语言模型——超详细步骤
人工智能·深度学习·机器学习
湘大小菜鸡1 小时前
NLP进阶(一)
人工智能·自然语言处理
XiaoLiuLB2 小时前
最佳语音识别 Whisper-large-v3-turbo 上线,速度更快(本地安装 )
人工智能·whisper·语音识别
哪 吒2 小时前
吊打ChatGPT4o!大学生如何用上原版O1辅助论文写作(附论文教程)
人工智能·ai·自然语言处理·chatgpt·aigc