机器学习-基于KNN算法手动实现kd树

目录

一、概括

二、KD树的构建流程

1.循环选轴

2.选择分裂点

三、kd树的查询

1.输入我们要搜索的点

2.递归向下遍历:

3.记录最近点

4.回溯父节点:

四、KD树的优化与变种:

五、KD树代码:


上一章我们将了机器学习-手搓KNN算法,这一章我们加上kd树对它进行优化,下面先来讲讲kd树。

KD 树(K-Dimensional Tree)是一种高效的K 维空间数据索引结构,主要用于最近邻搜索和范围搜索。以下从原理、构建、查询、优化等方面详细讲解:

一、概括

KD树通过递归划分k维空间,将数据点组织成二叉树结构:每一个节点代表一个k维超矩形空,比如在二维空间中,就是一个矩形包围一个点,三维就是一个体来包围一个点。然后使用二叉树将这些点连接起来,父节点选择一个维度作为分裂轴,用该维度的中位数将区域划分维左子树(小于分裂轴的点)和右子树(大于等于分裂轴的点)

二、KD树的构建流程

以X = [(2,3), (5,4), (9,6), (4,7), (8,1), (7,2)]为例

1.循环选轴

先计算每个维度的方差,X的x轴数据有(2,5,9,4,8,7)方差为5.8055。x轴数据有(3,4,6,7,1,2)方差为4.4722。因为x轴的方差大于y轴,那么先选择x轴。等到x轴分裂后下一次就是y轴,如果还有别的维度那么继续循环,循环结束后又回到y轴开始下一轮的循环

2.选择分裂点

第一次在上面选择完成后的轴x上,选择该轴的中位数,数据为((2,5,9,4,8,7)那么中位数为(5,4)那么在该点上分裂,分裂后的左子树为[(2,3)],右子树为[(9,6), (4,7), (8,1), (7,2)]

第二次选择y轴:在上面的右子树中的中位数为(7,2),那么根据中位数分裂后左子树为[(4,7), (8,1)],右子树为:[(9,6)]。继续循环,循环结束后树结构为:

复制代码
      (5,4) (x轴分裂)
     /        \
(2,3)      (7,2) (y轴分裂)
             /      \
        (4,7)    (9,6)
             \
              (8,1)

三、kd树的查询

既然设计到树,那么肯定有增删改查。

1.输入我们要搜索的点

最近邻搜索的目的是找到我们要查询的点的最近的K个点,那么目标就变成了在我们的KD树中寻找到距离搜索点的最小距离的K个点。

2.递归向下遍历:

从根节点开始,根据当前分裂轴比较我们要搜索的点,如果比我们要搜索的点大就去右子树,小就去左子树。

3.记录最近点

等到第二步递归到叶子节点时,那么这个叶子节点就是距离我们要搜索的点最近的点,将这个点记录下来

4.回溯父节点:

计算我们搜索到的点到我们要搜索的点的距离,因为还要遍历另外一边的最近点,比如刚刚遍历的是左子树,那么现在要遍历右子树了,每次回溯到父节点后都要将搜索到的点与上一次搜索的最近点比较距离大小,将小的留下

示例:

以上面的例子为例:比如查找(6,3)的最近点

1.从根节点(5,4)出发,x 轴分裂,6>5,进入右子树(7,2)。

2.(7,2)是 y 轴分裂,3>2,进入右子树(9,6),记录最近点为(9,6)(距离√[(6-9)²+(3-6)²]=√18)。

3.回溯到(7,2),计算 y 轴分裂超平面距离为 | 3-2|=1 < √18,检查左子树(4,7)和(8,1)。

在左子树中,(8,1)距离为√[(6-8)²+(3-1)²]=√8,更近,更新最近点。

4.回溯到根节点(5,4),计算 x 轴分裂超平面距离为 | 6-5|=1 < √8,检查左子树(2,3),距离√[(6-2)²+(3-3)²]=4 > √8,不更新。最终最近点为 (8,1)。

2.范围搜索

目标:找到所有在k维超矩形区域内的点。

这个方法是先设置一个距离,然后递归遍历树,若当前节点的分裂轴到我们查询点的距离超过了我们设置的距离,则直接剪枝就是不去遍历这个节点以后的点了,如果这个节点在查询区域内则加入结果集,继续搜索子树

四、KD树的优化与变种:

1.BBF算法:使用有线队列优化最近邻搜索,减少回溯次数

2.Ball树:用超球体代替超矩形,更高效处理高维数据(普通KD树在维度>20时性能明显下降)

3.k-d-B树:结合KD树和B树,支持动态插入和删除

五、KD树代码:

python 复制代码
import numpy as np
from collections import deque

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler


class KDNode:
    def __init__(self,point,left=None,right=None,axis=None):
        self.point=point  # 数据点[]
        self.left=left
        self.right=right
        self.axis=axis

class KDTree:
    def __init__(self,data,labels):
        self.data=np.c_[data,labels]
        self.root=self.build_tree(self.data)

    def build_tree(self,points,depth=0):
        if len(points)==0:
            return None

        k=points.shape[1]-1
        axis=depth%k

        sorted_points=points[points[:,axis].argsort()]
        median_idx=len(sorted_points)//2
        median_point=sorted_points[median_idx]

        left=self.build_tree(sorted_points[:median_idx],depth+1)
        right=self.build_tree(sorted_points[median_idx+1:],depth+1)

        return KDNode(median_point,left,right,axis)

    def query_knn(self, target, k):
        best_candidates = []  # 保存最近的k个邻居(按距离倒序存储)
        candidates = deque()  # 使用双端队列实现非递归遍历
        candidates.append((self.root, False))  # (当前节点, 是否已访问)

        while candidates:
            node, visited = candidates.pop()
            if node is None:
                continue

            if not visited:
                # 计算当前节点到目标的欧氏距离(排除标签列)
                distance = np.sqrt(np.sum((node.point[:-1] - target)  ** 2))

                # 维护长度为k的优先队列(使用负距离实现最大堆)
                if len(best_candidates) < k:
                    best_candidates.append((-distance, node.point))
                    best_candidates.sort(reverse=True)  # 按距离从大到小排序
                else:
                    if distance < -best_candidates[0][0]:
                        best_candidates.pop()  # 移除最远候选
                        best_candidates.append((-distance, node.point))
                        best_candidates.sort(reverse=True)

                # 根据切分维度决定搜索路径(类似二叉搜索树)
                axis = node.axis
                if target[axis] < node.point[axis]:
                    candidates.append((node, True))  # 标记当前节点已访问
                    candidates.append((node.left, False))  # 先搜索左子树
                else:
                    candidates.append((node, True))
                    candidates.append((node.right, False))  # 先搜索右子树
            else:
                # 回溯检查另一侧子树是否需要搜索(剪枝优化)
                axis = node.axis
                worst_dist = -best_candidates[0][0] if best_candidates else np.inf
                # 判断目标点到分割超平面的距离是否小于当前最远邻居距离
                if (len(best_candidates) < k) or \
                        (abs(target[axis] - node.point[axis]) < worst_dist):
                    if target[axis] < node.point[axis]:
                        candidates.append((node.right, False))  # 搜索右子树
                    else:
                        candidates.append((node.left, False))  # 搜索左子树

        # 返回前k个邻居的标签(按距离从近到远排序)
        return [point[-1] for (dist, point) in sorted(best_candidates, reverse=True)]


class KNNWithKDTree:
    def __init__(self, k=5):
        self.k = k  # 最近邻数量K
        self.kdtree = None  # 存储构建好的KD树

    def fit(self, X, y):
        # 构建KD树(将训练数据和标签传入)
        self.kdtree = KDTree(X, y)

    def predict(self, X_test):
        predictions = []
        for x in X_test:
            # 获取当前测试样本的K个最近邻标签
            neighbors = self.kdtree.query_knn(x, self.k)
            # 多数投票(取出现次数最多的类别)
            most_common = max(set(neighbors), key=neighbors.count)
            predictions.append(most_common)
        return np.array(predictions)


if __name__ == '__main__':
    # 加载鸢尾花数据集
    iris = load_iris()
    X, y = iris.data, iris.target

    # 数据标准化(消除量纲影响)
    scaler = StandardScaler()
    X = scaler.fit_transform(X)

    # 划分训练集和测试集(70%训练,30%测试)
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

    # 初始化KNN分类器(K=5)
    knn = KNNWithKDTree(k=5)
    knn.fit(X_train, y_train)  # 训练模型(构建KD树)

    # 预测测试集结果
    y_pred = knn.predict(X_test)

    # 计算准确率
    accuracy = np.sum(y_pred == y_test) / len(y_test)
    print(f"准确率: {accuracy:.4f}")  # 输出如:准确率: 0.9778
相关推荐
救救孩子把12 分钟前
集成开发环境革新:IntelliJ IDEA与Cursor AI的智能演进
java·人工智能·intellij-idea
jndingxin18 分钟前
OpenCV图像拼接(6)图像拼接模块的用于创建权重图函数createWeightMap()
人工智能·opencv·计算机视觉
九亿AI算法优化工作室&1 小时前
SA模拟退火算法优化高斯回归回归预测matlab代码
人工智能·python·算法·随机森林·matlab·数据挖掘·模拟退火算法
么耶咩_5151 小时前
排序复习_代码纯享
数据结构·算法
Blossom.1181 小时前
基于Python的机器学习入门指南
开发语言·人工智能·经验分享·python·其他·机器学习·个人开发
藍海琴泉2 小时前
蓝桥杯算法精讲:二分查找实战与变种解析
python·算法
默 语2 小时前
10分钟打造专属AI助手!ToDesk云电脑/顺网云/海马云操作DeepSeek哪家强?
人工智能·电脑·todesk
大刀爱敲代码3 小时前
基础算法01——二分查找(Binary Search)
java·算法
Donvink4 小时前
【Dive Into Stable Diffusion v3.5】2:Stable Diffusion v3.5原理介绍
人工智能·深度学习·语言模型·stable diffusion·aigc·transformer
宇灵梦4 小时前
大模型金融企业场景落地应用
人工智能