tensorflow跑手写体实验

目录

1、环境条件

2、代码实现

3、总结


1、环境条件

  1. pycharm编译器
  2. python3.0环境
  3. tensorflow2.0依赖
  4. matplotlib依赖(用于画图)

2、代码实现

python 复制代码
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.preprocessing import image
import numpy as np
import matplotlib.pyplot as plt

# 加载并预处理 MNIST 数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
print(x_train)
print(x_test)

# 构建 LeNet-5 模型
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, kernel_size=(5, 5), activation='relu', input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Conv2D(64, kernel_size=(5, 5), activation='relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(120, activation='relu'),
    tf.keras.layers.Dense(84, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# 重塑数据以适应模型
x_train = x_train.reshape(-1, 28, 28, 1)
x_test = x_test.reshape(-1, 28, 28, 1)

# 训练模型
model.fit(x_train, y_train, epochs=5)

# 评估模型
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f'测试准确率: {test_acc}')

# 保存模型
model.save('lenet-5_model.h5')
print('模型已保存至 lenet-5_model.h5')

# 加载模型
loaded_model = tf.keras.models.load_model('lenet-5_model.h5')
print('模型已加载')

# 加载并预处理本地图片
def load_and_preprocess_image(image_path):
    img = image.load_img(image_path, color_mode="grayscale", target_size=(28, 28))
    img_array = image.img_to_array(img)
    img_array = img_array / 255.0  # 归一化
    img_array = np.expand_dims(img_array, axis=0)  # 添加批次维度
    return img_array

# 预测本地图片
image_path = '4.png'  # 替换为你的本地图片路径
img_array = load_and_preprocess_image(image_path)

# 使用加载的模型进行预测
predictions = loaded_model.predict(img_array)
predicted_label = np.argmax(predictions)

# 打印预测结果
print(f'预测结果: {predicted_label}')

# 显示图片
plt.imshow(img_array[0, :, :, 0], cmap='gray')
plt.title(f'预测结果: {predicted_label}')
plt.show()

解释:image_path为本地图片路径,通过model.save()方法实现模型的保存功能,下次预测使用的时候直接使用训练好的模型即可。下面将给出可直接预测的代码:

python 复制代码
import tensorflow as tf
from tensorflow.keras.preprocessing import image
import numpy as np
import matplotlib.pyplot as plt

from matplotlib.font_manager import FontProperties

# 加载模型
loaded_model = tf.keras.models.load_model('lenet-5_model.h5')
print('模型已加载')


# 加载并预处理本地图片
def load_and_preprocess_image(image_path):
    img = image.load_img(image_path, color_mode="grayscale", target_size=(28, 28))
    img_array = image.img_to_array(img)
    img_array = img_array / 255.0  # 归一化
    img_array = np.expand_dims(img_array, axis=0)  # 添加批次维度
    return img_array


# 预测本地图片
image_path = '7.png'  # 替换为你的本地图片路径
img_array = load_and_preprocess_image(image_path)

# 使用加载的模型进行预测
predictions = loaded_model.predict(img_array)
predicted_label = np.argmax(predictions)

# 打印预测结果
print(f'预测结果: {predicted_label}')

# 设置支持中文的字体
font_path = "C:/Windows/Fonts/simhei.ttf"  # 替换为你的字体路径,例如 SimHei.ttf
font_prop = FontProperties(fname=font_path)

# 显示图片
plt.imshow(img_array[0, :, :, 0], cmap='gray')
plt.title(f'预测结果: {predicted_label}', fontproperties=font_prop)
plt.show()

3、总结

使用tensorflow完成手写体图片的识别功能,其主要难点在安装依赖环境,其他的都是比较简单的事情。

学习之所以会想睡觉,是因为那是梦开始的地方。

ଘ(੭ˊᵕˋ)੭ (开心) ଘ(੭ˊᵕˋ)੭ (开心)ଘ(੭ˊᵕˋ)੭ (开心)ଘ(੭ˊᵕˋ)੭ (开心)ଘ(੭ˊᵕˋ)੭ (开心)

------不写代码不会凸的小刘

相关推荐
love_summer1 分钟前
深入理解Python基础:数据类型、运算符与内存机制初探
python
沙漠的浪人1 分钟前
基于多 Agent 的 Planning-Executor 架构设计
人工智能·agent
小雪_Snow2 分钟前
Python 安装教程【使用 Python install manager】
python
光锥智能4 分钟前
高通推出全套机器人技术组合,含 Dragonwing IQ10 系列处理器
人工智能
云卓SKYDROID6 分钟前
工业吊舱图像采集与增强模块解析
人工智能·数码相机·计算机视觉·无人机·高科技·云卓科技
狮子座明仔6 分钟前
EXPLAIN:用实体摘要为RAG“开外挂“,让文档问答又快又准
人工智能
星月前端6 分钟前
基于DeepSeek API的Telegram机器人
python·机器人
狮子座明仔6 分钟前
CiteFix: 通过后处理引用校正提升RAG系统准确率
人工智能·深度学习·ai·语言模型·自然语言处理
希艾席帝恩6 分钟前
数字孪生赋能水利行业转型升级的关键路径
大数据·人工智能·数字孪生·数据可视化·数字化转型
AI 智能服务9 分钟前
第2课___结构化输出与 Prompt 设计
人工智能·机器学习·prompt