深度剖析 K 近邻算法:分类、回归实战及优劣势分析

简介:本文全面且细致地讲解了 K 近邻算法,开篇点明掌握该算法需先理解特征空间这一关键概念,通过水果、鸢尾花数据集实例,助读者明晰特征空间维度构成;接着深入剖析 K 近邻算法原理,涵盖分类、回归两大应用方向。分类部分,详述从导包、样本生成、绘图到预测及邻近点绘制的完整实操流程,还贴心指出 sklearn 导包易错点;回归板块,依次展示数据集生成、模型训练、预测及结果可视化步骤。文末精准总结 KNN 算法参数、优缺点及改进方法,是新手入门、老手温故 K 近邻算法的优质参考。

这里写目录标题

K近邻算法

要掌握K近邻算法先要理解一个概念:特征空间

什么是特征空间

假设我们要对水果进行分类,有两个特征:"形状"(圆形、椭圆形等)和 "颜色"(红色、绿色等)。那么特征空间就是一个二维空间,其中一个维度代表 "形状",另一个维度代表 "颜色"。

比如鸢尾花数据集。它有四个特征:花萼长度、花萼宽度、花瓣长度、花瓣宽度。那么特征空间就是一个四维空间。每一朵鸢尾花都可以在这个四维空间中找到一个对应的位置(点)。

K近邻算法的原理

在特征空间中一个样本的附近的K个最近样本大多属于某一个类别,则该样本也属于这个类别。

K近邻算法属于有监督的机器学习算法,如果您不了解什么是有监督可以阅读我的文章:一文吃透监督学习:从原理到实战,攻克过拟合与欠拟合难题 。有监督学习的算法一般用于分类和回归。

K近邻分类

因为正方形最近的三个分类好的邻居是两个三角形和一个圆形,所以他被分类为三角形

下面,将按照步骤实现一个分类的案例

(第一步)导入 sklearn 和matplotlib和numpy包

这里导包的时候会遇到一个问题关于sklearn包的,一些课本上使用 make_blobs的时候导入步骤如下:

python 复制代码
from sklearn.datasets.samples_generator import make_blobs

现在会导致warnings,具体来说,你正在使用的代码中涉及到了 sklearn.datasets.samples_generator 模块,但是从 scikit-learn 库的版本 0.22 开始,这个模块就已经被标记为弃用(deprecated)了,并且计划在版本 0.24 中移除它。提示建议相应的类或函数应该从 sklearn.datasets 中去导入,而那些无法从 sklearn.datasets 导入的部分现在已经属于私有应用程序接口(private API)了。

所以应该按如下步骤导入:

python 复制代码
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs

(第二步)生成样本

要生成的60个数据将围绕三个中心聚类。这三个中心在二维空间中的坐标分别是(-2,2)、(2,2)和(0,4)。random_state是随机种子,设置了相同的随机种子会得到相同的结果

python 复制代码
centers = [[-2,2],[2,2],[0,4]]
X,y = make_blobs(n_samples = 60 ,centers = centers , random_state = 0 ,cluster_std = 0.60)

(第三步)画出生成的样本

涉及到了一个与plt.scatter散点图的知识,需要先对他进行讲解

(第三步补充)plt.scatter

他是matplotlib中用于画散点图的函数 第一个参数是横坐标 第二个参数 是纵坐标,第三个参数c代表着不同的类别 ,第四个参数s代表着点的大小 ,cmap是上色方式

接着回到第三步开始画出三个不同颜色的样本与聚类中心

python 复制代码
plt.figure(figsize=(5,3),dpi = 144)
c = np.array(centers)
# 样本
plt.scatter(X[:,0],X[:,1],c = y,s=10,cmap='cool')
# 中心点
plt.scatter(c[:,0],c[:,1],s = 50,marker='*',c = 'orange')
plt.show()

(第四步)进行预测

预测之前要学习一个函数 kneighbors(X_sample, return_distance = False)

X_sample是输入的样本数据点的坐标。目的是找到这个样本点在训练数据(X和y)中的k个最近邻。

return_distance = False是一个控制返回值内容的参数。当设置为False时,表示不返回样本点到其最近邻的距离信息,而只返回最近邻的索引(即这些最近邻在训练数据集X中的位置索引。

python 复制代码
from sklearn.neighbors import KNeighborsClassifier

k = 5
clf = KNeighborsClassifier(n_neighbors= k)
clf.fit(X,y)
X_sample = ([[0,2]])
y_sample = clf.predict(X_sample)
neighbors = clf.kneighbors(X_sample,return_distance= False)

(第五步)把邻近点画出来

python 复制代码
plt.figure(figsize = (5,3),dpi = 144)
c = np.array(centers)
# 样本
plt.scatter(X[:,0],X[:,1],c = y,s=10,cmap='cool')
# 中心点
plt.scatter(c[:,0],c[:,1],s = 50,marker='*',c = 'orange')
# 需要被预测的点
plt.scatter(X_sample[0][0],X_sample[0][1],marker='^',s = 100,cmap = 'cool')

# 把距离被预测点最近的五个点和被预测的点连成线
for i in neighbors[0]:
    plt.plot([X[i][0],X_sample[0][0]],[X[i][1],X_sample[0][1]],'k--',linewidth = 0.6)
    
plt.show()

K近邻回归

KNN算法用于回归的思路是,找出一个样本的K个最近邻居,讲这些 邻居的属性的平均值赋给该样本,就可以得到该样本的值,下面将按步骤进行回归:

(第一步)生成数据集

使用余弦曲线,并且添加一些噪声:

python 复制代码
import numpy as np
np.random.seed(30)
X = 5*np.random.rand(42,1)
y = np.cos(X).ravel()
y += 0.2*np.random.rand(42)-0.1

ravel起到将多维数组压扁成为一维数组的作用

(第二步)训练模型

python 复制代码
from sklearn.neighbors import KNeighborsRegressor
k = 5
knn = KNeighborsRegressor(k)
knn.fit(X,y)

(第三步)生成足够密集的点进行预测

python 复制代码
T = np.linspace(0,5,500)[:,np.newaxis]
y_pred = knn.predict(T)

(第四步)把预测结果画成图

python 复制代码
import matplotlib.pyplot as plt
plt.figure(figsize=(5,3),dpi =144)
# 这是训练集的样本点
plt.scatter(X,y,c='g',label="Train data",s =10)

# zh
plt.plot(T,y_pred,c='k',label='Test prediction',lw =2)

# plt.axis('tight')的主要作用是自动调整坐标轴的范围,使得坐标轴紧密地包围数据点或者图形元素。
plt.axis('tight')

plt.title(f"KNN Reg(k ={k})")
# 显示label
plt.legend()
plt.show()

KNN算法参数与优缺点

  1. KNeighbors有两个参数一个是邻居的数量,一个是计算距离的方式。一般邻居在3-5个有较好的效果,K越大越容易掺杂进噪声,K太小会过拟合,默认使用欧式距离,他在许多情况下比较好。
  2. 他的却带你在于如果特征空间有很多0,数据集比较稀疏的情况下分类效果不太好,还有就是数据集大的时候时间比较长。
  3. 改进的常规方法目前有两个,针对不同的邻居使用不同的距离权重,另一个是选一定半径内距离最近K个的邻居。
相关推荐
KingDol_MIni2 小时前
Transformer-LSTM混合模型在时序回归中的完整流程研究
回归·lstm·transformer
极客智谷4 小时前
Spring AI 系列——使用大模型对文本内容分类归纳并标签化输出
人工智能·spring·分类
奋斗者1号5 小时前
神经网络中之多类别分类:从基础到高级应用
大数据·神经网络·分类
码记大虾7 小时前
机器学习:支持向量机 二分类的基本思想
机器学习·支持向量机·分类
白杆杆红伞伞9 小时前
02_线性模型(回归线性模型)
人工智能·数据挖掘·回归
海森大数据1 天前
人工智能可信度新突破:MIT改进共形分类助力高风险医学诊断
大数据·人工智能·分类
槑辉_2 天前
【se-res模块学习】结合CIFAR-10分类任务学习
图像处理·人工智能·pytorch·深度学习·机器学习·分类
KingDol_MIni2 天前
transformer➕lstm训练回归模型
回归·lstm·transformer
A林玖2 天前
【机器学习】Logistic 回归
人工智能·机器学习·回归
QQ676580083 天前
PyTorch和torchvision为例,如何使用预训练的ResNet模型来训练水稻虫害分类数据集 14类 从数据准备到模型训练、评估全流程
人工智能·pytorch·分类