一、核心概念与核心思想
1. 核心定义
KNN算法通过计算待预测样本与训练集中所有样本的"距离",筛选出距离最近的K个样本(即"K个邻居"),然后根据这K个邻居的类别(分类任务)或数值(回归任务),采用投票法或平均法得到待预测样本的结果。
2. 核心思想
"物以类聚,人以群分"------ 待预测样本的结果,由其周围最相似的K个样本共同决定。KNN算法不构建显式的模型,而是将训练数据本身作为"模型",通过距离度量实现样本间的相似性判断。
3. 关键特征:懒惰学习
与决策树、逻辑回归等"急切学习"(Eager Learning)算法不同,KNN在训练阶段不进行参数估计或模型构建,仅存储训练数据;直到预测阶段,才通过计算距离完成推理。这种特性使得KNN训练速度快,但预测速度相对较慢(尤其当训练数据量大时)。
二、算法原理与核心要素
1. 距离度量:相似性的判断标准
距离度量是KNN算法的核心,用于量化两个样本之间的相似性------距离越小,样本越相似。常用的距离度量方法包括:
-
欧几里得距离(Euclidean Distance):最常用的距离度量,适用于连续型特征,计算两点在多维空间中的直线距离。公式为:对于样本X=(x₁,x₂,...,xₙ)和Y=(y₁,y₂,...,yₙ),距离D=√[(x₁-y₁)²+(x₂-y₂)²+...+(xₙ-yₙ)²]。
-
曼哈顿距离(Manhattan Distance):适用于高维数据或特征值为整数的场景,计算两点在多维空间中沿坐标轴的"直角距离"。公式为:D=|x₁-y₁|+|x₂-y₂|+...+|xₙ-yₙ|。
-
切比雪夫距离(Chebyshev Distance):适用于需要考虑"最远距离维度"的场景,如棋盘上两点的距离,等于各维度距离的最大值。公式为:D=max(|x₁-y₁|,|x₂-y₂|,...,|xₙ-yₙ|)。
-
余弦相似度(Cosine Similarity):适用于文本分类等特征维度极高的场景,通过计算两个样本向量的夹角余弦值衡量相似性,取值范围[-1,1],值越接近1越相似。公式为:cosθ=(X·Y)/(|X|·|Y|),其中X·Y为向量内积,|X|、|Y|为向量模长。
注意:距离度量的选择需结合数据类型。例如,连续型数据优先用欧几里得距离,稀疏高维数据(如文本TF-IDF特征)优先用余弦相似度。
2. K值的选择:邻居数量的权衡
K值是KNN算法中最关键的超参数,直接影响模型性能,其选择需要在"过拟合"与"欠拟合"之间权衡:
-
K值过小(如K=1):模型对噪声敏感,易发生过拟合。此时仅依赖最近的1个样本,若该样本为异常值,会直接导致预测错误。
-
K值过大(如K=训练集样本数):模型过于平滑,易发生欠拟合。此时所有样本的投票权重趋于平均,会忽略局部数据的特征,导致预测结果偏差。
-
最佳K值选择方法:通常选择奇数(避免分类任务中投票平局),通过交叉验证(如5折、10折交叉验证)遍历多个K值(如1-30),选择验证集准确率最高的K值。
3. 决策规则:结果的生成方式
决策规则根据任务类型(分类/回归)有所不同:
-
分类任务:采用"多数投票法"------统计K个邻居中出现次数最多的类别,作为待预测样本的类别。可优化为"加权投票法",即距离越近的邻居权重越高(权重=1/距离²),降低远邻对结果的干扰。
-
回归任务:采用"平均值法"------计算K个邻居的数值平均值,作为待预测样本的结果。同样可优化为"加权平均法",距离越近的邻居权重越高,提升预测精度。
三、算法实现步骤(以分类任务为例)
-
数据预处理:
-
特征标准化/归一化:由于距离度量对特征尺度敏感(如"身高"以cm为单位,"体重"以kg为单位,未标准化会导致身高对距离的影响主导结果),需将所有特征转换到同一尺度(如标准化:均值为0,方差为1;归一化:取值范围[0,1])。
-
处理缺失值:通过均值填充、中位数填充等方式补全缺失特征。
-
划分训练集与测试集:通常按7:3或8:2的比例划分,避免数据泄露。
-
-
确定超参数:通过交叉验证确定最优K值和距离度量方法。
-
计算距离:对于测试集中的每个待预测样本,计算其与训练集中所有样本的距离。
-
筛选K个近邻:将训练样本按距离从小到大排序,选取前K个样本作为待预测样本的邻居。
-
生成预测结果:采用多数投票法(或加权投票法)统计K个邻居的类别,输出预测类别。
-
模型评估:通过准确率、精确率、召回率、F1值等指标评估模型在测试集上的性能,若效果不佳则重新调整超参数。
四、KNN算法的优缺点
1. 优点
-
原理简单直观,易于理解和实现,无需复杂的数学推导。
-
对异常值不敏感(当K值较大时),鲁棒性较强。
-
适用于多分类任务,且无需修改算法核心逻辑。
-
训练速度快,仅需存储训练数据,无需模型训练过程。
2. 缺点
-
预测速度慢:预测时需计算与所有训练样本的距离,当训练数据量极大(如百万级)时,时间复杂度高。
-
空间复杂度高:需存储全部训练数据,对内存要求较高。
-
对高维数据敏感:高维空间中"距离"的区分度降低(维度灾难),导致模型性能下降。
-
对超参数K和距离度量依赖度高,需通过大量实验确定最优组合。
五、优化方法
-
数据降维:通过PCA(主成分分析)、LDA(线性判别分析)等方法降低特征维度,缓解维度灾难,提升预测速度。
-
索引优化:采用KD树(K-Dimensional Tree)、Ball树等数据结构对训练样本进行索引,减少预测时的距离计算次数,提升预测速度。
-
特征选择:剔除无关特征或冗余特征,保留对结果影响大的核心特征,降低计算成本。
-
加权投票/平均:引入距离权重,提升近邻对结果的贡献,优化预测精度。
六、应用场景
KNN算法适用于数据量不大、特征维度适中、对模型可解释性要求较高的场景,典型应用包括:
-
图像识别:如手写数字识别(MNIST数据集),通过像素特征的距离度量实现分类。
-
文本分类:如垃圾邮件识别、新闻分类,基于文本的TF-IDF特征或词向量计算余弦相似度,完成分类。
-
推荐系统:如基于用户行为的协同过滤推荐,通过计算用户或物品之间的相似度,推荐"相似用户喜欢的物品"。
-
医疗诊断:如基于患者的生理指标(血压、血糖等),对比历史病例数据,辅助判断疾病类型。
-
回归预测:如房价预测、气温预测,通过周边相似样本的数值进行加权平均,得到预测结果。
七、经典代码实现(Python+Scikit-learn)
以下以鸢尾花(Iris)分类数据集为例,展示KNN算法的实现流程:
#导入画图包 import matplotlib.pyplot as plt import numpy as np #1数据点集合 point1 = np.array([[7.7, 6.1], [3.1, 5.9], [8.6, 8.8], [9.5, 7.3], [3.9, 7.4], [5.0, 5.3], [1.0, 7.3]]) point2 = np.array([[0.2, 2.2], [4.5, 4.1], [0.5, 1.1], [2.7, 3.0], [4.7, 0.2], [2.9, 3.3], [7.3, 7.9]]) point3 = np.array([[9.2, 0.7], [9.2, 2.1], [7.3, 4.5], [8.9, 2.9], [9.5, 3.7], [7.7, 3.7], [9.4, 2.4]]) #整合一下数据 np.concatenate()表示合并或者拼接 np_train_data =np.concatenate((point1,point2,point3),axis=0) #添加一下标签 np_train_label=np.array([0]*len(point1)+[1]*len(point2)+[2]*len(point3)) print(np_train_label) #设置要预测的数据点 scatter_predict=np.array([3.35,2.46]) #计算L2距离也就是欧式距离 #第二种方法 def eucalidean_distance_1(x_point,y_point): x_1=np.array(x_point) y_1=np.array(y_point) return np.sqrt(np.sum((x_1-y_1)**2)) #初始化K值(超参数,自己设定的) k=5 #取最近的k个点 (如果想取出最近的k个点,我得计算出所有的距离,取距离最小的3个点) #存储所有的距离 distances=[] #计算预测数据点和所有训练数据的距离 for scatter_train in np_train_data: distances.append(eucalidean_distance_1(scatter_predict,scatter_train)) #取距离最小的k个数据点 nearest=np.argsort(np.array(distances)) #获取最近的k个数据点的数据和标签 #获取数据标签 topk_label= [np_train_label[i] for i in nearest[:k] ] #列表递推表达式 topk_data= [np_train_data[i] for i in nearest[:k] ] # class_0=[] class_1=[] class_2=[] label_predict=0 for item in topk_label: if item==0: class_0.append(0) if item==1: class_1.append(1) if item==2: class_2.append(2) #比较那个类多 index_max=np.argmax([len(class_0),len(class_1),len(class_2)]) print('最后的类别',index_max) #画图 #第一步获得5条线的距离 list_knn_distance=[distances[nearest[:k][j]] for j in range(k)] #annotate 根据坐标写文字 for i in range(k): plt.plot([scatter_predict[0],topk_data[i][0]],[scatter_predict[1],topk_data[i][1]]) #annotate的第一个参数表示要展示的文本内容,第二个参数是要展示的文本内容所在的坐标 plt.annotate("%s" % round(list_knn_distance[i], 2), xy=((scatter_predict[0] + topk_data[i][0]) / 2,(scatter_predict[1] + topk_data[i][1]) / 2)) #画原始的数据点 #下面画原始点 plt.xlabel('x axis label') plt.ylabel('y axis label') plt.scatter(np_train_data[np_train_label==0,0],np_train_data[np_train_label==0,1],marker='*') plt.scatter(np_train_data[np_train_label==1,0],np_train_data[np_train_label==1,1],marker='^') plt.scatter(np_train_data[np_train_label==2,0],np_train_data[np_train_label==2,1],marker='s') plt.scatter(scatter_predict[0],scatter_predict[1],marker='o') plt.show()
八、总结
KNN算法是监督学习中入门级的经典算法,其核心优势在于简单直观、易于实现,适合作为机器学习的入门实践案例。但由于其"懒惰学习"的特性,在处理大规模、高维数据时存在效率瓶颈,需通过降维、索引优化等方法提升性能。在实际应用中,需结合数据特点合理选择距离度量、K值及决策规则,以实现最优的预测效果。