基于 KNN 算法的手写数字识别项目实践

一、项目背景与意义​

在数字化时代,手写数字识别是一个经典的模式识别问题,广泛应用于邮政编码识别、银行支票金额识别、手写答题卡批改等场景。实现高效、准确的手写数字识别,不仅能提升相关业务的自动化处理效率,还能为更复杂的图像识别任务奠定基础。​

KNN(K - 近邻)算法作为一种简单易懂的监督学习算法,在分类任务中有着广泛应用。它无需复杂的模型训练过程,核心思想是 "物以类聚"------ 通过计算待识别样本与已知样本的距离,找到最近的 K 个样本,再根据这 K 个样本的类别来判断待识别样本的类别。本次项目就将基于 KNN 算法实现手写数字识别,带你从零开始体验算法落地的全过程。

二、项目实现步骤

1.环境搭建:筑牢项目基础

首先,明确所需库及其作用:numpy用于高效的数值计算和矩阵操作,是处理数据集的核心工具;opencv负责将手写数字图像及识别结果以可视化形式呈现,便于直观观察;sklearn(scikit-learn)则提供了 KNN 模型、数据集加载和模型评估等一站式工具,能大幅简化开发流程。

2.数据处理:为模型输入 "优质食材"

1.数据是模型的 "粮食",数据处理的质量直接影响模型性能,这一环节包括数据集加载、格式转换、归一化和分割等步骤。MNIST 数据集中的每个手写数字图像是 20×20 像素的二维矩阵,而 KNN 算法需要一维特征向量作为输入。因此,需将二维图像转换为一维数组,20×20 的矩阵转换后会得到 2500 个特征值的一维数组。

2.为了评估模型的泛化能力,需要将数据集分割为训练集和测试集。训练集用于让模型 "学习" 数据规律,测试集用于检验模型的识别效果。使用train_test_split函数进行分割

3、模型构建与训练:打造 "识别利器"

KNN 算法的一大特点是 "懒惰学习",无需复杂的训练过程,其核心是存储训练数据,在预测时通过计算距离找到近邻。​使用sklearn的KNeighborsClassifier构建模型

4、模型预测与评估:检验 "利器" 性能

使用测试集对模型进行预测,得到模型对测试样本的识别结果,predict方法会根据训练集中的近邻信息,为测试集中的每个样本预测对应的数字标签。

数据集照片:

三、代码分析

1.库导入:基础工具准备

python 复制代码
import numpy as np  # 用于数组操作和数值计算​
from sklearn.neighbors import KNeighborsClassifier  # 导入KNN分类器​
import cv2  # 用于图像读取、处理和转换(OpenCV库)

这部分代码导入了三个核心库:numpy负责数据结构处理,sklearn提供 KNN 算法实现,cv2(OpenCV)则专门用于图像相关的操作,是整个代码的 "工具基础"。

2.数据集准备:从图像到可训练数据

这部分是代码的核心,主要实现 "原始图像→训练集 / 测试集→特征向量" 的转换,共分为 5 个步骤:

(1)读取并预处理原始数据集图像

python 复制代码
# 读取包含手写数字的数据集图像(假设是一个包含多个数字的拼接图)
img = cv2.imread('23ea179d663b57f1b10ff385449f038.png')
# 将彩色图像转换为灰度图像(减少计算量,保留核心特征)
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

把大量手写数字样本按行列排列成一张大图(类似答题卡的布局),后续需要将其分割成单个数字样本。转换为灰度图是因为颜色信息对数字识别无意义,灰度图(单通道)能减少数据量。

(2)分割图像为单个数字样本

python 复制代码
# 将灰度图按行分割为50份(假设原图是50行×100列的数字网格)
a = np.vsplit(gray, 50)
# 先按行分割为50行,再将每行分割为100列,得到50×100个数字样本
cells = [np.hsplit(row, 100) for row in np.vsplit(gray, 50)]

这部分是关键的分割步骤:​

np.vsplit(gray, 50):将图像垂直(Vertical)分割成 50 个等高大图(每行 1 个大图);​

np.hsplit(row, 100):将每行的大图水平(Horizontal)分割成 100 个等宽小图;​

最终得到50行×100列共 5000 个小图像(每个小图是一个手写数字样本)。

将其按照这样切分将每个手写数字单独拿出来,使其前一半做训练集,后一半做测试集。

(3)转换为数组并划分训练集和测试集

python 复制代码
# 将分割后的图像列表转换为numpy数组(形状为50×100×宽×高,便于索引)
x = np.array(cells)

# 取前50列作为训练集(50行×50列=2500个样本)
train = x[:, :50, :, :]
# 取后50列作为测试集(50行×50列=2500个样本)
test = x[:, 50:, :, :]

这里每个小图像的尺寸是固定的(比如 20×20 像素,20×20=400,后续特征维度对应 400)。通过数组切片将 5000 个样本按 "列" 分为两部分:前 50 列 2500 个样本用于训练,后 50 列 2500 个样本用于测试。

(4)将图像转换为特征向量

python 复制代码
# 将训练集图像从二维(宽×高)转换为一维特征向量(2500个样本,每个400维)
train_new = train.reshape(2500, 400).astype(np.float32)
# 测试集同样转换为一维特征向量(2500个样本,每个400维)
test_new = test.reshape(2500, 400).astype(np.float32)

KNN 算法需要一维特征向量输入,因此将每个二维图像(如 20×20)展平为 1×400 的一维数组。astype(np.float32)是为了统一数据类型,避免计算时类型错误。

(5)生成样本标签(数字类别)

python 复制代码
# 生成0-9的基础数字数组
b = np.arange(10)
# 为训练集生成标签:每个数字(0-9)重复250次(2500个样本=10类×250个)
lables = np.repeat(b, 250)
# 转换为二维数组(形状为2500×1,符合sklearn标签格式)
train_lables = lables[:, np.newaxis]

# 测试集标签生成逻辑与训练集相同(10类×250个=2500个)
test_lables = np.repeat(b, 250)[:, np.newaxis]

标签是样本对应的真实数字(0-9)。这里假设数据集中每个数字有 250 个样本(2500 个样本 ÷10 类 = 250),因此用np.repeat生成重复标签,确保每个样本对应正确的数字类别。

3、模型训练与评估:验证 KNN 效果

python 复制代码
# 初始化KNN模型(K=5,即取最近的5个邻居)
knn = KNeighborsClassifier(n_neighbors=5)

knn.fit(train_new, train_lables)

# 计算模型在测试集上的准确率(正确预测样本数÷总样本数)
score = knn.score(test_new, test_lables)
print(score)  # 输出准确率(0-1之间)

KNeighborsClassifier(n_neighbors=5):定义 K=5 的 KNN 模型;​

fit方法是模型 "训练"(实际是存储训练样本)​score方法通过测试集计算准确率,用于评估模型效果。

4.单样本预测:实际应用测试

python 复制代码
# 读取待预测的单个数字图像(如手写的"2")
img1 = cv2.imread('1.png')
# 转换为灰度图(与训练数据格式一致)
jiangwei = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)

# 转换为特征向量(形状为1×400,与训练样本维度一致)
test_sample = jiangwei.reshape(-1, 400).astype(np.float32)

# 用训练好的模型预测数字
result = knn.predict(test_sample)
# 输出预测结果(转换为整数)
print(f'预测数字是:{int(result[0])}')

这部分是实际应用环节:​

1.读取用户输入的单个数字图像(如1.png);​

2.按训练数据的格式预处理(灰度化、展平为 400 维向量);​

3.用predict方法得到预测结果,最终输出数字类别。

四、完整代码

python 复制代码
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
import cv2
img=cv2.imread('23ea179d663b57f1b10ff385449f038.png')
gray=cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
a=np.vsplit(gray,50)
cells=[np.hsplit(row,100) for row in np.vsplit(gray,50)]
x=np.array(cells)
train=x[:,:50,:,:]
test=x[:,50:,:,:]
train_new=train.reshape(2500,400).astype(np.float32)
test_new=test.reshape(2500,400).astype(np.float32)
b=np.arange(10)
lables=np.repeat(b,250)
train_lables=lables[:,np.newaxis]
test_lables=np.repeat(b,250)[:,np.newaxis]
knn=KNeighborsClassifier(n_neighbors=5)
knn.fit(train_new,train_lables)
score=knn.score(test_new,test_lables)
print(score)

img1=cv2.imread('1.png')
jiangwei=cv2.cvtColor(img1,cv2.COLOR_BGR2GRAY)
test_sample = jiangwei.reshape(-1, 400).astype(np.float32)
result = knn.predict(test_sample)
print(f'预测数字是:{int(result[0])}')

五、总结与注意事项​

(一)核心逻辑​

该代码实现了一个完整的 "自定义数据集→KNN 训练→预测" 流程,核心是将拼接图像分割为单个样本,转换为特征向量后用 KNN 进行分类,适合小批量自定义手写数字识别场景。​

(二)关键注意事项​

1.数据格式一致性:所有图像(训练集、测试集、待预测图像)必须有相同的尺寸(最终展平为 400 维),否则会因特征维度不匹配报错;​

2.标签对应性:标签生成依赖 "每个数字有 250 个样本" 的假设,若实际数据集样本数量不同,需修改np.repeat的参数;​

(三)适用场景​

适合个人制作的手写数字数据集(如自己手写的 0-9 数字拼接图),无需依赖公开数据集,灵活性较高,但需保证数据集分割逻辑(50×100)与实际图像一致。​

相关推荐
朝朝又沐沐3 小时前
算法竞赛阶段二-数据结构(36)数据结构双向链表模拟实现
开发语言·数据结构·c++·算法·链表
Ronin-Lotus3 小时前
深度学习篇---剪裁&缩放
图像处理·人工智能·缩放·剪裁
cpsvps4 小时前
3D芯片香港集成:技术突破与产业机遇全景分析
人工智能·3d
薰衣草23334 小时前
一天两道力扣(6)
算法·leetcode
剪一朵云爱着4 小时前
力扣946. 验证栈序列
算法·
遇见尚硅谷4 小时前
C语言:*p++与p++有何区别
c语言·开发语言·笔记·学习·算法
国科安芯4 小时前
抗辐照芯片在低轨卫星星座CAN总线通讯及供电系统的应用探讨
运维·网络·人工智能·单片机·自动化
AKAMAI4 小时前
利用DataStream和TrafficPeak实现大数据可观察性
人工智能·云原生·云计算
天天开心(∩_∩)4 小时前
代码随想录算法训练营第三十二天
算法
微光-沫年4 小时前
150-SWT-MCNN-BiGRU-Attention分类预测模型等!
机器学习·matlab·分类