目录
1、环境条件
- pycharm编译器
- python3.0环境
- tensorflow2.0依赖
- matplotlib依赖(用于画图)
2、代码实现
python
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.preprocessing import image
import numpy as np
import matplotlib.pyplot as plt
# 加载并预处理 MNIST 数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
print(x_train)
print(x_test)
# 构建 LeNet-5 模型
model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(32, kernel_size=(5, 5), activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Conv2D(64, kernel_size=(5, 5), activation='relu'),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(120, activation='relu'),
tf.keras.layers.Dense(84, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# 重塑数据以适应模型
x_train = x_train.reshape(-1, 28, 28, 1)
x_test = x_test.reshape(-1, 28, 28, 1)
# 训练模型
model.fit(x_train, y_train, epochs=5)
# 评估模型
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f'测试准确率: {test_acc}')
# 保存模型
model.save('lenet-5_model.h5')
print('模型已保存至 lenet-5_model.h5')
# 加载模型
loaded_model = tf.keras.models.load_model('lenet-5_model.h5')
print('模型已加载')
# 加载并预处理本地图片
def load_and_preprocess_image(image_path):
img = image.load_img(image_path, color_mode="grayscale", target_size=(28, 28))
img_array = image.img_to_array(img)
img_array = img_array / 255.0 # 归一化
img_array = np.expand_dims(img_array, axis=0) # 添加批次维度
return img_array
# 预测本地图片
image_path = '4.png' # 替换为你的本地图片路径
img_array = load_and_preprocess_image(image_path)
# 使用加载的模型进行预测
predictions = loaded_model.predict(img_array)
predicted_label = np.argmax(predictions)
# 打印预测结果
print(f'预测结果: {predicted_label}')
# 显示图片
plt.imshow(img_array[0, :, :, 0], cmap='gray')
plt.title(f'预测结果: {predicted_label}')
plt.show()
解释:image_path为本地图片路径,通过model.save()方法实现模型的保存功能,下次预测使用的时候直接使用训练好的模型即可。下面将给出可直接预测的代码:
python
import tensorflow as tf
from tensorflow.keras.preprocessing import image
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
# 加载模型
loaded_model = tf.keras.models.load_model('lenet-5_model.h5')
print('模型已加载')
# 加载并预处理本地图片
def load_and_preprocess_image(image_path):
img = image.load_img(image_path, color_mode="grayscale", target_size=(28, 28))
img_array = image.img_to_array(img)
img_array = img_array / 255.0 # 归一化
img_array = np.expand_dims(img_array, axis=0) # 添加批次维度
return img_array
# 预测本地图片
image_path = '7.png' # 替换为你的本地图片路径
img_array = load_and_preprocess_image(image_path)
# 使用加载的模型进行预测
predictions = loaded_model.predict(img_array)
predicted_label = np.argmax(predictions)
# 打印预测结果
print(f'预测结果: {predicted_label}')
# 设置支持中文的字体
font_path = "C:/Windows/Fonts/simhei.ttf" # 替换为你的字体路径,例如 SimHei.ttf
font_prop = FontProperties(fname=font_path)
# 显示图片
plt.imshow(img_array[0, :, :, 0], cmap='gray')
plt.title(f'预测结果: {predicted_label}', fontproperties=font_prop)
plt.show()
3、总结
使用tensorflow完成手写体图片的识别功能,其主要难点在安装依赖环境,其他的都是比较简单的事情。
学习之所以会想睡觉,是因为那是梦开始的地方。
ଘ(੭ˊᵕˋ)੭ (开心) ଘ(੭ˊᵕˋ)੭ (开心)ଘ(੭ˊᵕˋ)੭ (开心)ଘ(੭ˊᵕˋ)੭ (开心)ଘ(੭ˊᵕˋ)੭ (开心)
------不写代码不会凸的小刘