tensorflow + pygame 手写数字识别的小游戏

起因, 目的:

很久之前,一个客户的作业,我帮忙写的。

今天删项目,觉得比较简洁,发出来给大家看看。

效果图:

1. 训练模型的代码
python 复制代码
import sys
import tensorflow as tf

# Use MNIST handwriting dataset
mnist = tf.keras.datasets.mnist

# Prepare data for training
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
y_train = tf.keras.utils.to_categorical(y_train)
y_test = tf.keras.utils.to_categorical(y_test)
x_train = x_train.reshape(
    x_train.shape[0], x_train.shape[1], x_train.shape[2], 1
)
x_test = x_test.reshape(
    x_test.shape[0], x_test.shape[1], x_test.shape[2], 1
)

"""
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation='softmax')
])

"""

# Create a convolutional neural network
model = tf.keras.models.Sequential([

    # 1.  Convolutional layer. Learn 32 filters using a 3x3 kernel, activation function is relu, input shape (28,28,1)
    tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),

    #2. Max-pooling layer, using 2x2 pool size
    tf.keras.layers.MaxPooling2D((2, 2)),

    #3.  Flatten units
    tf.keras.layers.Flatten(),

    #4. Add a hidden layer with dropout,
    tf.keras.layers.Dropout(0.2),

    #5. Add an output layer with output units for all 10 digits, activation function is softmax
    tf.keras.layers.Dense(10, activation='softmax')
    
])


# Train neural network
model.compile(
    optimizer="adam",
    loss="categorical_crossentropy",
    metrics=["accuracy"]
)
model.fit(x_train, y_train, epochs=10)

# Evaluate neural network performance
model.evaluate(x_test,  y_test, verbose=2)

# Save model to file
if len(sys.argv) == 2:
    filename = sys.argv[1]
    model.save(filename)
    print(f"Model saved to {filename}.")

"""
Run this code:  python handwriting.py model_1.pth

output:

1875/1875 [==============================] - 10s 5ms/step - loss: 0.0413 - accuracy: 0.9873
Epoch 8/10
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0385 - accuracy: 0.9877
Epoch 9/10
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0338 - accuracy: 0.9898
Epoch 10/10
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0319 - accuracy: 0.9900
313/313 - 1s - loss: 0.0511 - accuracy: 0.9845 - 718ms/epoch - 2ms/step
Model saved to model_1.pth.
"""
2. 运行小游戏, 进行识别

从命令行运行:

python recognition.py model.h5

python 复制代码
import numpy as np
import pygame
import sys
import tensorflow as tf
import time

"""
run this code:
python recognition.py  model_1.pth 
 
or,  
 
python recognition.py  model.h5 

output:
"""


print("len(sys.argv): ", len(sys.argv))

# Check command-line arguments
if len(sys.argv) != 2:
    print("Usage: python recognition.py model")
    sys.exit()


model = tf.keras.models.load_model(sys.argv[1])


# Colors
BLACK = (0, 0, 0)
WHITE = (255, 255, 255)

# Start pygame
pygame.init()
size = width, height = 600, 400
screen = pygame.display.set_mode(size)

# Fonts
OPEN_SANS = "assets/fonts/OpenSans-Regular.ttf"
smallFont = pygame.font.Font(OPEN_SANS, 20)
largeFont = pygame.font.Font(OPEN_SANS, 40)

ROWS, COLS = 28, 28

OFFSET = 20
CELL_SIZE = 10

handwriting = [[0] * COLS for _ in range(ROWS)]
classification = None

while True:

    # Check if game quit
    for event in pygame.event.get():
        if event.type == pygame.QUIT:
            sys.exit()

    screen.fill(BLACK)

    # Check for mouse press
    click, _, _ = pygame.mouse.get_pressed()
    if click == 1:
        mouse = pygame.mouse.get_pos()
    else:
        mouse = None

    # Draw each grid cell
    cells = []
    for i in range(ROWS):
        row = []
        for j in range(COLS):
            rect = pygame.Rect(
                OFFSET + j * CELL_SIZE,
                OFFSET + i * CELL_SIZE,
                CELL_SIZE, CELL_SIZE
            )

            # If cell has been written on, darken cell
            if handwriting[i][j]:
                channel = 255 - (handwriting[i][j] * 255)
                pygame.draw.rect(screen, (channel, channel, channel), rect)

            # Draw blank cell
            else:
                pygame.draw.rect(screen, WHITE, rect)
            pygame.draw.rect(screen, BLACK, rect, 1)

            # If writing on this cell, fill in current cell and neighbors
            if mouse and rect.collidepoint(mouse):
                handwriting[i][j] = 250 / 255
                if i + 1 < ROWS:
                    handwriting[i + 1][j] = 220 / 255
                if j + 1 < COLS:
                    handwriting[i][j + 1] = 220 / 255
                if i + 1 < ROWS and j + 1 < COLS:
                    handwriting[i + 1][j + 1] = 190 / 255

    # Reset button
    resetButton = pygame.Rect(
        30, OFFSET + ROWS * CELL_SIZE + 30,
        100, 30
    )
    resetText = smallFont.render("Reset", True, BLACK)
    resetTextRect = resetText.get_rect()
    resetTextRect.center = resetButton.center
    pygame.draw.rect(screen, WHITE, resetButton)
    screen.blit(resetText, resetTextRect)

    # Classify button
    classifyButton = pygame.Rect(
        150, OFFSET + ROWS * CELL_SIZE + 30,
        100, 30
    )
    classifyText = smallFont.render("Classify", True, BLACK)
    classifyTextRect = classifyText.get_rect()
    classifyTextRect.center = classifyButton.center
    pygame.draw.rect(screen, WHITE, classifyButton)
    screen.blit(classifyText, classifyTextRect)

    # Reset drawing
    if mouse and resetButton.collidepoint(mouse):
        handwriting = [[0] * COLS for _ in range(ROWS)]
        classification = None

    # Generate classification
    if mouse and classifyButton.collidepoint(mouse):
        classification = model.predict(
            [np.array(handwriting).reshape(1, 28, 28, 1)]
        ).argmax()

    # Show classification if one exists
    if classification is not None:
        classificationText = largeFont.render(str(classification), True, WHITE)
        classificationRect = classificationText.get_rect()
        grid_size = OFFSET * 2 + CELL_SIZE * COLS
        classificationRect.center = (
            grid_size + ((width - grid_size) / 2),
            100
        )
        screen.blit(classificationText, classificationRect)

    pygame.display.flip()

完整项目,我已经上传了。 0积分下载。

完整项目链接

https://download.csdn.net/download/waterHBO/89881853

老哥留步,支持一下。

相关推荐
雪兽软件1 小时前
人工智能和大数据如何改变企业?
大数据·人工智能
UMS攸信技术3 小时前
汽车电子行业数字化转型的实践与探索——以盈趣汽车电子为例
人工智能·汽车
ws2019073 小时前
聚焦汽车智能化与电动化︱AUTO TECH 2025 华南展,以展带会,已全面启动,与您相约11月广州!
大数据·人工智能·汽车
堇舟4 小时前
斯皮尔曼相关(Spearman correlation)系数
人工智能·算法·机器学习
爱写代码的小朋友4 小时前
使用 OpenCV 进行人脸检测
人工智能·opencv·计算机视觉
Cici_ovo5 小时前
摄像头点击器常见问题——摄像头视窗打开慢
人工智能·单片机·嵌入式硬件·物联网·计算机视觉·硬件工程
QQ39575332375 小时前
中阳智能交易系统:创新金融科技赋能投资新时代
人工智能·金融
这个男人是小帅5 小时前
【图神经网络】 AM-GCN论文精讲(全网最细致篇)
人工智能·pytorch·深度学习·神经网络·分类
放松吃羊肉6 小时前
【约束优化】一次搞定拉格朗日,对偶问题,弱对偶定理,Slater条件和KKT条件
人工智能·机器学习·支持向量机·对偶问题·约束优化·拉格朗日·kkt
MJ绘画中文版6 小时前
灵动AI:艺术与科技的融合
人工智能·ai·ai视频