sklearn 笔记 BallTree/KD Tree

由NearestNeighbors类包装

1 主要使用方法

python 复制代码
sklearn.neighbors.BallTree(X, leaf_size=40, metric='minkowski', **kwargs)

|-----------|----------------------------------------------|
| X | 数据集中的点数 |
| leaf_size | 改变 leaf_size 不会影响查询的结果,但可以显著影响查询的速度和构建树所需的内存 |
| metric | 用于距离计算的度量。默认为 "minkowski" |

2 主要方法

2.1 get_arrays

python 复制代码
import numpy as np
from sklearn.neighbors import BallTree
X = np.random.random((10, 3))
tree = BallTree(X)                
tree.get_arrays()

'''
(array([[0.90651098, 0.68471698, 0.6299996 ],
        [0.82751465, 0.31739009, 0.61572299],
        [0.22778906, 0.63614041, 0.73672184],
        [0.64655758, 0.9729849 , 0.68232389],
        [0.94992886, 0.72604933, 0.45649069],
        [0.34932115, 0.95985124, 0.41451989],
        [0.45131894, 0.21650206, 0.82466273],
        [0.87047096, 0.48403116, 0.58119046],
        [0.94468825, 0.14985636, 0.12132986],
        [0.62717326, 0.12924198, 0.23928098]]),
 array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int64),
 array([(0, 10, 1, 0.61638879)],
       dtype=[('idx_start', '<i8'), ('idx_end', '<i8'), ('is_leaf', '<i8'), ('radius', '<f8')]),
 array([[[0.68012737, 0.52767645, 0.53022429]]]))
'''
  • 返回了4个数组
    • 第一个数组:原始数据点数组

    • 第二个数组:整数数组,代表每个点的索引

    • 第三个数组:结构化数组,包含了 BallTree 的内部树结构的信息

      • idx_startidx_end:定义了存储在当前节点的点的索引范围。
      • is_leaf:表明当前节点是否是叶节点。
      • radius:当前节点中所有点到节点中心点的最大距离
    • 第四个数组:树的每个节点的中心点

2.2 get_tree_stats

获取 BallTree 的状态信息:树的剪枝次数、叶节点的数量、分裂次数

2.3 query

查询树以找到 k 个最近邻居

python 复制代码
query(X, k=1, return_distance=True, dualtree=False, breadth_first=False)

|-----------------------|------------------------------------------------------------------------------------------|
| X | 要查询的点的数组 |
| k | (int,默认为1) 要返回的最近邻居的数量 |
| return_distance | (bool,默认为True) 如果为 True,返回一个包含距离和索引的元组 (d, i); 如果为 False,只返回数组 i |
| dualtree | (bool,默认为False): 如果为 True,使用双树形式进行查询:为查询点构建一个树,并使用这对树来高效地搜索这个空间当点的数量变得很大时,这可以带来更好的性能 |
| breadth_first | (bool,默认为False) 如果为 True,则以广度优先的方式查询节点。否则,以深度优先的方式查询 |
| sort_results | (bool,默认为True) 如果为 True,则在返回时对每个点的距离和索引进行排序,使得第一列包含最近的点 |

python 复制代码
import numpy as np
from sklearn.neighbors import BallTree
X = np.random.random((100, 3))
tree = BallTree(X)                
tree.query(X[:3],k=3)
'''
(array([[0.        , 0.08335798, 0.15625817],
        [0.        , 0.06843236, 0.10825558],
        [0.        , 0.0968137 , 0.10245125]]),
 array([[ 0, 59, 88],
        [ 1, 70,  5],
        [ 2, 43, 20]], dtype=int64))
'''

2.4 query_radius

  • 进行半径查询的功能
  • 查询树,以找出在指定半径 r 内的邻居点
python 复制代码
query_radius(X, r, return_distance=False, count_only=False, sort_results=False)

|-----------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------|
| X | 要查询的点的数组 |
| r | 返回邻居的距离范围 r 可以是单个值,也可以是一个数组,形状为 x.shape[:-1],如果每个点需要不同的半径 |
| return_distance | (bool,默认为False) 如果为 True,则返回每个点的邻居距离;如果为 False,则只返回邻居 query() 方法不同, 这里设置 return_distance=True 会增加计算时间。如果 return_distance=False,并不需要显式计算所有距离 |
| count_only | (bool,默认为False) 如果为 True,则只返回距离 r 内的点的数量; 如果为 False,则返回距离 r 内所有点的索引 |
| sort_results | (bool,默认为False) 如果为 True,则在返回之前对距离和索引进行排序。如果为 False,则结果不排序 |

python 复制代码
import numpy as np
from sklearn.neighbors import BallTree
X = np.random.random((100, 3))
tree = BallTree(X)                
tree.query_radius(X[:3],r=0.3)
'''
array([array([ 0, 68, 11, 31, 46, 19, 36, 63, 16, 86, 79], dtype=int64),
       array([26, 64, 20, 94,  1,  4, 13,  3], dtype=int64),
       array([35, 50, 30, 83, 85, 18, 15, 53,  2, 96, 81], dtype=int64)],
      dtype=object)
'''

2.5 two_point_correlation

计算距离小于等于r[i]的点的数量

python 复制代码
two_point_correlation(X, r, dualtree=False)

|----------------|------------------------------------------------------|
| X | 要查询的点集 |
| r | 一维数组,包含距离值 |
| dualtree | 如果为 True,则使用双树算法;否则,使用单树算法。 对于大量数据点(N),双树算法可能有更好的扩展性 |

返回值

counts (ndarray) : counts[i] 包含距离小于或等于 r[i] 的点对数

python 复制代码
import numpy as np
from sklearn.neighbors import BallTree
X = np.random.random((100, 3))
r=np.linspace(0.1,1,5)
tree = BallTree(X)                
tree.two_point_correlation(X[:3],r=r)
#array([  4,  34,  99, 196, 263], dtype=int64)
'''
返回的第一个值:和X[0]的距离小于r[0]的数量+和X[1]的距离小于r[0]的数量+和X[2]的距离小于r[0]的数量
'''

3 KD-Tree

和Ball-Tree 一模一样

相关推荐
LeonDL1687 分钟前
【通用视觉框架】基于C#+Winform+OpencvSharp开发的视觉框架软件,全套源码,开箱即用
人工智能·c#·winform·opencvsharp·机器视觉软件框架·通用视觉框架·机器视觉框架
AI纪元故事会10 分钟前
《目标检测全解析:从R-CNN到DETR,六大经典模型深度对比与实战指南》
人工智能·yolo·目标检测·r语言·cnn
Shang1809893572634 分钟前
T41LQ 一款高性能、低功耗的系统级芯片(SoC) 适用于各种AIoT应用智能安防、智能家居方案优选T41L
人工智能·驱动开发·嵌入式硬件·fpga开发·信息与通信·信号处理·t41lq
安全不再安全42 分钟前
免杀技巧 - 早鸟注入详细学习笔记
linux·windows·笔记·学习·测试工具·web安全·网络安全
Bony-1 小时前
用于糖尿病视网膜病变图像生成的GAN
人工智能·神经网络·生成对抗网络
罗西的思考1 小时前
【Agent】 ACE(Agentic Context Engineering)源码阅读笔记---(3)关键创新
人工智能·算法
Elastic 中国社区官方博客1 小时前
通过混合搜索重排序提升多语言嵌入模型的相关性
大数据·数据库·人工智能·elasticsearch·搜索引擎·ai·全文检索
猫头虎1 小时前
昆仑芯 X HAMi X 百度智能云 | 昆仑芯 P800 XPU/vXPU 双模式算力调度方案落地
人工智能·百度·开源·aigc·文心一言·gpu算力·agi
大千AI助手1 小时前
探索LoSA:动态低秩稀疏自适应——大模型高效微调的新突破
人工智能·神经网络·lora·大模型·llm·大千ai助手·稀疏微调
说私域1 小时前
“开源链动2+1模式AI智能名片S2B2C商城小程序”在拉群营销中的应用与效果
人工智能·小程序