从零构建 KNN 分类: sklearn 与自定义实现对比

文章目录

    • 理论基础
      • [KNN 算法原理](#KNN 算法原理)
      • [KNN 距离度量](#KNN 距离度量)
      • 损失函数
    • 项目结构
      • 代码实现
        • [1. `my_sklearn/neighbors/knn.py` --- 自定义KNN实现](#1. my_sklearn/neighbors/knn.py — 自定义KNN实现)
        • [2. `my_sklearn/metrics/accuracy.py` --- 自定义准确率计算](#2. my_sklearn/metrics/accuracy.py — 自定义准确率计算)
        • [3. `my_sklearn/neighbors/init.py` --- 导出 `MyKNN`](#3. my_sklearn/neighbors/__init__.py — 导出 MyKNN)
        • [4. `my_sklearn/metrics/init.py` --- 导出准确率函数](#4. my_sklearn/metrics/__init__.py — 导出准确率函数)
        • [5. `my_sklearn/init.py` --- 主包初始化](#5. my_sklearn/__init__.py — 主包初始化)
        • [6. `main.py` --- 主程序文件,包含数据处理和模型对比](#6. main.py — 主程序文件,包含数据处理和模型对比)
      • 项目使用说明
      • 总结

K近邻算法(K-Nearest Neighbors,简称 KNN)是一种基于实例的学习方法,它通过计算样本与其最近邻的距离来进行分类或回归。KNN 的主要特点是简单有效,特别适用于小规模数据集。

本文将详细介绍 KNN 算法的工作原理、如何计算距离、如何根据最近邻投票进行分类,并通过自定义实现与 sklearn 的 KNN 模型进行对比。


理论基础

KNN 算法原理

KNN 算法的核心思想非常简单,适用于分类和回归任务:

  1. 分类任务:给定一个样本,计算该样本与训练集中所有样本的距离,然后选择最近的 k 个样本,最终通过这 k 个邻居的标签来决定该样本的类别。最常用的做法是采用多数投票法,即选择 k 个最近邻中出现次数最多的标签作为该样本的预测标签。

  2. 回归任务:给定一个样本,计算该样本与训练集所有样本的距离,然后选择最近的 k 个样本,最终通过这 k 个邻居的标签(通常是数值)来决定该样本的预测值。最常用的做法是对这些邻居标签取平均。


KNN 距离度量

KNN 算法的核心操作之一是计算样本之间的距离。常见的距离度量方法有:

  • 欧几里得距离:最常用的距离度量方式,适用于数值型数据。

    d ( x , y ) = ∑ i = 1 n ( x i − y i ) 2 d(x, y) = \sqrt{\sum_{i=1}^n (x_i - y_i)^2} d(x,y)=i=1∑n(xi−yi)2

  • 曼哈顿距离:适用于数据点在格状结构中,计算方式是各坐标差值的绝对值之和。

    d ( x , y ) = ∑ i = 1 n ∣ x i − y i ∣ d(x, y) = \sum_{i=1}^n |x_i - y_i| d(x,y)=i=1∑n∣xi−yi∣

  • 闵可夫斯基距离:欧几里得距离和曼哈顿距离的推广,适用于不同的距离度量。

    d ( x , y ) = ( ∑ i = 1 n ∣ x i − y i ∣ p ) 1 / p d(x, y) = \left( \sum_{i=1}^n |x_i - y_i|^p \right)^{1/p} d(x,y)=(i=1∑n∣xi−yi∣p)1/p


损失函数

与许多其他机器学习算法不同,KNN 没有明确的损失函数。KNN 主要通过投票或计算邻居的均值来进行预测。然而,在分类问题中,我们可以使用 准确率 作为评估标准。准确率的计算方法如下:

Accuracy = Correct Predictions Total Predictions \text{Accuracy} = \frac{\text{Correct Predictions}}{\text{Total Predictions}} Accuracy=Total PredictionsCorrect Predictions


项目结构

project_root/
├── my_sklearn/
│   ├── neighbors/
│   │   ├── __init__.py
│   │   └── knn.py
│   ├── metrics/
│   │   ├── __init__.py
│   │   └── accuracy.py
│   └── __init__.py
└── main.py

代码实现

1. my_sklearn/neighbors/knn.py --- 自定义KNN实现
python 复制代码
# my_sklearn/neighbors/knn.py
# 包含 KNN 模型的实现

import numpy as np


class MyKNN:
    def __init__(self, n_neighbors=3):
        self.n_neighbors = n_neighbors
        self.X_train = None
        self.y_train = None

    def fit(self, X, y):
        self.X_train = X
        self.y_train = y

    def predict(self, X):
        predictions = []
        for x in X:
            distances = np.sqrt(np.sum((self.X_train - x) ** 2, axis=1))  # 计算测试样本与所有训练样本的欧氏距离
            k_indices = np.argsort(distances)[:self.n_neighbors]  # 获取距离最近的 k 个样本的索引
            k_nearest_labels = self.y_train[k_indices]  # 获取 k 个最近邻样本的标签
            most_common = np.bincount(k_nearest_labels).argmax()  # 统计 k 个标签中出现次数最多的类别
            predictions.append(most_common)  # 将预测结果加入预测列表
        return np.array(predictions)  # 将预测列表转换为 numpy 数组并返回

2. my_sklearn/metrics/accuracy.py --- 自定义准确率计算
python 复制代码
# my_sklearn/metrics/accuracy.py

def my_accuracy_score(y_true, y_pred):
    """
    计算准确率:正确预测的样本数 / 总样本数
    
    参数:
    y_true -- 真实标签
    y_pred -- 预测标签
    
    返回:
    accuracy -- 准确率
    """
    correct = sum(y_true == y_pred)
    total = len(y_true)
    accuracy = correct / total
    return accuracy

3. my_sklearn/neighbors/__init__.py --- 导出 MyKNN
python 复制代码
# my_sklearn/neighbors/__init__.py

from .knn import MyKNN

# 导出的类列表
__all__ = ['MyKNN']

4. my_sklearn/metrics/__init__.py --- 导出准确率函数
python 复制代码
# my_sklearn/metrics/__init__.py
# metrics 子包初始化文件

# 从 accuracy 模块导入函数
from .accuracy import my_accuracy_score

# 导出的函数列表
__all__ = ['my_accuracy_score']

5. my_sklearn/__init__.py --- 主包初始化
python 复制代码
# my_sklearn/__init__.py
# 主包初始化文件

# 导入子模块,使它们可以通过my_sklearn.xxx直接访问
from . import metrics
from . import neighbors

6. main.py --- 主程序文件,包含数据处理和模型对比
python 复制代码
# main.py

import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier as SklearnKNN
from sklearn.metrics import accuracy_score  # 导入 sklearn 的准确率计算函数

# 导入自定义模块
from my_sklearn.neighbors import MyKNN  # 导入自定义 knn 模型
from my_sklearn.metrics import my_accuracy_score  # 导入自定义准确率计算函数

# 1. 加载鸢尾花数据集
iris = datasets.load_iris()
X = iris.data
y = iris.target

# 2. 数据预处理(标准化)
scaler = StandardScaler()
X = scaler.fit_transform(X)

# 3. 切分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 4. 使用自定义实现的 KNN
my_knn = MyKNN(n_neighbors=3)
my_knn.fit(X_train, y_train)

# 5. 使用 sklearn 实现的 KNN
sklearn_knn = SklearnKNN(n_neighbors=3)
sklearn_knn.fit(X_train, y_train)

# 6. 对比自定义模型和 sklearn 模型的预测准确度
my_pred = my_knn.predict(X_test)
sklearn_pred = sklearn_knn.predict(X_test)

# 使用自定义函数计算自定义模型的准确率
my_accuracy = my_accuracy_score(y_test, my_pred)
# 使用 sklearn 的函数计算 sklearn 模型的准确率
sklearn_accuracy = accuracy_score(y_test, sklearn_pred)

print(f"My KNN Accuracy: {my_accuracy * 100:.2f}%")
print(f"Sklearn KNN Accuracy: {sklearn_accuracy * 100:.2f}%")

项目使用说明

  1. 安装依赖

    由于此项目只使用了 numpysklearn 库,你可以通过以下命令安装它们:

    bash 复制代码
    pip install numpy scikit-learn
  2. 运行项目

    运行 main.py 文件来加载数据集、训练模型并对比自定义模型和 sklearn 版本的准确度:

    bash 复制代码
    python main.py

总结

这个项目实现了自定义的 KNN 模型,并与 sklearn 的实现进行了对比。结构上,模仿了 sklearn 的模块化方式,将功能分为 neighborsmetrics 两个子模块。KNN算法是一种简单但有效的分类算法,通过计算测试样本与训练样本之间的距离,找出最近的k个邻居,然后通过多数投票的方式确定测试样本的类别。

相关推荐
艾思科蓝 AiScholar3 小时前
【 IEEE出版 | 快速稳定EI检索 | 往届已EI检索】2025年储能及能源转换国际学术会议(ESEC 2025)
人工智能·计算机网络·自然语言处理·数据挖掘·自动化·云计算·能源
Wis4e7 小时前
数据挖掘导论——第二章:数据
人工智能·数据挖掘
汤姆和佩琦8 小时前
LLMs基础学习(一)概念、模型分类、主流开源框架介绍以及模型的预训练任务
人工智能·学习·算法·分类·数据挖掘
zhulangfly8 小时前
机器学习算法分类及应用场景全解析
算法·机器学习·分类
梦里是谁N8 小时前
以下是基于文章核心命题打造的15个标题方案,根据传播场景分类推荐
人工智能·分类·数据挖掘
lele_ne8 小时前
【深度学习】宠物品种分类Pet Breeds Classifier
人工智能·深度学习·分类
dundunmm1 天前
【数据挖掘】知识蒸馏(Knowledge Distillation, KD)
人工智能·深度学习·数据挖掘·模型·知识蒸馏·蒸馏
AI技术控1 天前
机器学习实战——音乐流派分类(主页有源码)
人工智能·机器学习·分类
L_pyu1 天前
pytorch实现cifar10多分类总结
人工智能·pytorch·分类