机器学习KNeighborsClassifier实现手写数字识别

import numpy as np

import cv2

from PIL import Image

import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split

from sklearn.neighbors import KNeighborsClassifier

from sklearn.metrics import accuracy_score

import os

import gzip

import urllib.request

============ 设置 matplotlib 中文字体(消除中文缺失警告) ============

plt.rcParams'font.sans-serif' = 'Microsoft YaHei', 'SimHei', 'Arial Unicode MS', 'DejaVu Sans'

plt.rcParams'axes.unicode_minus' = False # 解决负号 '-' 显示为方块的问题

============ 从本地/网络加载 MNIST 数据集 ============

def load_mnist_local(data_dir="mnist_data"):

"""

从本地加载MNIST数据集,若不存在则尝试用keras或手动下载

"""

方式1:尝试用 keras 加载(keras自带缓存,不需要联网)

try:

from keras.datasets import mnist

(X_train, y_train), (X_test, y_test) = mnist.load_data()

X = np.concatenate(X_train, X_test, axis=0)

y = np.concatenate(y_train, y_test, axis=0)

X = X.reshape(X.shape0, -1).astype(np.float32) / 255.0

print("✅ 通过 keras 加载 MNIST 成功")

return X, y.astype(np.int32)

except Exception:

pass

复制代码
# 方式2:从本地 .npz 文件加载
local_file = os.path.join(data_dir, "mnist.npz")
if os.path.exists(local_file):
    data = np.load(local_file)
    X = data["X"]
    y = data["y"]
    print("✅ 从本地 mnist.npz 加载成功")
    return X, y

# 方式3:手动下载 MNIST 原始文件并解析
os.makedirs(data_dir, exist_ok=True)
base_url = "https://ossci-datasets.s3.amazonaws.com/mnist/"
files = {
    "train-images-idx3-ubyte.gz": (60000, 784),
    "train-labels-idx1-ubyte.gz": (60000,),
    "t10k-images-idx3-ubyte.gz":  (10000, 784),
    "t10k-labels-idx1-ubyte.gz":  (10000,),
}

def download_file(filename):
    local_path = os.path.join(data_dir, filename)
    if not os.path.exists(local_path):
        url = base_url + filename
        print(f"下载 {filename} ...")
        urllib.request.urlretrieve(url, local_path)
    return local_path

def parse_images(path):
    with gzip.open(path, 'rb') as f:
        magic = int.from_bytes(f.read(4), 'big')
        num = int.from_bytes(f.read(4), 'big')
        rows = int.from_bytes(f.read(4), 'big')
        cols = int.from_bytes(f.read(4), 'big')
        buf = f.read(num * rows * cols)
        return np.frombuffer(buf, dtype=np.uint8).reshape(num, rows * cols).astype(np.float32) / 255.0

def parse_labels(path):
    with gzip.open(path, 'rb') as f:
        magic = int.from_bytes(f.read(4), 'big')
        num = int.from_bytes(f.read(4), 'big')
        return np.frombuffer(f.read(num), dtype=np.uint8).astype(np.int32)

try:
    print("尝试从 AWS 下载 MNIST ...")
    train_img = download_file("train-images-idx3-ubyte.gz")
    train_lbl = download_file("train-labels-idx1-ubyte.gz")
    test_img  = download_file("t10k-images-idx3-ubyte.gz")
    test_lbl  = download_file("t10k-labels-idx1-ubyte.gz")

    X_train = parse_images(train_img)
    y_train = parse_labels(train_lbl)
    X_test  = parse_images(test_img)
    y_test  = parse_labels(test_lbl)

    X = np.concatenate([X_train, X_test], axis=0)
    y = np.concatenate([y_train, y_test], axis=0)
    print("✅ MNIST 下载并加载成功")
    return X, y
except Exception as e:
    raise RuntimeError(f"无法加载 MNIST 数据集: {e}")

1. 训练KNN模型

print("加载MNIST数据集并训练模型...")

X, y = load_mnist_local()

划分数据集

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

截取部分数据提速

X_train = X_train:15000

y_train = y_train:15000

训练模型

knn = KNeighborsClassifier(n_neighbors=5)

knn.fit(X_train, y_train)

评估准确率

y_pred = knn.predict(X_test:3000)

acc = accuracy_score(y_test:3000, y_pred)

print(f"模型测试准确率: {acc:.4f}")

print("模型训练完成,开始识别本地图片\n")

2. 图片预处理函数(增强版:二值化 + 去噪 + 自动反色 + 中心裁剪)

def preprocess_image(img_path, debug=False):

"""

将任意手写数字照片预处理为 MNIST 风格的 28x28 特征向量

"""

1) 读取并转灰度

img = Image.open(img_path).convert('L')

img_array = np.array(img, dtype=np.float32)

复制代码
if debug:
    print(f"[DEBUG] 原始尺寸: {img_array.shape}, 像素范围: [{img_array.min():.0f}, {img_array.max():.0f}]")

# 2) OTSU 二值化(自动找最佳阈值,分离前景/背景)
img_uint8 = img_array.astype(np.uint8)
_, binary = cv2.threshold(img_uint8, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)

if debug:
    fg_ratio = (binary > 0).sum() / binary.size
    print(f"[DEBUG] OTSU阈值二值化后,前景像素占比: {fg_ratio:.3f}")

# 3) 自动判断是否需要反色:MNIST 是黑底白字(前景=255),如果前景占比 > 50% 说明是白底黑字,需要反色
if (binary > 0).sum() > binary.size * 0.5:
    binary = 255 - binary
    if debug:
        print("[DEBUG] 检测到白底黑字,已反色为黑底白字")

# 4) 形态学去噪(去除孤立噪点)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
binary = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel)

# 5) 找数字轮廓,裁剪出数字区域
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if contours:
    # 取最大轮廓(排除噪声轮廓)
    cnt = max(contours, key=cv2.contourArea)
    x, y, w, h = cv2.boundingRect(cnt)

    # 稍微扩展边界
    pad = 4
    x = max(0, x - pad)
    y = max(0, y - pad)
    w = min(binary.shape[1] - x, w + pad * 2)
    h = min(binary.shape[0] - y, h + pad * 2)

    binary = binary[y:y+h, x:x+w]
    if debug:
        print(f"[DEBUG] 裁剪数字区域: x={x}, y={y}, w={w}, h={h}")

# 6) 按比例缩放:将数字放入 20x20 的画布中(保留边距),然后放到 28x28 中心
h, w = binary.shape
scale = 20.0 / max(h, w)
new_h, new_w = int(h * scale), int(w * scale)
if new_h < 1:
    new_h = 1
if new_w < 1:
    new_w = 1
resized = cv2.resize(binary, (new_w, new_h), interpolation=cv2.INTER_AREA)

# 创建 28x28 黑底画布,居中放置
canvas = np.zeros((28, 28), dtype=np.uint8)
offset_y = (28 - new_h) // 2
offset_x = (28 - new_w) // 2
canvas[offset_y:offset_y+new_h, offset_x:offset_x+new_w] = resized

# 7) 归一化到 [0, 1]
result = canvas.astype(np.float32) / 255.0

if debug:
    print(f"[DEBUG] 最终形状: {result.shape}, 前景像素: {(result > 0.1).sum()}")

# 展平为 1x784 特征向量
img_flat = result.reshape(1, -1)
return img_flat, result

3. 可视化:并排显示原图 + 预处理结果

def show_processed_img(original_path, processed_arr):

fig, axes = plt.subplots(1, 2, figsize=(8, 4))

原图

orig = Image.open(original_path).convert('L')

axes0.imshow(orig, cmap='gray')

axes0.set_title("原始图片")

axes0.axis('off')

预处理后

axes1.imshow(processed_arr, cmap='gray')

axes1.set_title("28×28 预处理后")

axes1.axis('off')

plt.tight_layout()

plt.show()

4. 识别函数

def predict_digit(img_path, debug=False):

try:

img_data, processed_img = preprocess_image(img_path, debug=debug)

复制代码
    # 校验维度
    if img_data.shape[1] != 784:
        print(f"错误:特征维度 {img_data.shape[1]},要求784")
        return

    # KNN 预测(返回概率分布,看 top-3 候选)
    pred = knn.predict(img_data)
    proba = knn.predict_proba(img_data)
    top3_idx = np.argsort(proba[0])[-3:][::-1]
    top3_classes = knn.classes_[top3_idx]
    top3_scores = proba[0][top3_idx]

    print(f"识别结果:数字 {pred[0]}(置信度 {top3_scores[0]:.2%})")
    print(f"Top-3 候选:", end="")
    for cls, sc in zip(top3_classes, top3_scores):
        print(f" [{cls}]:{sc:.2%}", end="")
    print()

    show_processed_img(img_path, processed_img)
except Exception as e:
    print(f"识别失败:{e}")

if name == "main ":

批量测试多张图片(支持相对路径或绝对路径)

test_images = [

"digit1.jpg",

"digit2.jpg",

"digit3.jpg",

]

复制代码
# 先单张调试:打开 debug=True 查看预处理细节
# predict_digit("digit7.jpg", debug=True)

for img_path in test_images:
    if os.path.exists(img_path):
        print(f"\n{'='*40}")
        print(f"测试图片: {img_path}")
        print(f"{'='*40}")
        predict_digit(img_path)
    else:
        print(f"文件不存在: {img_path}")
相关推荐
jerryinwuhan1 小时前
marker BiBERTo解释
java·前端·人工智能
掘金安东尼1 小时前
如果你真能 7×24 小时运行最顶级的大模型,你会想用它来干嘛
人工智能
翼龙云_cloud1 小时前
云服务器代理商:2026 年云计算趋势 AI 算力需求激增下的云服务器选择
服务器·人工智能·云计算·ai智能体
数智工坊1 小时前
周志华《Machine Learning》学习笔记--第四章--决策树
笔记·学习·机器学习
m沐沐1 小时前
【机器学习】NLP---用 Python+TF-IDF 给《红楼梦》自动提取关键词
人工智能·python·机器学习·自然语言处理·nlp·中文分词·tf-idf
小脑斧1231 小时前
自媒体内容工业化:基于AI Skills低代码实现穿搭账号矩阵自动化量产
人工智能·低代码·媒体·skills·openclaw·hermes·marvis
填满你的记忆1 小时前
《为什么 MySQL 不适合做 AI 检索?》
数据库·人工智能·mysql·ai·向量数据库
威尔逊·柏斯科·希伯理1 小时前
机器学习第二天(KNN)
人工智能·机器学习
winlife_1 小时前
让 AI 自动跑 PlayMode 回归测试:从 BUG 注入到自动判 FAIL 的完整闭环
人工智能·unity·bug·ai编程·mcp·回归测试·游戏测试