大数据-195 KNN/K近邻算法实战:欧氏距离+投票机制手写实现,含可视化与调参要点

TL;DR

  • 场景:小样本、低维特征的监督分类,用"相似样本投票"快速出结果
  • 结论:KNN核心是距离度量+K个邻居投票;特征尺度决定距离可信度,K影响偏差/方差
  • 产出:一套可复现实验(葡萄酒二分类)、距离计算/排序/投票流程、KNN函数封装

监督学习算法

KNN/K近邻算法

K近邻算法(K-Nearest Neighbors, KNN)的核心思想是基于样本间的距离度量来判断相似性。具体来说,它通过计算待分类样本与训练集中各个样本之间的特征空间距离(常用的距离度量包括欧式距离、曼哈顿距离或余弦相似度等),如果两个样本在特征空间中的距离足够接近,就认为它们具有较高的相似度,很可能属于同一类别。

在实际应用中,仅依靠单个最近邻样本进行分类容易受到噪声和异常值的影响,导致分类结果不稳定。因此KNN算法会选取距离最近的K个样本(即K个最近邻),这些近邻样本构成了待分类样本的局部邻域。算法会统计这些近邻样本的类别标签分布情况(标签代表样本所属的真实类别,如"猫"、"狗"等分类结果),然后采用投票机制:将K个近邻中出现次数最多的类别作为待分类样本的预测结果。

举例来说,在图像分类任务中,假设K=5,待分类图片的5个最近邻中包含3张"猫"和2张"狗"的图片,那么算法就会将该图片分类为"猫"。这种基于局部邻域投票的机制使得KNN算法对噪声数据具有较好的鲁棒性,同时K值的选择(通常通过交叉验证确定)会直接影响算法的分类性能。

实现过程

假设 X_test 待标记的数据样本,X_train 为已标记的数据集。

  • 遍历已标记数据集中的所有样本,计算每个样本与待标记的点的距离,并把距离保存在 Distance 数组中。
  • 对 Distance 数组进行排序,取距离最近的 K 个点,记为X_knn。
  • 在 X_knn 中统计每个类别的个数,即 class0 在 X_knn 中有几个样本,class1 在 X_knn中有几个样本
  • 待标记样本的类别,就是在 K_knn 中样本个数最多的那类别。

距离的确定

该算法的【距离】在二维坐标轴就表示两点之间的距离,计算距离的公式有很多。我们常说的欧拉公式,即"欧氏距离",回忆一下,一个平面直角坐标系上,如何计算两点之间的距离?一个立体直角坐标系上,又如何计算两点之间的距离?

当特征数量有很多个形式多维空间时,再用上述的写法就不方便了,我们换一个写法,用 X 加下角标的方式表示特征维度,则在 N 维空间中,有两个点 A 和B,它们坐标分别为:

而在机器学习中,坐标轴上的 x1、x2、x3等,正是我们样本上的N 个特征。

算法优点

算法参数k(通常称为k近邻算法中的k值)是一个关键的超参数,它决定了在预测时要考虑多少个最近的邻居样本。具体来说:

  1. k值的作用机制

    • 当k=1时,模型仅考虑最近的一个样本点,这会使决策边界变得非常复杂
    • 随着k值增大,模型会考虑更多邻居的投票结果,使得决策边界趋于平滑
  2. k值与模型偏差的关系

    • 较大的k值(如k=15)会使模型偏差增大,因为决策会基于更多样本的平均值
    • 这会使模型对个别噪声数据点(如标注错误的样本)的敏感度降低
    • 极端情况下,当k接近训练集大小时,模型会简单地预测多数类
  3. k值与模型方差的关系

    • 较小的k值(如k=1或3)会使模型方差增大
    • 模型容易捕捉到训练数据中的随机波动和噪声
    • 在图像分类任务中,过小的k值可能导致模型对像素级别的变化过于敏感
  4. 参数选择的实践经验

    • 通常从k=5开始尝试,这是经验法则
    • 对于特征维度较高的数据(如文本数据),可能需要更大的k值
    • 可以通过交叉验证来寻找最优k值,常用方法是绘制k值与准确率的曲线图
    • 在sklearn中,可以使用GridSearchCV进行k值调优
  5. 不同场景下的k值选择

    • 对于噪声较多的数据集(如传感器数据),建议使用较大的k值(7-15)
    • 对于清晰可分的数据(如MNIST手写数字),较小的k值(3-5)可能更合适
    • 当类别分布不平衡时,k值不宜过小,否则容易受到少数类样本的影响
  6. 与其他参数的协同作用

    • k值选择还需考虑距离度量方式(欧式距离、曼哈顿距离等)
    • 在特征缩放后,k值的效果可能会发生变化
    • 加权kNN中,k值的影响会相对减弱

算法变种

变种 1

默认情况下,在计算距离时,权重都是相同的,但实际上可以针对不同的领居指定不同的距离权重,比如距离越近权重越高。 可以通过指定算法的 weights 参数来实现。

变种2

使用一定半径内的点取代距离最近的 k 个点

  • 在 scikit-learn 中,RadiusNeighborsClassifier 实现了这种算法的变种
  • 当数据采样不均匀时,该算法变种可以获得更好的性能

代码实现

导入相关包

python 复制代码
# 全部行都能输出
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# 解决坐标轴刻度负号乱码
plt.rcParams['axes.unicode_minus'] = False
# 解决中文乱码问题
plt.rcParams['font.sans-serif'] = ['Simhei']
plt.style.use('ggplot')
# plt.figure(figsize=(2,3),dpi=720)

执行结果如下:

构建已经分类好的原始数据集

首先随机设置十个样本表示十杯酒,这里去了部分样本点。 为了方便验证,使用 Python 的字典 dict 构建数据集,然后再将其转换为DataFrame 格式。

python 复制代码
rowdata = {
    '颜色深度': [14.13,13.2,13.16,14.27,13.24,12.07,12.43,11.79,12.37,12.04],
    '酒精浓度': [5.64,4.28,5.68,4.80,4.22,2.76,3.94,3.1,2.12,2.6],
    '品种': [0,0,0,0,0,1,1,1,1,1]
}
# 0 代表 "黑皮诺",1 代表 "赤霞珠"
wine_data = pd.DataFrame(rowdata)

执行结果如下图所示: 编写代码进行处理:

python 复制代码
X = np.array(wine_data.iloc[:,0:2]) #我们把特征(酒的属性)放在X
y = np.array(wine_data.iloc[:,-1]) #把标签(酒的类别)放在Y
#探索数据,假如我们给出新数据[12.03,4.1] ,你能猜出这杯红酒是什么类别么?
new_data = np.array([12.03,4.1])
plt.scatter(X[y==1,0], X[y==1,1], color='red', label='赤霞珠') #画出标签y为1
的、关于"赤霞珠"的散点
plt.scatter(X[y==0,0], X[y==0,1], color='purple', label='黑皮诺') #画出标签y为0
的、关于"黑皮诺"的散点
plt.scatter(new_data[0],new_data[1], color='yellow') # 新数据点
new_data
plt.xlabel('酒精浓度')
plt.ylabel('颜色深度')
plt.legend(loc='lower right')
plt.savefig('葡萄酒样本.png')

执行结果如下如下所示:

计算已知类别数据集中的点与当前之间的距离

我们使用欧式距离公式,计算新数据点 new_data 与现存的 X 数据集每一个点的距离:

python 复制代码
from math import sqrt
distance = [sqrt(np.sum((x-new_data)**2)) for x in X ]
distance

执行结果如下:

shell 复制代码
[2.6041505332833594,
 1.1837651794169315,
 1.9424983912477256,
 2.3468276459936295,
 1.2159358535712326,
 1.3405968819895113,
 0.4308131845707605,
 1.0283968105745949,
 2.0089798406156287,
 1.500033332962971]

运行结果如下所示:

将距离升序排列 选取距离最小的 K 个点

python 复制代码
sort_dist = np.argsort(distance)
sort_dist

执行结果如下所示: 所以,6、7、1 为最近的3个"数据点"的索引值,那么这些索引值对应的原数据的标签是什么?

python 复制代码
k = 3
topK = [y[i] for i in sort_dist[:k]]
topK

执行结果如下所示:

确定前k个点所在类别的计数

python 复制代码
pd.Series(topK).value_counts().index[0]

执行结果如下所示:

封装函数

将数据的过程封装成一个函数:

python 复制代码
def KNN(new_data,dataSet,k):
    '''
    函数功能:KNN分类器
    参数说明:
    new_data: 需要预测分类的数据集
    dataSet: 已知分类标签的数据集
    k: k-近邻算法参数,选择距离最小的k个点
    return:
    result: 分类结果
    '''
    from math import sqrt
    from collections import Counter
    import numpy as np
    import pandas as pd
    result = []
    distance = [sqrt(np.sum((x-new_data)**2)) for x in
                np.array(dataSet.iloc[:,0:2])]
    sort_dist = np.argsort(distance)
    topK = [dataSet.iloc[:,-1][i] for i in sort_dist[:k]]
    result.append(pd.Series(topK).value_counts().index[0])
    return result

测试函数的运行结果:

python 复制代码
new_data=np.array([12.03,4.1])
k = 3
KNN(new_data,wine_data,k)

执行结果如下所示:

错误速查

症状 根因 定位方法 修复方案
中文/负号乱码 字体未安装或字体名不匹配;axes.unicode_minus未生效 查看控制台字体 fallback 警告;图中中文变方块 安装/指定可用中文字体(如 SimHei/微软雅黑);保留 plt.rcParams['axes.unicode_minus']=False
散点图坐标轴含义对不上 X列顺序与 xlabel/ylabel语义不一致 对照 wine_data.iloc[:,0:2] 的列顺序与 plt.xlabel/plt.ylabel 统一"X的第0/1列"与坐标轴标签;或交换 plt.scatter(..., X[...,0], X[...,1]) 的轴含义
代码运行直接报 SyntaxError 注释被断行,形成裸文本(如"的、关于...") 报错行通常在 plt.scatter(... label=...) 附近 把断行文字改为 # 注释或移到同一行注释里
Simhei 不生效/报找不到字体 字体名大小写或字体不存在;Linux常见无该字体 matplotlib 日志 findfont/Font family 改为系统存在字体名;或用字体文件路径加载
保存图片为空白或缺元素 savefig 与绘制/显示顺序不当,或画布被清空 检查是否在 plt.show() 后保存;是否多次创建 figure 先绘制完再 plt.savefig(...);必要时显式 plt.figure() 管理画布
KNN 结果不稳定/偏某一类 K过小/过大;类别不平衡;距离被尺度主导 观察不同 K 的预测变化;看特征量纲差异 做特征缩放;用交叉验证选 K;必要时改用加权KNN
高维数据效果明显变差 距离集中现象(维度灾难),欧氏距离区分度下降 维度升高后近邻距离差距变小 降维/特征选择;换距离度量;或改用更适合高维的模型
半径近邻找不到邻居/报空集 半径设置不合理、采样不均匀 输出邻居数量或检查半径覆盖范围 调整 radius;设置回退策略(无邻居时用默认类/扩大半径)

其他系列

🚀 AI篇持续更新中(长期更新)

AI炼丹日志-29 - 字节跳动 DeerFlow 深度研究框斜体样式架 私有部署 测试上手 架构研究 ,持续打造实用AI工具指南! AI研究-132 Java 生态前沿 2025:Spring、Quarkus、GraalVM、CRaC 与云原生落地

💻 Java篇持续更新中(长期更新)

Java-207 RabbitMQ Direct 交换器路由:RoutingKey 精确匹配、队列多绑定与日志分流实战 MyBatis 已完结,Spring 已完结,Nginx已完结,Tomcat已完结,分布式服务已完结,Dubbo已完结,MySQL已完结,MongoDB已完结,Neo4j已完结,FastDFS 已完结,OSS已完结,GuavaCache已完结,EVCache已完结,RabbitMQ正在更新... 深入浅出助你打牢基础!

📊 大数据板块已完成多项干货更新(300篇):

包括 Hadoop、Hive、Kafka、Flink、ClickHouse、Elasticsearch 等二十余项核心组件,覆盖离线+实时数仓全栈! 大数据-278 Spark MLib - 基础介绍 机器学习算法 梯度提升树 GBDT案例 详解

相关推荐
Godspeed Zhao31 分钟前
Level 4自动驾驶系统设计2——功能与场景2
人工智能·机器学习·自动驾驶
Jerry.张蒙33 分钟前
AI工具Opencode助力SAP提质增效实践
大数据·运维·服务器·人工智能·运维开发
JustHappy8 小时前
古法编程秘籍(七):互联网到底是什么?把两台电脑怎么说话搞懂就够了
前端·后端·网络协议
Hommy888 小时前
【剪映小助手】添加图片接口(Add Images)
后端·github·剪映小助手·视频剪辑自动化
Data-Miner8 小时前
集团数字化转型大数据平台整体建设方案
大数据
GetcharZp9 小时前
别再盲目用 OpenCV 读图了,这才是 CV 预处理的终极杀手锏!
后端
guslegend9 小时前
大模型驱动大数据SRE智能运维
大数据·运维
硅谷秋水9 小时前
面向长上下文自动驾驶的规划对齐Token压缩
人工智能·深度学习·机器学习·计算机视觉·自动驾驶
跨境小彭10 小时前
2026 Temu 合规新玩法,凌风 ERP 优化 POD 运营效率
大数据·跨境电商·temu·shein
weixin_3975740910 小时前
从“点状试点“到“全面智能化“:制造企业AI落地的现实路径
大数据·人工智能·制造