[Python] scikit-learn - K近邻算法介绍和使用案例

什么是K近邻算法?

K近邻算法(K-Nearest Neighbors,简称KNN)是一种基于实例的学习方法,主要用于分类和回归任务。它的基本思想是:给定一个训练数据集,对于一个新的输入实例,在训练数据集中找到与该实例最邻近的K个实例,这K个实例的多数类别就是该输入实例的类别。

思路:

  1. 计算输入实例与训练数据集中每个实例之间的距离。
  2. 对距离进行排序,找到距离最近的K个实例。
  3. 根据这K个实例的类别进行投票,得到输入实例的类别。

K近邻算法使用场景和注意事项

K近邻算法(K-Nearest Neighbors,简称KNN)是一种基于实例的学习方法,主要用于分类和回归任务。它的使用场景包括:

  1. 数据集较小的情况:当数据集较小时,KNN算法可以快速地进行训练和预测,而不需要大量的计算资源。
  2. 数据集中存在噪声的情况:由于KNN算法是基于实例的,因此它对数据集中的噪声具有一定的容忍度。
  3. 数据集中存在异常值的情况:KNN算法在处理异常值时,会根据邻近实例的类别来进行投票,从而降低了异常值对结果的影响。
  4. 数据集中存在不平衡类别的情况:KNN算法在处理不平衡类别的数据集时,可以通过调整K值来平衡各个类别之间的样本数量。

在使用KNN算法时,需要注意以下几点:

  1. 选择合适的K值:K值的选择对算法的性能有很大影响。通常情况下,可以通过交叉验证等方法来选择合适的K值。
  2. 特征选择:KNN算法对特征的数量和质量要求较高,因此需要对特征进行选择和预处理,以提高算法的性能。
  3. 距离度量:KNN算法需要计算实例之间的距离,因此需要选择合适的距离度量方法,如欧氏距离、曼哈顿距离等。
  4. 性能评估:为了确保算法的性能,需要对算法进行性能评估,如准确率等指标。

K近邻算法python实现

复制代码
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

import numpy as np
from collections import Counter

def euclidean_distance(x1, x2):
    # 计算欧氏距离
    return np.sqrt(np.sum((x1 - x2) ** 2))

class KNN:
    def __init__(self, k=3):
        self.k = k

    def fit(self, X, y):
        self.X_train = X
        self.y_train = y

    def predict(self, X):
        y_pred = [self._predict(x) for x in X]
        return np.array(y_pred)

    def _predict(self, x):
        # 计算输入实例与训练数据集中每个实例之间的距离
        distances = [euclidean_distance(x, x_train) for x_train in self.X_train]
        # 对距离进行排序,找到距离最近的K个实例的索引
        k_indices = np.argsort(distances)[:self.k]
        # 根据这K个实例的类别进行投票,得到输入实例的类别
        k_nearest_labels = [self.y_train[i] for i in k_indices]
        most_common = Counter(k_nearest_labels).most_common(1)
        return most_common[0][0]


data = load_iris()
X, y = data.data, data.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

knn = KNN(k=3)
knn.fit(X_train, y_train)
predictions = knn.predict(X_test)

print("Accuracy:", accuracy_score(y_test, predictions))

scikit-learn中的K近邻算法

K近邻算法用于分类任务

sklearn.neighbors.KNeighborsClassifier --- scikit-learn 1.4.0 documentation

复制代码
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.neighbors import KNeighborsClassifier

data = load_iris()
X, y = data.data, data.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

knc = KNeighborsClassifier(n_neighbors=3)
knc.fit(X_train, y_train)
predictions = knc.predict(X_test)

print("Accuracy:", accuracy_score(y_test, predictions))

在这个示例中,我们首先从scikit-learn库中加载了iris花卉数据集,并将其划分为训练集和测试集。然后,我们创建了一个KNeighborsClassifier对象,并设置了K值为3。接下来,我们使用训练集对模型进行训练,并使用测试集进行预测。最后,我们计算了预测结果的准确度。

K近邻算法用于回归任务

sklearn.neighbors.KNeighborsRegressor --- scikit-learn 1.4.0 documentation

复制代码
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsRegressor
from sklearn.metrics import mean_squared_error

# 加载iris花卉数据集
data = load_iris()
X = data.data
y = data.target

# 将数据集划分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 创建KNeighborsRegressor对象,设置K值为3
knn = KNeighborsRegressor(n_neighbors=3)

# 使用训练集对模型进行训练
knn.fit(X_train, y_train)

# 使用测试集进行预测
y_pred = knn.predict(X_test)

# 计算预测结果的均方误差
mse = mean_squared_error(y_test, y_pred)
print("均方误差:", mse)

在这个示例中,我们首先从scikit-learn库中加载了iris花卉数据集,并将其划分为训练集和测试集。然后,我们创建了一个KNeighborsRegressor对象,并设置了K值为3。接下来,我们使用训练集对模型进行训练,并使用测试集进行预测。最后,我们计算了预测结果的均方误差。

相关推荐
木卫二号Coding2 分钟前
第七十二篇-V100-32G+WebUI+Flux.1-Schnell+Lora+文生图
开发语言·人工智能·python
墨笔之风2 分钟前
基于python 实现的小游戏
开发语言·python·pygame
多米Domi0113 分钟前
0x3f 第24天 黑马web (安了半天程序 )hot100普通数组
数据结构·python·算法·leetcode
BoBoZz194 分钟前
AnatomicalOrientation 3D人体模型及三个人体标准解剖学平面展示
python·vtk·图形渲染·图形处理
love530love5 分钟前
EPGF 新手教程 11在 PyCharm(中文版 GUI)中创建 uv 环境,并把 uv 做到“项目自包含”(工具本地化为必做环节)
ide·人工智能·python·pycharm·conda·uv·epgf
jackylzh5 分钟前
cmd或其它终端的dos命令 & events.out.tfevents文件怎么打开
python
gis_rc6 分钟前
python下shp转3dtiles
python·3d·cesium·3dtiles·数字孪生模型
廖圣平6 分钟前
直播间福袋脚本,研究json格式【一】
python
Lkygo8 分钟前
ragflow 构建本地知识库指南
人工智能·python·语言模型
TTGGGFF2 小时前
Supertonic 部署与使用全流程保姆级指南(附已部署镜像)
开发语言·python