【机器学习】K近邻算法

目录

算法引入:

KNN算法的核心思想

KNN算法的步骤

KNN常用的距离度量方法

KNN算法的优缺点

优点:

缺点:

K值的选择

KNN的C++实现

复杂度分析:


**K近邻算法(K-Nearest Neighbors, KNN)**是一种简单但非常实用的监督学习算法,主要用于分类和回归问题。KNN 基于相似性度量(如欧几里得距离)来进行预测,核心思想是给定一个样本,找到与其最接近的 K 个邻居,根据这些邻居的类别或特征对该样本进行分类或预测。

算法引入:

我们假设平面上有两类点的集合,一类属于A类,一类属于B类,A类有三个点,B类有三个点。如果我们要加入一个橙色的点,那么它是属于A类还是B类?

我们不去考虑算法的话,如果去归类A类还是B类,那么我们肯定想到的就是这个点它离哪个点最近,就属于哪一类。到这里就与我们的K近邻算法大概相同了,只不过我们要选取一个范围,在这个范围里找点进行判断,比如,我们选择它的邻域三个点,即K=3,那么这个区域里面有一个点属于A类,两个点属于B类,根据少数服从多数,我们就可以把它归于B类点。

K近邻算法有三个要素,如下图所示,第一个就是距离度量,这个距离有很多种距离,比如:欧几里得距离、曼哈顿距离、闵可夫斯基距离等。上面的例子中我们选择的欧几里得距离。第二个是K值,就是这个一个范围的的点个数,上面例子中,K=3。第三个是少数服从多数规则,上面例子中区域里面有一个点属于A类,两个点属于B类,根据少数服从多数,我们就可以把它归于B类点。


KNN算法的核心思想

1. 分类问题:

  • 给定一个未标记的数据点,通过计算该数据点与已标记的训练数据集中的每一个数据点的距离,选择距离最近的 K 个邻居。

  • 根据这 K 个邻居的类别,采用"多数投票"的方式来决定未标记数据点的类别。

2. 回归问题:

  • 计算未标记数据点的 K 个最近邻居的值,然后取这些邻居的平均值或加权平均值作为该点的预测值。

KNN算法的步骤

1. 数据准备: 准备好训练数据集,包括特征和标签(分类问题中为类别,回归问题中为数值)。
2. 选择K值: 选择邻居的数量K,一般是正整数。
3. 计算距离: 对每个未标记数据点,计算它与训练集中每一个数据点的距离(常见的距离度量方法有欧几里得距离、曼哈顿距离等)。
4. 选择K个最近邻居: 根据距离从小到大排序,选择距离最近的K个邻居。
**5. 投票或平均:**分类问题中,根据K个邻居的类别进行投票选择类别;回归问题中,计算邻居的平均值作为预测结果。


KNN常用的距离度量方法

1. 欧几里得距离:

欧几里得距离是最常用的距离度量方法,适用于连续变量的情况。

2. 曼哈顿距离:

曼哈顿距离适用于某些特定场景,尤其是当特征值的变化范围不均匀时。

3. 闵可夫斯基距离:

闵可夫斯基距离是欧几里得距离和曼哈顿距离的推广形式,其中 p 是一个参数,当 p=2 时,便是欧几里得距离,当 p=1 时便是曼哈顿距离。


KNN算法的优缺点

优点:

1. 简单易理解: KNN算法实现简单,易于理解和解释。
2. 无参数模型: KNN不需要训练过程,可以直接使用数据进行预测。
**3. 适用性广泛:**KNN可以用于分类和回归问题,并且对非线性数据有较好的适应性。

缺点:

1. 计算复杂度高: KNN算法需要对每一个测试样本都计算与所有训练样本的距离,因此在大数据集下计算开销较大。
2. 内存开销大: KNN需要存储整个训练数据集,占用较大的存储空间。
**3. 对噪声敏感:**KNN对噪声数据较为敏感,特别是在K值较小的情况下,少量噪声数据可能会对结果产生很大影响。


K值的选择

- 小K值 :K值较小时,模型会更加复杂,可能会过拟合。即使有少量噪声数据,也会对分类结果产生较大的影响。
- 大K值:K值较大时,模型会更加平滑,可能会欠拟合。K值过大会忽略数据的局部结构。

通常,K值通过交叉验证等方法来选择合适的值。


KNN的C++实现

下面是一个简单的KNN算法的C++实现,用于分类问题,采用欧几里得距离来计算邻居之间的距离。

cpp 复制代码
#include <iostream>
#include <vector>
#include <cmath>
#include <algorithm>
using namespace std;

// 定义一个点,包含特征和标签
struct Point {
    vector<double> features;
    int label;
};

// 计算欧几里得距离
double euclideanDistance(const vector<double>& a, const vector<double>& b) {
    double distance = 0.0;
    for (int i = 0; i < a.size(); i++) {
        distance += pow(a[i] - b[i], 2);
    }
    return sqrt(distance);
}

// KNN算法实现
int knn(const vector<Point>& train_data, const vector<double>& test_point, int k) {
    vector<pair<double, int>> distances; // 距离和标签的对
    
    // 计算每个训练数据点到测试点的距离
    for (const auto& point : train_data) {
        double distance = euclideanDistance(point.features, test_point);
        distances.push_back({distance, point.label});
    }

    // 按距离排序
    sort(distances.begin(), distances.end());

    // 统计前k个最近邻的类别
    vector<int> label_count(100, 0); // 假设标签在0-99之间
    for (int i = 0; i < k; i++) {
        label_count[distances[i].second]++;
    }

    // 返回出现次数最多的类别
    int max_count = 0;
    int predicted_label = -1;
    for (int i = 0; i < label_count.size(); i++) {
        if (label_count[i] > max_count) {
            max_count = label_count[i];
            predicted_label = i;
        }
    }

    return predicted_label;
}

int main() {
    int n, m, k;
    cin >> n;//训练数据的个数
    cin >> m;//测试数据的维度
    cin >> k;
    
    // 输入训练数据
    vector<Point> train_data(n);
    for (int i = 0; i < n; i++) {
        train_data[i].features.resize(m);
        for (int j = 0; j < m; j++) {
            cin >> train_data[i].features[j];
        }
        cin >> train_data[i].label;
    }

    // 输入测试点
    vector<double> test_point(m);
    for (int j = 0; j < m; j++) {
        cin >> test_point[j];
    }

    // 使用KNN进行分类
    int predicted_label = knn(train_data, test_point, k);
    cout << predicted_label << endl;

    return 0;
}

复杂度分析:

  • 时间复杂度 :对于每个测试点,KNN需要计算与所有训练点的距离,因此时间复杂度为 O(n * m),其中 n 是训练集大小,m 是特征维度。

  • 空间复杂度:主要用于存储训练数据和距离结果,空间复杂度为 O(n)。


K近邻算法是一个简单直观的非参数分类算法,适用于低维、小数据集的情况。然而,由于它的计算复杂性较高,KNN在大数据集或高维数据上的表现不佳。因此,KNN算法通常被用作基准模型或在小规模数据集上使用。

相关推荐
南宫生6 分钟前
力扣-图论-17【算法学习day.67】
java·学习·算法·leetcode·图论
不想当程序猿_18 分钟前
【蓝桥杯每日一题】求和——前缀和
算法·前缀和·蓝桥杯
IT古董28 分钟前
【机器学习】机器学习的基本分类-强化学习-策略梯度(Policy Gradient,PG)
人工智能·机器学习·分类
落魄君子29 分钟前
GA-BP分类-遗传算法(Genetic Algorithm)和反向传播算法(Backpropagation)
算法·分类·数据挖掘
centurysee30 分钟前
【最佳实践】Anthropic:Agentic系统实践案例
人工智能
mahuifa30 分钟前
混合开发环境---使用编程AI辅助开发Qt
人工智能·vscode·qt·qtcreator·编程ai
四口鲸鱼爱吃盐31 分钟前
Pytorch | 从零构建GoogleNet对CIFAR10进行分类
人工智能·pytorch·分类
冷眼看人间恩怨35 分钟前
【Qt笔记】QDockWidget控件详解
c++·笔记·qt·qdockwidget
菜鸡中的奋斗鸡→挣扎鸡37 分钟前
滑动窗口 + 算法复习
数据结构·算法
红龙创客44 分钟前
某狐畅游24校招-C++开发岗笔试(单选题)
开发语言·c++