深度剖析 K 近邻算法:分类、回归实战及优劣势分析

简介:本文全面且细致地讲解了 K 近邻算法,开篇点明掌握该算法需先理解特征空间这一关键概念,通过水果、鸢尾花数据集实例,助读者明晰特征空间维度构成;接着深入剖析 K 近邻算法原理,涵盖分类、回归两大应用方向。分类部分,详述从导包、样本生成、绘图到预测及邻近点绘制的完整实操流程,还贴心指出 sklearn 导包易错点;回归板块,依次展示数据集生成、模型训练、预测及结果可视化步骤。文末精准总结 KNN 算法参数、优缺点及改进方法,是新手入门、老手温故 K 近邻算法的优质参考。

这里写目录标题

K近邻算法

要掌握K近邻算法先要理解一个概念:特征空间

什么是特征空间

假设我们要对水果进行分类,有两个特征:"形状"(圆形、椭圆形等)和 "颜色"(红色、绿色等)。那么特征空间就是一个二维空间,其中一个维度代表 "形状",另一个维度代表 "颜色"。

比如鸢尾花数据集。它有四个特征:花萼长度、花萼宽度、花瓣长度、花瓣宽度。那么特征空间就是一个四维空间。每一朵鸢尾花都可以在这个四维空间中找到一个对应的位置(点)。

K近邻算法的原理

在特征空间中一个样本的附近的K个最近样本大多属于某一个类别,则该样本也属于这个类别。

K近邻算法属于有监督的机器学习算法,如果您不了解什么是有监督可以阅读我的文章:一文吃透监督学习:从原理到实战,攻克过拟合与欠拟合难题 。有监督学习的算法一般用于分类和回归。

K近邻分类

因为正方形最近的三个分类好的邻居是两个三角形和一个圆形,所以他被分类为三角形

下面,将按照步骤实现一个分类的案例

(第一步)导入 sklearn 和matplotlib和numpy包

这里导包的时候会遇到一个问题关于sklearn包的,一些课本上使用 make_blobs的时候导入步骤如下:

python 复制代码
from sklearn.datasets.samples_generator import make_blobs

现在会导致warnings,具体来说,你正在使用的代码中涉及到了 sklearn.datasets.samples_generator 模块,但是从 scikit-learn 库的版本 0.22 开始,这个模块就已经被标记为弃用(deprecated)了,并且计划在版本 0.24 中移除它。提示建议相应的类或函数应该从 sklearn.datasets 中去导入,而那些无法从 sklearn.datasets 导入的部分现在已经属于私有应用程序接口(private API)了。

所以应该按如下步骤导入:

python 复制代码
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs

(第二步)生成样本

要生成的60个数据将围绕三个中心聚类。这三个中心在二维空间中的坐标分别是(-2,2)、(2,2)和(0,4)。random_state是随机种子,设置了相同的随机种子会得到相同的结果

python 复制代码
centers = [[-2,2],[2,2],[0,4]]
X,y = make_blobs(n_samples = 60 ,centers = centers , random_state = 0 ,cluster_std = 0.60)

(第三步)画出生成的样本

涉及到了一个与plt.scatter散点图的知识,需要先对他进行讲解

(第三步补充)plt.scatter

他是matplotlib中用于画散点图的函数 第一个参数是横坐标 第二个参数 是纵坐标,第三个参数c代表着不同的类别 ,第四个参数s代表着点的大小 ,cmap是上色方式

接着回到第三步开始画出三个不同颜色的样本与聚类中心

python 复制代码
plt.figure(figsize=(5,3),dpi = 144)
c = np.array(centers)
# 样本
plt.scatter(X[:,0],X[:,1],c = y,s=10,cmap='cool')
# 中心点
plt.scatter(c[:,0],c[:,1],s = 50,marker='*',c = 'orange')
plt.show()

(第四步)进行预测

预测之前要学习一个函数 kneighbors(X_sample, return_distance = False)

X_sample是输入的样本数据点的坐标。目的是找到这个样本点在训练数据(X和y)中的k个最近邻。

return_distance = False是一个控制返回值内容的参数。当设置为False时,表示不返回样本点到其最近邻的距离信息,而只返回最近邻的索引(即这些最近邻在训练数据集X中的位置索引。

python 复制代码
from sklearn.neighbors import KNeighborsClassifier

k = 5
clf = KNeighborsClassifier(n_neighbors= k)
clf.fit(X,y)
X_sample = ([[0,2]])
y_sample = clf.predict(X_sample)
neighbors = clf.kneighbors(X_sample,return_distance= False)

(第五步)把邻近点画出来

python 复制代码
plt.figure(figsize = (5,3),dpi = 144)
c = np.array(centers)
# 样本
plt.scatter(X[:,0],X[:,1],c = y,s=10,cmap='cool')
# 中心点
plt.scatter(c[:,0],c[:,1],s = 50,marker='*',c = 'orange')
# 需要被预测的点
plt.scatter(X_sample[0][0],X_sample[0][1],marker='^',s = 100,cmap = 'cool')

# 把距离被预测点最近的五个点和被预测的点连成线
for i in neighbors[0]:
    plt.plot([X[i][0],X_sample[0][0]],[X[i][1],X_sample[0][1]],'k--',linewidth = 0.6)
    
plt.show()

K近邻回归

KNN算法用于回归的思路是,找出一个样本的K个最近邻居,讲这些 邻居的属性的平均值赋给该样本,就可以得到该样本的值,下面将按步骤进行回归:

(第一步)生成数据集

使用余弦曲线,并且添加一些噪声:

python 复制代码
import numpy as np
np.random.seed(30)
X = 5*np.random.rand(42,1)
y = np.cos(X).ravel()
y += 0.2*np.random.rand(42)-0.1

ravel起到将多维数组压扁成为一维数组的作用

(第二步)训练模型

python 复制代码
from sklearn.neighbors import KNeighborsRegressor
k = 5
knn = KNeighborsRegressor(k)
knn.fit(X,y)

(第三步)生成足够密集的点进行预测

python 复制代码
T = np.linspace(0,5,500)[:,np.newaxis]
y_pred = knn.predict(T)

(第四步)把预测结果画成图

python 复制代码
import matplotlib.pyplot as plt
plt.figure(figsize=(5,3),dpi =144)
# 这是训练集的样本点
plt.scatter(X,y,c='g',label="Train data",s =10)

# zh
plt.plot(T,y_pred,c='k',label='Test prediction',lw =2)

# plt.axis('tight')的主要作用是自动调整坐标轴的范围,使得坐标轴紧密地包围数据点或者图形元素。
plt.axis('tight')

plt.title(f"KNN Reg(k ={k})")
# 显示label
plt.legend()
plt.show()

KNN算法参数与优缺点

  1. KNeighbors有两个参数一个是邻居的数量,一个是计算距离的方式。一般邻居在3-5个有较好的效果,K越大越容易掺杂进噪声,K太小会过拟合,默认使用欧式距离,他在许多情况下比较好。
  2. 他的却带你在于如果特征空间有很多0,数据集比较稀疏的情况下分类效果不太好,还有就是数据集大的时候时间比较长。
  3. 改进的常规方法目前有两个,针对不同的邻居使用不同的距离权重,另一个是选一定半径内距离最近K个的邻居。
相关推荐
Giser探索家1 小时前
遥感卫星升轨 / 降轨技术解析:对图像光照、对比度的影响及工程化应用
大数据·人工智能·算法·安全·计算机视觉·分类
B站计算机毕业设计之家2 小时前
深度学习实战:python动物识别分类检测系统 计算机视觉 Django框架 CNN算法 深度学习 卷积神经网络 TensorFlow 毕业设计(建议收藏)✅
python·深度学习·算法·计算机视觉·分类·毕业设计·动物识别
好奇龙猫3 小时前
【学习AI-相关路程-mnist手写数字分类-一段学习的结束:自我学习AI-复盘-代码-了解原理-综述(5) 】
人工智能·学习·分类
A-大程序员3 小时前
【Pytorch】分类问题交叉熵
人工智能·pytorch·分类
wu_jing_sheng03 小时前
ai 作物分类
人工智能·分类·数据挖掘
飞翔的佩奇3 小时前
【完整源码+数据集+部署教程】烟叶植株计数与分类系统源码和数据集:改进yolo11-TADDH
python·yolo·计算机视觉·目标跟踪·分类·数据集·yolo11
csuzhucong4 小时前
人类知识体系分类
人工智能·分类·数据挖掘
JJJJ_iii6 小时前
【机器学习03】学习率与特征工程、多项式回归、逻辑回归
人工智能·pytorch·笔记·学习·机器学习·回归·逻辑回归
云端FFF7 小时前
论文理解 【LLM-回归】—— Decoding-based Regression
人工智能·数据挖掘·回归
IT小哥哥呀20 小时前
基于深度学习的数字图像分类实验与分析
人工智能·深度学习·分类