KNN算法学习笔记

一、核心概念与核心思想

1. 核心定义

KNN算法通过计算待预测样本与训练集中所有样本的"距离",筛选出距离最近的K个样本(即"K个邻居"),然后根据这K个邻居的类别(分类任务)或数值(回归任务),采用投票法或平均法得到待预测样本的结果。

2. 核心思想

"物以类聚,人以群分"------ 待预测样本的结果,由其周围最相似的K个样本共同决定。KNN算法不构建显式的模型,而是将训练数据本身作为"模型",通过距离度量实现样本间的相似性判断。

3. 关键特征:懒惰学习

与决策树、逻辑回归等"急切学习"(Eager Learning)算法不同,KNN在训练阶段不进行参数估计或模型构建,仅存储训练数据;直到预测阶段,才通过计算距离完成推理。这种特性使得KNN训练速度快,但预测速度相对较慢(尤其当训练数据量大时)。

二、算法原理与核心要素

1. 距离度量:相似性的判断标准

距离度量是KNN算法的核心,用于量化两个样本之间的相似性------距离越小,样本越相似。常用的距离度量方法包括:

  • 欧几里得距离(Euclidean Distance):最常用的距离度量,适用于连续型特征,计算两点在多维空间中的直线距离。公式为:对于样本X=(x₁,x₂,...,xₙ)和Y=(y₁,y₂,...,yₙ),距离D=√[(x₁-y₁)²+(x₂-y₂)²+...+(xₙ-yₙ)²]。

  • 曼哈顿距离(Manhattan Distance):适用于高维数据或特征值为整数的场景,计算两点在多维空间中沿坐标轴的"直角距离"。公式为:D=|x₁-y₁|+|x₂-y₂|+...+|xₙ-yₙ|。

  • 切比雪夫距离(Chebyshev Distance):适用于需要考虑"最远距离维度"的场景,如棋盘上两点的距离,等于各维度距离的最大值。公式为:D=max(|x₁-y₁|,|x₂-y₂|,...,|xₙ-yₙ|)。

  • 余弦相似度(Cosine Similarity):适用于文本分类等特征维度极高的场景,通过计算两个样本向量的夹角余弦值衡量相似性,取值范围[-1,1],值越接近1越相似。公式为:cosθ=(X·Y)/(|X|·|Y|),其中X·Y为向量内积,|X|、|Y|为向量模长。

注意:距离度量的选择需结合数据类型。例如,连续型数据优先用欧几里得距离,稀疏高维数据(如文本TF-IDF特征)优先用余弦相似度。

2. K值的选择:邻居数量的权衡

K值是KNN算法中最关键的超参数,直接影响模型性能,其选择需要在"过拟合"与"欠拟合"之间权衡:

  • K值过小(如K=1):模型对噪声敏感,易发生过拟合。此时仅依赖最近的1个样本,若该样本为异常值,会直接导致预测错误。

  • K值过大(如K=训练集样本数):模型过于平滑,易发生欠拟合。此时所有样本的投票权重趋于平均,会忽略局部数据的特征,导致预测结果偏差。

  • 最佳K值选择方法:通常选择奇数(避免分类任务中投票平局),通过交叉验证(如5折、10折交叉验证)遍历多个K值(如1-30),选择验证集准确率最高的K值。

3. 决策规则:结果的生成方式

决策规则根据任务类型(分类/回归)有所不同:

  • 分类任务:采用"多数投票法"------统计K个邻居中出现次数最多的类别,作为待预测样本的类别。可优化为"加权投票法",即距离越近的邻居权重越高(权重=1/距离²),降低远邻对结果的干扰。

  • 回归任务:采用"平均值法"------计算K个邻居的数值平均值,作为待预测样本的结果。同样可优化为"加权平均法",距离越近的邻居权重越高,提升预测精度。

三、算法实现步骤(以分类任务为例)

  1. 数据预处理

    1. 特征标准化/归一化:由于距离度量对特征尺度敏感(如"身高"以cm为单位,"体重"以kg为单位,未标准化会导致身高对距离的影响主导结果),需将所有特征转换到同一尺度(如标准化:均值为0,方差为1;归一化:取值范围[0,1])。

    2. 处理缺失值:通过均值填充、中位数填充等方式补全缺失特征。

    3. 划分训练集与测试集:通常按7:3或8:2的比例划分,避免数据泄露。

  2. 确定超参数:通过交叉验证确定最优K值和距离度量方法。

  3. 计算距离:对于测试集中的每个待预测样本,计算其与训练集中所有样本的距离。

  4. 筛选K个近邻:将训练样本按距离从小到大排序,选取前K个样本作为待预测样本的邻居。

  5. 生成预测结果:采用多数投票法(或加权投票法)统计K个邻居的类别,输出预测类别。

  6. 模型评估:通过准确率、精确率、召回率、F1值等指标评估模型在测试集上的性能,若效果不佳则重新调整超参数。

四、KNN算法的优缺点

1. 优点

  • 原理简单直观,易于理解和实现,无需复杂的数学推导。

  • 对异常值不敏感(当K值较大时),鲁棒性较强。

  • 适用于多分类任务,且无需修改算法核心逻辑。

  • 训练速度快,仅需存储训练数据,无需模型训练过程。

2. 缺点

  • 预测速度慢:预测时需计算与所有训练样本的距离,当训练数据量极大(如百万级)时,时间复杂度高。

  • 空间复杂度高:需存储全部训练数据,对内存要求较高。

  • 对高维数据敏感:高维空间中"距离"的区分度降低(维度灾难),导致模型性能下降。

  • 对超参数K和距离度量依赖度高,需通过大量实验确定最优组合。

五、优化方法

  • 数据降维:通过PCA(主成分分析)、LDA(线性判别分析)等方法降低特征维度,缓解维度灾难,提升预测速度。

  • 索引优化:采用KD树(K-Dimensional Tree)、Ball树等数据结构对训练样本进行索引,减少预测时的距离计算次数,提升预测速度。

  • 特征选择:剔除无关特征或冗余特征,保留对结果影响大的核心特征,降低计算成本。

  • 加权投票/平均:引入距离权重,提升近邻对结果的贡献,优化预测精度。

六、应用场景

KNN算法适用于数据量不大、特征维度适中、对模型可解释性要求较高的场景,典型应用包括:

  • 图像识别:如手写数字识别(MNIST数据集),通过像素特征的距离度量实现分类。

  • 文本分类:如垃圾邮件识别、新闻分类,基于文本的TF-IDF特征或词向量计算余弦相似度,完成分类。

  • 推荐系统:如基于用户行为的协同过滤推荐,通过计算用户或物品之间的相似度,推荐"相似用户喜欢的物品"。

  • 医疗诊断:如基于患者的生理指标(血压、血糖等),对比历史病例数据,辅助判断疾病类型。

  • 回归预测:如房价预测、气温预测,通过周边相似样本的数值进行加权平均,得到预测结果。

七、经典代码实现(Python+Scikit-learn)

以下以鸢尾花(Iris)分类数据集为例,展示KNN算法的实现流程:

复制代码
#导入画图包
import matplotlib.pyplot as plt
import numpy as np
#1数据点集合
point1 = np.array([[7.7, 6.1], [3.1, 5.9], [8.6, 8.8], [9.5, 7.3], [3.9, 7.4], [5.0, 5.3], [1.0, 7.3]])
point2 = np.array([[0.2, 2.2], [4.5, 4.1], [0.5, 1.1], [2.7, 3.0], [4.7, 0.2], [2.9, 3.3], [7.3, 7.9]])
point3 = np.array([[9.2, 0.7], [9.2, 2.1], [7.3, 4.5], [8.9, 2.9], [9.5, 3.7], [7.7, 3.7], [9.4, 2.4]])

#整合一下数据 np.concatenate()表示合并或者拼接
np_train_data =np.concatenate((point1,point2,point3),axis=0)
#添加一下标签
np_train_label=np.array([0]*len(point1)+[1]*len(point2)+[2]*len(point3))
print(np_train_label)

#设置要预测的数据点
scatter_predict=np.array([3.35,2.46])
#计算L2距离也就是欧式距离
#第二种方法
def eucalidean_distance_1(x_point,y_point):
    x_1=np.array(x_point)
    y_1=np.array(y_point)
    return np.sqrt(np.sum((x_1-y_1)**2))

#初始化K值(超参数,自己设定的)
k=5
#取最近的k个点 (如果想取出最近的k个点,我得计算出所有的距离,取距离最小的3个点)
#存储所有的距离
distances=[]
#计算预测数据点和所有训练数据的距离
for scatter_train in np_train_data:
    distances.append(eucalidean_distance_1(scatter_predict,scatter_train))
#取距离最小的k个数据点
nearest=np.argsort(np.array(distances))
#获取最近的k个数据点的数据和标签
#获取数据标签
topk_label= [np_train_label[i]  for i in nearest[:k] ] #列表递推表达式
topk_data= [np_train_data[i]  for i in nearest[:k] ] #
class_0=[]
class_1=[]
class_2=[]
label_predict=0
for item in topk_label:
    if item==0:
        class_0.append(0)
    if item==1:
        class_1.append(1)
    if item==2:
        class_2.append(2)
#比较那个类多
index_max=np.argmax([len(class_0),len(class_1),len(class_2)])
print('最后的类别',index_max)

#画图
#第一步获得5条线的距离
list_knn_distance=[distances[nearest[:k][j]]  for j in range(k)]
#annotate 根据坐标写文字
for i in range(k):
    plt.plot([scatter_predict[0],topk_data[i][0]],[scatter_predict[1],topk_data[i][1]])
    #annotate的第一个参数表示要展示的文本内容,第二个参数是要展示的文本内容所在的坐标
    plt.annotate("%s" % round(list_knn_distance[i], 2), xy=((scatter_predict[0] + topk_data[i][0]) / 2,(scatter_predict[1] + topk_data[i][1]) / 2))

#画原始的数据点
#下面画原始点
plt.xlabel('x axis label')
plt.ylabel('y axis label')
plt.scatter(np_train_data[np_train_label==0,0],np_train_data[np_train_label==0,1],marker='*')
plt.scatter(np_train_data[np_train_label==1,0],np_train_data[np_train_label==1,1],marker='^')
plt.scatter(np_train_data[np_train_label==2,0],np_train_data[np_train_label==2,1],marker='s')
plt.scatter(scatter_predict[0],scatter_predict[1],marker='o')
plt.show()

八、总结

KNN算法是监督学习中入门级的经典算法,其核心优势在于简单直观、易于实现,适合作为机器学习的入门实践案例。但由于其"懒惰学习"的特性,在处理大规模、高维数据时存在效率瓶颈,需通过降维、索引优化等方法提升性能。在实际应用中,需结合数据特点合理选择距离度量、K值及决策规则,以实现最优的预测效果。

相关推荐
祁思妙想4 小时前
数据分析三剑客:NumPy、Pandas、Matplotlib
数据分析·numpy·pandas
猪在黑魔纹里1 天前
解决VSCode无法高亮、解析numpy中的部分接口(如pi、deg2rad)
ide·vscode·python·numpy
九死九歌1 天前
【Sympydantic】使用sympydantic,利用pydantic告别numpy与pytorch编程中,tensor形状带来的烦人痛点!
开发语言·pytorch·python·机器学习·numpy·pydantic
qq19226382 天前
探索图像滤波去噪:MATLAB GUI的奇妙之旅
numpy
Python大数据分析@2 天前
Numpy基础20问
numpy
Cat God 0072 天前
CentOS 搭建 SFTP 服务器(二)
服务器·centos·numpy
fresh hacker3 天前
【Python数据分析】速通NumPy
开发语言·python·数据挖掘·数据分析·numpy
maycho1233 天前
探索 Buck DCDC:自适应恒定导通时间控制的降压变换器之旅
numpy
裤裤兔3 天前
python2与python3的兼容
开发语言·python·numpy