KNN算法手写数字识别 网格搜索 交叉验证 机器学习基础2 python人工智能

上一篇:KNN算法基础 机器学习基础1 python人工智能

目录

[网格搜索与交叉验证:如何为 KNN 选择最优超参数](#网格搜索与交叉验证:如何为 KNN 选择最优超参数)

一、为什么不能只划分一次训练集和测试集?

二、交叉验证的核心思想

[三、KNN 中为什么需要调参?](#三、KNN 中为什么需要调参?)

[四、什么是网格搜索(Grid Search)?](#四、什么是网格搜索(Grid Search)?)

[五、网格搜索 + 交叉验证如何协同工作?](#五、网格搜索 + 交叉验证如何协同工作?)

六、交叉验证折数(cv)的含义与权衡

七、最终模型评估的正确姿势

八、以上内容完整代码+详细注释

运行结果

[基于 KNN 的手写数字识别实战解析](#基于 KNN 的手写数字识别实战解析)

一、任务背景与数据形式

[1. 数据来源与结构](#1. 数据来源与结构)

二、像素数据与图像可视化

[1. 从向量还原为图像](#1. 从向量还原为图像)

三、数据预处理:为什么要归一化?

四、训练集与测试集的合理划分

[五、KNN 模型在手写数字识别中的作用](#五、KNN 模型在手写数字识别中的作用)

[1. KNN 的工作机制](#1. KNN 的工作机制)

[2. K 值的选择](#2. K 值的选择)

六、模型评估方式

七、模型持久化:从"训练"到"部署"

[1. 模型保存的意义](#1. 模型保存的意义)

[2. 模型加载与预测流程](#2. 模型加载与预测流程)

八、完整流程回顾

九、以上完整代码+详细注释

运行结果

网格搜索与交叉验证:如何为 KNN 选择最优超参数

在机器学习建模过程中,模型性能不仅取决于数据本身,还高度依赖于超参数的选择 。以 KNN(K 近邻)算法为例,k 的取值直接影响模型的泛化能力。如果 k 选得不合适,模型要么容易过拟合,要么预测能力不足。

本文将结合一个典型的鸢尾花分类任务,系统讲解交叉验证网格搜索的原理、流程以及二者如何协同工作,从而自动、科学地选择最优超参数。

一、为什么不能只划分一次训练集和测试集?

在最基础的建模流程中,通常会将数据集划分为训练集测试集,例如 80% 用于训练,20% 用于测试。这种方式简单直观,但存在一个明显问题:

模型的好坏可能高度依赖于这一次随机划分的结果。

如果测试集"恰好比较容易",模型的评估结果就会偏高;反之亦然。尤其是在样本量较小的情况下,这种不稳定性更加明显。

为了解决这一问题,就引入了------交叉验证(Cross Validation)

二、交叉验证的核心思想

交叉验证的核心思想是:
让每一部分数据都有机会既作为训练数据,也作为验证数据。

k 折交叉验证 为例,其基本流程如下:

  1. 将训练数据平均划分为 k 份;

  2. 每次取其中 1 份作为验证集,其余 k−1 份作为训练集;

  3. 重复 k 次,使每一份数据都恰好当过一次验证集;

  4. k 次评估结果取平均,作为模型性能的最终评估。

这种方式的优势在于:

  • 评估结果更加稳定、可靠;

  • 减少了因一次随机划分带来的偶然性;

  • 更接近模型在"未知数据"上的真实表现。

三、KNN 中为什么需要调参?

KNN 是一种典型的基于距离的非参数模型 ,其核心超参数是 k(邻居个数):

  • k 较小

    • 模型复杂度高

    • 对噪声敏感

    • 容易过拟合

  • k 较大

    • 模型更加平滑

    • 可能忽略局部结构

    • 容易欠拟合

因此,k 并不存在一个"放之四海而皆准"的固定取值,必须结合具体数据,通过实验来选择。

四、什么是网格搜索(Grid Search)?

网格搜索是一种穷举式的超参数搜索方法,其思想非常直接:

提前给定一组可能的超参数取值,然后逐一尝试,选出表现最好的那一组。

在 KNN 的场景下,可以事先设定一组候选的 k 值(例如 1 到 10),然后:

  • 对每一个 k

  • 在训练集上进行模型训练;

  • 通过交叉验证评估该 k 的平均性能。

最终,从所有候选参数中,选择交叉验证评分最高的那一个。

五、网格搜索 + 交叉验证如何协同工作?

当网格搜索与交叉验证结合使用时,整体流程可以概括为:

  1. 先划分训练集和测试集

    • 测试集只在最后使用,用于最终评估

    • 避免"信息泄露"

  2. 在训练集内部进行网格搜索

    • 针对每一个超参数组合

    • 使用交叉验证进行多次训练与验证

  3. 计算每组参数的平均交叉验证得分

    • 比较不同参数组合的性能

    • 排除偶然性带来的影响

  4. 自动选出最优超参数组合

    • 得到最优评分

    • 得到对应的最优模型

这种方式的优势在于:

  • 参数选择过程完全自动化;

  • 评估结果更加稳健;

  • 避免人工"拍脑袋"选参数。

六、交叉验证折数(cv)的含义与权衡

在实践中,交叉验证的折数也是一个需要权衡的问题:

  • 折数较小(如 3 折、4 折)

    • 计算速度快

    • 评估稳定性相对一般

  • 折数较大(如 5 折、10 折)

    • 评估结果更可靠

    • 计算成本更高

在样本量不大的经典数据集(如鸢尾花数据集)中,4 折或 5 折交叉验证通常是一个较为平衡的选择。

七、最终模型评估的正确姿势

需要特别强调的是:

交叉验证的评分 ≠ 模型的最终性能。

正确流程应当是:

  1. 使用训练集 + 网格搜索 + 交叉验证,确定最优超参数;

  2. 用最优参数在完整训练集上重新训练模型;

  3. 从未参与过训练和调参的测试集上进行评估。

只有这样得到的测试集准确率,才是对模型泛化能力的真实衡量。

八、以上内容完整代码+详细注释

python 复制代码
'''
交叉验证:如何划分训练集和测试集
网格搜索:KNN的k这个超参怎么选择
'''

from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
# 分割训练集和测试集的, 寻找最优超参的(网格搜索 + 交叉验证):
from sklearn.model_selection import train_test_split, GridSearchCV

# 1. 加载鸢尾花数据集.
iris_data = load_iris()
# print(f'数据集的键:{iris_data.keys()}')
# 数据集的键:dict_keys(['data', 'target', 'frame', 'target_names', 'DESCR', 'feature_names', 'filename', 'data_module'])

# 2. 数据预处理, 这里是: 切分训练集和测试集, 比例: 8:2
# 参1: 数据集的特征数据,   参数2: 数据集的标签数据, 参数3: 测试集的比例, 参数4: 随机种子.
x_train, x_test, y_train, y_test = train_test_split(iris_data.data, iris_data.target, test_size=0.2, random_state=22)

# 3. 特征工程 -> 特征预处理 -> 标准化.
transfer = StandardScaler()
x_train = transfer.fit_transform(x_train)
x_test = transfer.transform(x_test)

# 4. 模型训练.
# 创建 KNN分类对象.
estimator=KNeighborsClassifier()

# 记录可能出现的K值
param_dict = {'n_neighbors': [i for i in range(1, 11)]}

# 创建 GridSearchCV对象 -> 寻找最优超参, 使用网格搜索 + 交叉验证方式
# 参1: 要计算最优超参的模型对象
# 参2: 该模型超参可能出现的值
# 参3: 交叉验证的折数, 这里的4折表示: 每个超参组合, 都会进行4次交叉验证.  这里共计是 4 * 10 = 40次.
# 返回值 estimator -> 处理后的模型对象.
estimator = GridSearchCV(estimator, param_dict, cv=4)

# 模型训练.
estimator.fit(x_train, y_train)

# 打印
print(f'最优评分: {estimator.best_score_}')
print(f'最优超参组合: {estimator.best_params_}')
print(f'最优的估计器对象: {estimator.best_estimator_}')
print(f'具体的交叉验证结果: {estimator.cv_results_}')
print('-'*30)

# 5. 模型评估.
estimator=KNeighborsClassifier(n_neighbors=3)
estimator.fit(x_train, y_train)
y_pre = estimator.predict(x_test)
print(f'准确率: {accuracy_score(y_test, y_pre)}')

运行结果

基于 KNN 的手写数字识别实战解析

手写数字识别是机器学习和计算机视觉领域中最经典的入门任务之一。该任务不仅直观,而且完整覆盖了从数据读取、特征处理、模型训练到模型部署与预测的全过程,非常适合理解传统机器学习模型在图像任务中的应用方式。

本文将围绕一个 基于 KNN(K 近邻)算法的手写数字识别系统,系统讲解其整体设计思路、数据处理方法以及模型使用流程。

一、任务背景与数据形式

1. 数据来源与结构

手写数字数据以 CSV 文件 的形式存储,其中:

  • 每一行表示一张图片

  • 第 1 列为标签(0--9)

  • 后续 784 列为像素特征

由于图像大小为 28 × 28 ,因此每张图片可被展平成一个长度为 784 的向量,每个元素表示一个像素点的灰度值。

这种"图像 → 向量"的表示方式,是传统机器学习算法处理图像数据的典型方法。

二、像素数据与图像可视化

在建模之前,理解数据本身非常重要。

1. 从向量还原为图像

  • 单张图片的特征向量维度为 (784,)

  • 通过重塑(reshape)可还原为 (28, 28) 的二维矩阵

  • 使用灰度图方式进行可视化

这一过程可以帮助我们:

  • 验证数据是否正确读取

  • 直观理解模型"看到"的输入内容

  • 排查潜在的数据错误或异常样本

三、数据预处理:为什么要归一化?

在原始数据中,像素值范围通常为 0--255

而 KNN 是一种基于距离计算的模型,如果不进行归一化,会带来两个问题:

  1. 距离计算尺度过大,影响数值稳定性

  2. 不利于不同特征之间的公平比较

因此,在模型训练前,需要将像素值进行 归一化处理

  • 将像素值缩放到 [0, 1] 区间

  • 保留原有的相对大小关系

  • 提高模型的鲁棒性

四、训练集与测试集的合理划分

为了客观评估模型的泛化能力,数据集被划分为:

  • 训练集:80%

  • 测试集:20%

在划分过程中,采用了 分层抽样(stratify) 策略,其核心目的是:

保证训练集和测试集中,各个数字类别的比例基本一致

这样可以有效避免:

  • 某些数字在测试集中样本过少

  • 评估结果产生偏差

五、KNN 模型在手写数字识别中的作用

1. KNN 的工作机制

KNN 是一种懒学习算法,其基本思想是:

  1. 对于待预测样本,计算它与训练集中所有样本的距离;

  2. 找出距离最近的 K 个邻居;

  3. 通过多数投票的方式确定最终类别。

在手写数字识别任务中:

  • 每一张图片被看作一个高维空间中的点

  • 相似的数字在特征空间中距离更近

2. K 值的选择

在该实现中,K 值设定为 3,其含义是:

  • 参考距离最近的 3 张已知数字图片

  • 通过投票决定预测结果

较小的 K 值有助于:

  • 捕捉局部结构

  • 提高对细节的敏感度

同时也需要注意:

  • K 过小可能对噪声敏感

  • K 过大可能导致欠拟合

六、模型评估方式

模型训练完成后,使用测试集进行评估,并计算:

  • 模型准确率(Accuracy)

准确率表示:

在测试集中,模型预测正确的样本占总样本的比例

在手写数字这种多分类且类别均衡的问题中,准确率是一个直观且常用的评估指标。

七、模型持久化:从"训练"到"部署"

在实际应用中,模型训练通常是一次性的,而预测可能是反复进行的。因此需要将训练好的模型保存下来。

1. 模型保存的意义

  • 避免重复训练

  • 提高系统运行效率

  • 支持模型复用与部署

2. 模型加载与预测流程

在预测阶段:

  1. 读取外部手写数字图片;

  2. 将图片转换为与训练数据一致的向量格式;

  3. 加载已保存的模型;

  4. 输出预测结果。

这一步标志着模型从"实验阶段"进入"可使用阶段"。

八、完整流程回顾

整个手写数字识别系统可以概括为以下步骤:

  1. 数据读取与结构理解

  2. 像素向量还原与可视化

  3. 特征归一化处理

  4. 分层划分训练集与测试集

  5. 使用 KNN 进行模型训练

  6. 测试集评估模型性能

  7. 保存模型并进行实际预测

这一流程完整体现了传统机器学习在图像任务中的工程实践路径

九、以上完整代码+详细注释

python 复制代码
'''
KNN手写数字识别
28*28像素的图片,即: 我们的csv文件中每一行都有 784个像素点, 表示图片(每个像素)的 颜色.
'''

import matplotlib.pyplot as plt
import pandas as pd
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
import joblib
from collections import Counter

# 忽略警告.
import warnings
warnings.filterwarnings('ignore', module='sklearn') # 参1: 忽略警告, 参2: 忽略的模块.

# 1. 定义函数, 获取用户传入的索引, 展示该索引对应的图片.
def show_digit(idx):
    # 读取数据集
    df = pd.read_csv('./data/手写数字识别.csv')
    # print(f'读取源数据:\n{df}')

    # 判断索引是否越界.
    if idx < 0 or idx > len(df) - 1:
        print('索引越界!')
        return

    x=df.iloc[:, 1:]  # 获取所有特征列.
    y=df.iloc[:, 0]  # 获取标签列.
    print(f'该图片对应的数字是: {y.iloc[idx]}')
    print(f'查看所有的标签的分布情况: {Counter(y)}')
    print(f'用户传入的索引对应的图片的形状: {x.iloc[idx].shape}')

    # 把 (784,) 转换成 (28, 28)
    x = x.iloc[idx].values.reshape(28, 28)
    # print(f'转化成28*28:\n {x}')

    # 绘制图片.
    plt.imshow(x, cmap='gray')  # 灰度图
    plt.axis('off')  # 不显示坐标轴
    plt.show()

# 2. 训练模型, 并保存模型.
def train_model():
    df = pd.read_csv('./data/手写数字识别.csv')

    # 数据预处理.
    x = df.iloc[:, 1:] # 获取所有特征列.
    y = df.iloc[:, 0] # 获取标签列.
    print(f'x的形状: {x.shape}')
    print(f'y的形状: {y.shape}')
    print(f'查看所有的标签的分布情况: {Counter(y)}')

    # 对特征列(拆分前)进行 归一化.
    x = x / 255   # 颜色的数值

    # 拆分训练集和测试集.
    # 参5: 参考y值进行抽取, 保持标签的比例(数据均衡)
    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=21, stratify=y)

    # 模型训练.
    estimator = KNeighborsClassifier(n_neighbors=3)
    estimator.fit(x_train, y_train)

    # 模型评估.
    print(f'准确率: {estimator.score(x_test, y_test)}')
    print(f'准确率: {accuracy_score(y_test, estimator.predict(x_test))}')

    # 保存模型.
    joblib.dump(estimator, './model/手写数字识别.pkl')
    print('模型保存成功!')


# 3. 测试模型.
def use_model():
    # 加载图片
    x = plt.imread('./data/demo.png')
    # print(f'图片的形状: {x.shape}')

    # 加载模型
    estimator = joblib.load('./model/手写数字识别.pkl')
    print(f'模型加载成功!')

    # 模型预测.
    # 数据集转换.
    print(f'数据集转换:\n{x.reshape(1, 784).shape}')  # reshape(1,-1)效果也一样
    x=x.reshape(1, -1) # 用原始的读取到的像素值, 做预测.
    y_pre = estimator.predict(x)
    print(f'预测值为: {y_pre}')



if __name__ == '__main__':
    # show_digit(9)
    # train_model()
    use_model()

运行结果

相关推荐
AngelPP2 小时前
OpenClaw 架构深度解析:如何把 AI 助手搬到你的个人设备上
人工智能
宅小年2 小时前
Claude Code 换成了Kimi K2.5后,我再也回不去了
人工智能·ai编程·claude
AI探索者2 小时前
LangGraph StateGraph 实战:状态机聊天机器人构建指南
python
AI探索者2 小时前
LangGraph 入门:构建带记忆功能的天气查询 Agent
python
九狼2 小时前
Flutter URL Scheme 跨平台跳转
人工智能·flutter·github
ZFSS2 小时前
Kimi Chat Completion API 申请及使用
前端·人工智能
天翼云开发者社区4 小时前
春节复工福利就位!天翼云息壤2500万Tokens免费送,全品类大模型一键畅玩!
人工智能·算力服务·息壤
知识浅谈4 小时前
教你如何用 Gemini 将课本图片一键转为精美 PPT
人工智能
FishCoderh4 小时前
Python自动化办公实战:批量重命名文件,告别手动操作
python
躺平大鹅4 小时前
Python函数入门详解(定义+调用+参数)
python