大数据-195 KNN/K近邻算法实战:欧氏距离+投票机制手写实现,含可视化与调参要点

TL;DR

  • 场景:小样本、低维特征的监督分类,用"相似样本投票"快速出结果
  • 结论:KNN核心是距离度量+K个邻居投票;特征尺度决定距离可信度,K影响偏差/方差
  • 产出:一套可复现实验(葡萄酒二分类)、距离计算/排序/投票流程、KNN函数封装

监督学习算法

KNN/K近邻算法

K近邻算法(K-Nearest Neighbors, KNN)的核心思想是基于样本间的距离度量来判断相似性。具体来说,它通过计算待分类样本与训练集中各个样本之间的特征空间距离(常用的距离度量包括欧式距离、曼哈顿距离或余弦相似度等),如果两个样本在特征空间中的距离足够接近,就认为它们具有较高的相似度,很可能属于同一类别。

在实际应用中,仅依靠单个最近邻样本进行分类容易受到噪声和异常值的影响,导致分类结果不稳定。因此KNN算法会选取距离最近的K个样本(即K个最近邻),这些近邻样本构成了待分类样本的局部邻域。算法会统计这些近邻样本的类别标签分布情况(标签代表样本所属的真实类别,如"猫"、"狗"等分类结果),然后采用投票机制:将K个近邻中出现次数最多的类别作为待分类样本的预测结果。

举例来说,在图像分类任务中,假设K=5,待分类图片的5个最近邻中包含3张"猫"和2张"狗"的图片,那么算法就会将该图片分类为"猫"。这种基于局部邻域投票的机制使得KNN算法对噪声数据具有较好的鲁棒性,同时K值的选择(通常通过交叉验证确定)会直接影响算法的分类性能。

实现过程

假设 X_test 待标记的数据样本,X_train 为已标记的数据集。

  • 遍历已标记数据集中的所有样本,计算每个样本与待标记的点的距离,并把距离保存在 Distance 数组中。
  • 对 Distance 数组进行排序,取距离最近的 K 个点,记为X_knn。
  • 在 X_knn 中统计每个类别的个数,即 class0 在 X_knn 中有几个样本,class1 在 X_knn中有几个样本
  • 待标记样本的类别,就是在 K_knn 中样本个数最多的那类别。

距离的确定

该算法的【距离】在二维坐标轴就表示两点之间的距离,计算距离的公式有很多。我们常说的欧拉公式,即"欧氏距离",回忆一下,一个平面直角坐标系上,如何计算两点之间的距离?一个立体直角坐标系上,又如何计算两点之间的距离?

当特征数量有很多个形式多维空间时,再用上述的写法就不方便了,我们换一个写法,用 X 加下角标的方式表示特征维度,则在 N 维空间中,有两个点 A 和B,它们坐标分别为:

而在机器学习中,坐标轴上的 x1、x2、x3等,正是我们样本上的N 个特征。

算法优点

算法参数k(通常称为k近邻算法中的k值)是一个关键的超参数,它决定了在预测时要考虑多少个最近的邻居样本。具体来说:

  1. k值的作用机制

    • 当k=1时,模型仅考虑最近的一个样本点,这会使决策边界变得非常复杂
    • 随着k值增大,模型会考虑更多邻居的投票结果,使得决策边界趋于平滑
  2. k值与模型偏差的关系

    • 较大的k值(如k=15)会使模型偏差增大,因为决策会基于更多样本的平均值
    • 这会使模型对个别噪声数据点(如标注错误的样本)的敏感度降低
    • 极端情况下,当k接近训练集大小时,模型会简单地预测多数类
  3. k值与模型方差的关系

    • 较小的k值(如k=1或3)会使模型方差增大
    • 模型容易捕捉到训练数据中的随机波动和噪声
    • 在图像分类任务中,过小的k值可能导致模型对像素级别的变化过于敏感
  4. 参数选择的实践经验

    • 通常从k=5开始尝试,这是经验法则
    • 对于特征维度较高的数据(如文本数据),可能需要更大的k值
    • 可以通过交叉验证来寻找最优k值,常用方法是绘制k值与准确率的曲线图
    • 在sklearn中,可以使用GridSearchCV进行k值调优
  5. 不同场景下的k值选择

    • 对于噪声较多的数据集(如传感器数据),建议使用较大的k值(7-15)
    • 对于清晰可分的数据(如MNIST手写数字),较小的k值(3-5)可能更合适
    • 当类别分布不平衡时,k值不宜过小,否则容易受到少数类样本的影响
  6. 与其他参数的协同作用

    • k值选择还需考虑距离度量方式(欧式距离、曼哈顿距离等)
    • 在特征缩放后,k值的效果可能会发生变化
    • 加权kNN中,k值的影响会相对减弱

算法变种

变种 1

默认情况下,在计算距离时,权重都是相同的,但实际上可以针对不同的领居指定不同的距离权重,比如距离越近权重越高。 可以通过指定算法的 weights 参数来实现。

变种2

使用一定半径内的点取代距离最近的 k 个点

  • 在 scikit-learn 中,RadiusNeighborsClassifier 实现了这种算法的变种
  • 当数据采样不均匀时,该算法变种可以获得更好的性能

代码实现

导入相关包

python 复制代码
# 全部行都能输出
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# 解决坐标轴刻度负号乱码
plt.rcParams['axes.unicode_minus'] = False
# 解决中文乱码问题
plt.rcParams['font.sans-serif'] = ['Simhei']
plt.style.use('ggplot')
# plt.figure(figsize=(2,3),dpi=720)

执行结果如下:

构建已经分类好的原始数据集

首先随机设置十个样本表示十杯酒,这里去了部分样本点。 为了方便验证,使用 Python 的字典 dict 构建数据集,然后再将其转换为DataFrame 格式。

python 复制代码
rowdata = {
    '颜色深度': [14.13,13.2,13.16,14.27,13.24,12.07,12.43,11.79,12.37,12.04],
    '酒精浓度': [5.64,4.28,5.68,4.80,4.22,2.76,3.94,3.1,2.12,2.6],
    '品种': [0,0,0,0,0,1,1,1,1,1]
}
# 0 代表 "黑皮诺",1 代表 "赤霞珠"
wine_data = pd.DataFrame(rowdata)

执行结果如下图所示: 编写代码进行处理:

python 复制代码
X = np.array(wine_data.iloc[:,0:2]) #我们把特征(酒的属性)放在X
y = np.array(wine_data.iloc[:,-1]) #把标签(酒的类别)放在Y
#探索数据,假如我们给出新数据[12.03,4.1] ,你能猜出这杯红酒是什么类别么?
new_data = np.array([12.03,4.1])
plt.scatter(X[y==1,0], X[y==1,1], color='red', label='赤霞珠') #画出标签y为1
的、关于"赤霞珠"的散点
plt.scatter(X[y==0,0], X[y==0,1], color='purple', label='黑皮诺') #画出标签y为0
的、关于"黑皮诺"的散点
plt.scatter(new_data[0],new_data[1], color='yellow') # 新数据点
new_data
plt.xlabel('酒精浓度')
plt.ylabel('颜色深度')
plt.legend(loc='lower right')
plt.savefig('葡萄酒样本.png')

执行结果如下如下所示:

计算已知类别数据集中的点与当前之间的距离

我们使用欧式距离公式,计算新数据点 new_data 与现存的 X 数据集每一个点的距离:

python 复制代码
from math import sqrt
distance = [sqrt(np.sum((x-new_data)**2)) for x in X ]
distance

执行结果如下:

shell 复制代码
[2.6041505332833594,
 1.1837651794169315,
 1.9424983912477256,
 2.3468276459936295,
 1.2159358535712326,
 1.3405968819895113,
 0.4308131845707605,
 1.0283968105745949,
 2.0089798406156287,
 1.500033332962971]

运行结果如下所示:

将距离升序排列 选取距离最小的 K 个点

python 复制代码
sort_dist = np.argsort(distance)
sort_dist

执行结果如下所示: 所以,6、7、1 为最近的3个"数据点"的索引值,那么这些索引值对应的原数据的标签是什么?

python 复制代码
k = 3
topK = [y[i] for i in sort_dist[:k]]
topK

执行结果如下所示:

确定前k个点所在类别的计数

python 复制代码
pd.Series(topK).value_counts().index[0]

执行结果如下所示:

封装函数

将数据的过程封装成一个函数:

python 复制代码
def KNN(new_data,dataSet,k):
    '''
    函数功能:KNN分类器
    参数说明:
    new_data: 需要预测分类的数据集
    dataSet: 已知分类标签的数据集
    k: k-近邻算法参数,选择距离最小的k个点
    return:
    result: 分类结果
    '''
    from math import sqrt
    from collections import Counter
    import numpy as np
    import pandas as pd
    result = []
    distance = [sqrt(np.sum((x-new_data)**2)) for x in
                np.array(dataSet.iloc[:,0:2])]
    sort_dist = np.argsort(distance)
    topK = [dataSet.iloc[:,-1][i] for i in sort_dist[:k]]
    result.append(pd.Series(topK).value_counts().index[0])
    return result

测试函数的运行结果:

python 复制代码
new_data=np.array([12.03,4.1])
k = 3
KNN(new_data,wine_data,k)

执行结果如下所示:

错误速查

症状 根因 定位方法 修复方案
中文/负号乱码 字体未安装或字体名不匹配;axes.unicode_minus未生效 查看控制台字体 fallback 警告;图中中文变方块 安装/指定可用中文字体(如 SimHei/微软雅黑);保留 plt.rcParams['axes.unicode_minus']=False
散点图坐标轴含义对不上 X列顺序与 xlabel/ylabel语义不一致 对照 wine_data.iloc[:,0:2] 的列顺序与 plt.xlabel/plt.ylabel 统一"X的第0/1列"与坐标轴标签;或交换 plt.scatter(..., X[...,0], X[...,1]) 的轴含义
代码运行直接报 SyntaxError 注释被断行,形成裸文本(如"的、关于...") 报错行通常在 plt.scatter(... label=...) 附近 把断行文字改为 # 注释或移到同一行注释里
Simhei 不生效/报找不到字体 字体名大小写或字体不存在;Linux常见无该字体 matplotlib 日志 findfont/Font family 改为系统存在字体名;或用字体文件路径加载
保存图片为空白或缺元素 savefig 与绘制/显示顺序不当,或画布被清空 检查是否在 plt.show() 后保存;是否多次创建 figure 先绘制完再 plt.savefig(...);必要时显式 plt.figure() 管理画布
KNN 结果不稳定/偏某一类 K过小/过大;类别不平衡;距离被尺度主导 观察不同 K 的预测变化;看特征量纲差异 做特征缩放;用交叉验证选 K;必要时改用加权KNN
高维数据效果明显变差 距离集中现象(维度灾难),欧氏距离区分度下降 维度升高后近邻距离差距变小 降维/特征选择;换距离度量;或改用更适合高维的模型
半径近邻找不到邻居/报空集 半径设置不合理、采样不均匀 输出邻居数量或检查半径覆盖范围 调整 radius;设置回退策略(无邻居时用默认类/扩大半径)

其他系列

🚀 AI篇持续更新中(长期更新)

AI炼丹日志-29 - 字节跳动 DeerFlow 深度研究框斜体样式架 私有部署 测试上手 架构研究 ,持续打造实用AI工具指南! AI研究-132 Java 生态前沿 2025:Spring、Quarkus、GraalVM、CRaC 与云原生落地

💻 Java篇持续更新中(长期更新)

Java-207 RabbitMQ Direct 交换器路由:RoutingKey 精确匹配、队列多绑定与日志分流实战 MyBatis 已完结,Spring 已完结,Nginx已完结,Tomcat已完结,分布式服务已完结,Dubbo已完结,MySQL已完结,MongoDB已完结,Neo4j已完结,FastDFS 已完结,OSS已完结,GuavaCache已完结,EVCache已完结,RabbitMQ正在更新... 深入浅出助你打牢基础!

📊 大数据板块已完成多项干货更新(300篇):

包括 Hadoop、Hive、Kafka、Flink、ClickHouse、Elasticsearch 等二十余项核心组件,覆盖离线+实时数仓全栈! 大数据-278 Spark MLib - 基础介绍 机器学习算法 梯度提升树 GBDT案例 详解

相关推荐
毕小宝2 小时前
Elasticsearch 条件字段为 date 类型时注意事项
大数据·elasticsearch·搜索引擎
最贪吃的虎2 小时前
JVM扫盲:内存模型
java·运维·jvm·后端
ONExiaobaijs2 小时前
基于Spring Boot的校园闲置物品交易系统
java·spring boot·后端
Penge6662 小时前
Go JSON 序列化大整数丢失精度分析
后端·go
爬山算法2 小时前
Hibernate(2)Hibernate的核心组件有哪些?
java·后端·hibernate
码界奇点2 小时前
基于Spring Boot和Vue的多通道支付网关系统设计与实现
vue.js·spring boot·后端·毕业设计·鸿蒙系统·源代码管理
IT 行者2 小时前
Spring Boot 升级之HTTP客户端调整:HttpExchange 与 Feign Client 深度对比分析
spring boot·后端·http
小蒜学长2 小时前
python基于Python的医疗机构药品及耗材信息管理系统(代码+数据库+LW)
数据库·spring boot·后端·python
seekCat2 小时前
C#中的Linq(Language Integrated Query)
后端