Python 机器学习核心入门与实战进阶 Day 2 - KNN(K-近邻算法)分类实战与调参

✅ 今日目标

  • 理解 KNN 的原理与"以邻为近"的思想
  • 掌握 K 值选择与模型效果的关系
  • 学会使用 sklearn 训练 KNN 模型
  • 实现 KNN 分类 + 模型评估 + 超参数调优

📘 一、KNN 算法原理

KNN(K-Nearest Neighbors)核心思想:

给定一个待预测样本,找到训练集中"距离它最近"的 K 个样本,用这些样本的类别进行多数投票预测。

特点 描述
模型类型 懒惰学习(无显式训练过程)
距离度量 欧几里得距离(默认)或自定义
参数调优 K 值、距离函数、权重方式
适用场景 数据量不大,维度不高,需快速建模时

🧪 二、KNN 分类流程(代码实践)

python 复制代码
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# 生成数据
X = [[1], [2], [3], [10], [11], [12]]
y = [0, 0, 0, 1, 1, 1]

# 训练测试划分
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33)

# 建模
model = KNeighborsClassifier(n_neighbors=3)
model.fit(X_train, y_train)

# 预测
y_pred = model.predict(X_test)
print("准确率:", accuracy_score(y_test, y_pred))

🧠 三、K 值选择对模型的影响

K 值 模型表现
K 太小 模型过拟合,受噪声影响大
K 太大 模型过于平滑,泛化能力下降
一般建议 使用奇数,避免投票平局;通过交叉验证选择最佳 K

🔧 四、模型调参建议(使用 GridSearchCV)

python 复制代码
from sklearn.model_selection import GridSearchCV

param_grid = {'n_neighbors': list(range(1, 11))}
grid_search = GridSearchCV(KNeighborsClassifier(), param_grid, cv=5)
grid_search.fit(X_train, y_train)

print("最优K值:", grid_search.best_params_)
print("最佳准确率:", grid_search.best_score_)

🧾 今日总结

技能 工具
快速建模 KNeighborsClassifier
评估效果 accuracy_score()
参数调优 GridSearchCV()
可视化分类边界 使用 matplotlibseaborn

🧪 建议练习脚本

  • 使用 sklearn 中的 KNN 模型实现学生是否及格分类

  • 尝试多种 K 值进行训练,并绘制准确率变化图

  • 使用 GridSearchCV 找出最优 K

  • 可视化分类边界(二维特征时)

    python 复制代码
    # KNN 分类实战演示:学生是否及格预测
    
    from sklearn.datasets import make_classification
    from sklearn.model_selection import train_test_split, GridSearchCV
    from sklearn.neighbors import KNeighborsClassifier
    from sklearn.metrics import accuracy_score, classification_report
    import matplotlib.pyplot as plt
    import pandas as pd
    import numpy as np
    
    plt.rcParams['font.family'] = 'Arial Unicode MS'  # Mac 用户可用
    plt.rcParams['axes.unicode_minus'] = False
    # 1. 模拟学生成绩数据(两个特征:成绩 + 性别)
    np.random.seed(42)
    size = 100
    scores = np.random.randint(40, 100, size)
    genders = np.random.choice([0, 1], size=size)  # 0=女, 1=男
    pass_label = (scores >= 60).astype(int)
    
    X = np.column_stack(((scores - scores.mean()) / scores.std(), genders))  # 标准化+性别
    y = pass_label
    
    # 2. 拆分训练集与测试集
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    
    # 3. 不同 K 值准确率比较
    acc_list = []
    k_values = range(1, 16)
    
    for k in k_values:
        model = KNeighborsClassifier(n_neighbors=k)
        model.fit(X_train, y_train)
        y_pred = model.predict(X_test)
        acc = accuracy_score(y_test, y_pred)
        acc_list.append(acc)
    
    # 4. 可视化不同 K 值的准确率
    plt.plot(k_values, acc_list, marker='o', linestyle='--')
    plt.title("不同 K 值下的准确率")
    plt.xlabel("K 值")
    plt.ylabel("准确率")
    plt.xticks(k_values)
    plt.grid(True)
    plt.tight_layout()
    plt.show()
    
    # 5. 使用 GridSearchCV 找最佳 K
    param_grid = {'n_neighbors': list(range(1, 16))}
    grid_search = GridSearchCV(KNeighborsClassifier(), param_grid, cv=5)
    grid_search.fit(X_train, y_train)
    
    print("✅ 最佳 K 值:", grid_search.best_params_)
    print("📋 最佳交叉验证准确率:", grid_search.best_score_)
    
    # 6. 在测试集上评估
    best_model = grid_search.best_estimator_
    y_pred = best_model.predict(X_test)
    
    print("\\n=== 最终模型评估(测试集) ===")
    print("准确率:", accuracy_score(y_test, y_pred))
    print(classification_report(y_test, y_pred))

    运行输出:

    python 复制代码
    ✅ 最佳 K 值: {'n_neighbors': 1}
    📋 最佳交叉验证准确率: 0.9875
    \n=== 最终模型评估(测试集) ===
    准确率: 0.95
                  precision    recall  f1-score   support
    
               0       0.88      1.00      0.93         7
               1       1.00      0.92      0.96        13
    
        accuracy                           0.95        20
       macro avg       0.94      0.96      0.95        20
    weighted avg       0.96      0.95      0.95        20
相关推荐
测试19989 小时前
Web自动化测试之测试用例流程设计
自动化测试·软件测试·python·selenium·测试工具·职场和发展·测试用例
没有口袋啦9 小时前
《机器学习与深度学习》入门
人工智能·深度学习·机器学习
Juchecar10 小时前
示例说明 Flask 调试模式的安全隐患
python
大翻哥哥10 小时前
Python 2025:数据分析平台智能化转型与新范式
人工智能·python·数据分析
love530love11 小时前
EPGF 架构为什么能保持长效和稳定?
运维·开发语言·人工智能·windows·python·架构·系统架构
我要学习别拦我~11 小时前
逻辑回归中的决策边界解析与应用实例
经验分享·机器学习·逻辑回归
傻啦嘿哟11 小时前
用Requests+BeautifulSoup实现天气预报数据采集:从入门到实战
开发语言·chrome·python
兆。12 小时前
python全栈-数据可视化
开发语言·python·信息可视化
Q_Q51100828512 小时前
python+nodejs+springboot在线车辆租赁信息管理信息可视化系统
spring boot·python·信息可视化·django·flask·node.js·php
站大爷IP12 小时前
Python条件控制:让程序学会"思考"的魔法
python