模型训练识别手写数字(二)

模型训练识别手写数字(一)使用手写数字图像进行模型测试

一、生成手写数字图像

1. 导入所需库

python 复制代码
import cv2
import numpy as np
import os

cv2用于计算机视觉操作。

numpy用于处理数组和图像数据。

os用于文件和目录操作。

2. 初始化画布

python 复制代码
canvas = np.zeros((280, 280), dtype="uint8")

创建一个280x280的黑色画布(值为0表示黑色)。

3. 鼠标回调函数

python 复制代码
def draw(event, x, y, flags, param):
    if event == cv2.EVENT_MOUSEMOVE and flags == cv2.EVENT_FLAG_LBUTTON:
        cv2.circle(canvas, (x, y), 5, 255, -1)

draw函数在鼠标移动时绘制白色圆点(值为255)到画布上。圆点的半径为5像素。

4. 创建窗口并设置回调

python 复制代码
cv2.namedWindow("Canvas")
cv2.setMouseCallback("Canvas", draw)

创建一个名为"Canvas"的窗口,并设置鼠标回调函数。

5. 主循环

python 复制代码
while True:
    cv2.imshow("Canvas", canvas)
    key = cv2.waitKey(1) & 0xFF

不断显示画布,等待用户输入。

6. 处理用户输入

python 复制代码
if key == ord('c'):
    canvas = np.zeros((280, 280), dtype="uint8")
elif key == ord('q'):
    break

按 'c' 键清空画布,按 'q' 键退出循环。

7. 保存图像目录

python 复制代码
save_dir = "Data"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

检查并创建保存图像的目录。

8. 保存图像文件

python 复制代码
save_path = os.path.join(save_dir, "handwritten_digit.png")
if cv2.imwrite(save_path, canvas):
    print(f"Image saved successfully at {save_path}")
else:
    print("Failed to save image.")

将画布保存为PNG文件,并输出保存状态。

9. 关闭窗口

python 复制代码
cv2.destroyAllWindows()

关闭所有OpenCV窗口。

二、调用训练的模型进行测试

1. 导入所需库

python 复制代码
import cv2
import matplotlib.pyplot as plt
import numpy as np
from keras.api.models import load_model

cv2用于图像处理。

matplotlib.pyplot用于可视化结果。

numpy用于数值计算。

load_model用于加载训练好的Keras模型。

2. 加载训练的模型

python 复制代码
model = load_model("my_model.h5")

从文件中加载训练好的模型。

3. 加载手写数字图像

python 复制代码
original_img = cv2.imread("Data/handwritten_digit.png", cv2.IMREAD_GRAYSCALE)

读取手写数字图像,并以灰度模式加载。

4. 处理图像用于预测

python 复制代码
img = cv2.resize(original_img, (28, 28))  # 调整为28x28大小
img = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY_INV)[1]  # 二值化
img = img.astype('float32') / 255  # 归一化

将图像调整为28x28像素,这是模型所需的输入尺寸。

使用阈值处理将图像二值化(黑白),并将背景设置为白色,手写数字为黑色。

将图像数据归一化到[0, 1]范围。

5. 调整图像形状以便于预测

python 复制代码
img = img.flatten()  # 展平为一维数组
img = img.reshape(1, 784)  # 调整形状为 (1, 784)

将28x28的图像展平为784个像素值的单行数组,以适应模型的输入格式。

6. 进行预测

python 复制代码
predictions = model.predict(img)
predicted_class = np.argmax(predictions, axis=1)

7. 可视化预测结果

python 复制代码
plt.figure(figsize=(6, 6))

# 显示原图
plt.imshow(original_img, cmap='gray', aspect='equal')  # 使用原始图像
plt.title(f'Predicted: {predicted_class[0]}', fontsize=14)
plt.axis('off')

plt.tight_layout()
plt.show()

创建一个图形窗口,并显示原始图像。

在标题中显示模型预测的类别。

使用tight_layout()优化图形布局,并显示图形。

手写8,预测却是2;说明模型在训练集上表现良好,但在测试却表现差。

目前使用的是一个全连接神经网络(Feedforward Neural Network)。这个网络的结构通常包括以下几个部分:

  1. 输入层:接受输入数据,例如在你的例子中是手写数字的像素值。
  2. 隐藏层:通过全连接的方式进行计算,使用激活函数(如 ReLU)引入非线性。
  3. 输出层:生成预测结果,通常使用 softmax 激活函数进行分类。

全连接神经网络在处理图像时通常需要将输入图像展平(flatten),这可能导致对空间特征的捕捉不够有效,因此卷积神经网络(CNN)更适合图像数据,因为它们能够利用卷积层自动提取空间特征,从而提高分类性能。

相关推荐
newxtc7 分钟前
【昆明市不动产登记中心-注册安全分析报告】
人工智能·安全
techdashen8 分钟前
圆桌讨论:Coding Agent or AI IDE 的现状和未来发展
ide·人工智能
CV实验室1 小时前
TIP 2025 | 哈工大&哈佛等提出 TripleMixer:攻克雨雪雾干扰的3D点云去噪网络!
人工智能·计算机视觉·3d·论文
余俊晖2 小时前
一套针对金融领域多模态问答的自适应多层级RAG框架-VeritasFi
人工智能·金融·rag
码农阿树2 小时前
视频解析转换耗时—OpenCV优化摸索路
人工智能·opencv·音视频
伏小白白白3 小时前
【论文精度-2】求解车辆路径问题的神经组合优化算法:综合展望(Yubin Xiao,2025)
人工智能·算法·机器学习
应用市场3 小时前
OpenCV编程入门:从零开始的计算机视觉之旅
人工智能·opencv·计算机视觉
星域智链4 小时前
宠物智能用品:当毛孩子遇上 AI,是便利还是过度?
人工智能·科技·学习·宠物
taxunjishu4 小时前
DeviceNet 转 MODBUS TCP罗克韦尔 ControlLogix PLC 与上位机在汽车零部件涂装生产线漆膜厚度精准控制的通讯配置案例
人工智能·区块链·工业物联网·工业自动化·总线协议
说私域4 小时前
基于多模态AI技术的传统行业智能化升级路径研究——以开源AI大模型、AI智能名片与S2B2C商城小程序为例
人工智能·小程序·开源