Python 机器学习核心入门与实战进阶 Day 2 - KNN(K-近邻算法)分类实战与调参

✅ 今日目标

  • 理解 KNN 的原理与"以邻为近"的思想
  • 掌握 K 值选择与模型效果的关系
  • 学会使用 sklearn 训练 KNN 模型
  • 实现 KNN 分类 + 模型评估 + 超参数调优

📘 一、KNN 算法原理

KNN(K-Nearest Neighbors)核心思想:

给定一个待预测样本,找到训练集中"距离它最近"的 K 个样本,用这些样本的类别进行多数投票预测。

特点 描述
模型类型 懒惰学习(无显式训练过程)
距离度量 欧几里得距离(默认)或自定义
参数调优 K 值、距离函数、权重方式
适用场景 数据量不大,维度不高,需快速建模时

🧪 二、KNN 分类流程(代码实践)

python 复制代码
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# 生成数据
X = [[1], [2], [3], [10], [11], [12]]
y = [0, 0, 0, 1, 1, 1]

# 训练测试划分
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33)

# 建模
model = KNeighborsClassifier(n_neighbors=3)
model.fit(X_train, y_train)

# 预测
y_pred = model.predict(X_test)
print("准确率:", accuracy_score(y_test, y_pred))

🧠 三、K 值选择对模型的影响

K 值 模型表现
K 太小 模型过拟合,受噪声影响大
K 太大 模型过于平滑,泛化能力下降
一般建议 使用奇数,避免投票平局;通过交叉验证选择最佳 K

🔧 四、模型调参建议(使用 GridSearchCV)

python 复制代码
from sklearn.model_selection import GridSearchCV

param_grid = {'n_neighbors': list(range(1, 11))}
grid_search = GridSearchCV(KNeighborsClassifier(), param_grid, cv=5)
grid_search.fit(X_train, y_train)

print("最优K值:", grid_search.best_params_)
print("最佳准确率:", grid_search.best_score_)

🧾 今日总结

技能 工具
快速建模 KNeighborsClassifier
评估效果 accuracy_score()
参数调优 GridSearchCV()
可视化分类边界 使用 matplotlibseaborn

🧪 建议练习脚本

  • 使用 sklearn 中的 KNN 模型实现学生是否及格分类

  • 尝试多种 K 值进行训练,并绘制准确率变化图

  • 使用 GridSearchCV 找出最优 K

  • 可视化分类边界(二维特征时)

    python 复制代码
    # KNN 分类实战演示:学生是否及格预测
    
    from sklearn.datasets import make_classification
    from sklearn.model_selection import train_test_split, GridSearchCV
    from sklearn.neighbors import KNeighborsClassifier
    from sklearn.metrics import accuracy_score, classification_report
    import matplotlib.pyplot as plt
    import pandas as pd
    import numpy as np
    
    plt.rcParams['font.family'] = 'Arial Unicode MS'  # Mac 用户可用
    plt.rcParams['axes.unicode_minus'] = False
    # 1. 模拟学生成绩数据(两个特征:成绩 + 性别)
    np.random.seed(42)
    size = 100
    scores = np.random.randint(40, 100, size)
    genders = np.random.choice([0, 1], size=size)  # 0=女, 1=男
    pass_label = (scores >= 60).astype(int)
    
    X = np.column_stack(((scores - scores.mean()) / scores.std(), genders))  # 标准化+性别
    y = pass_label
    
    # 2. 拆分训练集与测试集
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    
    # 3. 不同 K 值准确率比较
    acc_list = []
    k_values = range(1, 16)
    
    for k in k_values:
        model = KNeighborsClassifier(n_neighbors=k)
        model.fit(X_train, y_train)
        y_pred = model.predict(X_test)
        acc = accuracy_score(y_test, y_pred)
        acc_list.append(acc)
    
    # 4. 可视化不同 K 值的准确率
    plt.plot(k_values, acc_list, marker='o', linestyle='--')
    plt.title("不同 K 值下的准确率")
    plt.xlabel("K 值")
    plt.ylabel("准确率")
    plt.xticks(k_values)
    plt.grid(True)
    plt.tight_layout()
    plt.show()
    
    # 5. 使用 GridSearchCV 找最佳 K
    param_grid = {'n_neighbors': list(range(1, 16))}
    grid_search = GridSearchCV(KNeighborsClassifier(), param_grid, cv=5)
    grid_search.fit(X_train, y_train)
    
    print("✅ 最佳 K 值:", grid_search.best_params_)
    print("📋 最佳交叉验证准确率:", grid_search.best_score_)
    
    # 6. 在测试集上评估
    best_model = grid_search.best_estimator_
    y_pred = best_model.predict(X_test)
    
    print("\\n=== 最终模型评估(测试集) ===")
    print("准确率:", accuracy_score(y_test, y_pred))
    print(classification_report(y_test, y_pred))

    运行输出:

    python 复制代码
    ✅ 最佳 K 值: {'n_neighbors': 1}
    📋 最佳交叉验证准确率: 0.9875
    \n=== 最终模型评估(测试集) ===
    准确率: 0.95
                  precision    recall  f1-score   support
    
               0       0.88      1.00      0.93         7
               1       1.00      0.92      0.96        13
    
        accuracy                           0.95        20
       macro avg       0.94      0.96      0.95        20
    weighted avg       0.96      0.95      0.95        20
相关推荐
我没胡说八道4 小时前
高校论文AI检测优化工具对比研究与实测分析(2026)
人工智能·深度学习·机器学习·计算机视觉·aigc·论文
love530love4 小时前
LiveTalking 数字人项目 Windows 部署完全指南(EPGF 架构)
人工智能·windows·python·架构·livetalking·epgf
遇事不決洛必達5 小时前
【Python基础】GIL 锁是什么及其对爬虫的影响
爬虫·python·线程·进程·gil锁
CryptoPP5 小时前
快速对接东京证券交易所API数据:实战指南与代码示例
开发语言·人工智能·windows·python·信息可视化·区块链
探物 AI6 小时前
把 MambaOut 塞进 YOLOv11:会有什么样的反应
python·yolo·计算机视觉
unicrom_深圳市由你创科技6 小时前
基于Spring AI框架的RAG应用
人工智能·spring·机器学习
如竟没有火炬6 小时前
最大矩阵——单调栈
数据结构·python·线性代数·算法·leetcode·矩阵
阳区欠7 小时前
【LangChain】LLM基础介绍
开发语言·python·langchain
Cosolar7 小时前
保姆级 CrewAI 教程:从零构建多智能体协作系统
人工智能·python·架构
GDAL7 小时前
使用 uv 管理 Python 版本
python·uv·版本