k-d tree即k-dimensionaltree,常用来作空间划分及近邻搜索,是二叉空间划分树的一个特例。通常,对于维度为k,数据点数为N的数据集,k-d tree适用于N>>2^k的情形,kd树是基于欧式距离度量的。
k-d树是每个节点都为k维点的二叉树。所有非叶子节点可以视作用一个超平面把空间分区成两个半空间。节点左边的子树代表在超平面左边的点,节点右边的子树代表在超平面右边的点。选择超平面的方法如下:每个节点都与k维中垂直于超平面的那一维有关。因此,如果选择按照x轴划分,所有x值小于指定值的节点都会出现在左子树,所有x值大于指定值的节点都会出现在右子树。这样,超平面可以用该x值来确定,其法线为x轴的单位向量。
scikit-learning代码演示如下
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.patches import Circle
from sklearn.neighbors import KDTree
from sklearn.datasets.samples_generator import make_blobs
#生成类似聚类的数据
center = [[1, 1], [-1, 1], [1, -1], [-1,-1]]
x, y = make_blobs(n_samples = 75, centers=center, random_state=1, cluster_std=0.6)
#print(x)
#构造kd树
tree = KDTree(x)
point = [x[8]]
# kNN 找到point的K临近
distance1, index1 = tree.query(point, k=6, return_distance=True)
print(index1)
print(distance1)
# 找到point的半径临近
index2, distance2 = tree.query_radius(point, r=0.6, return_distance=True)
print(index2)
print(distance2)
#数据可视化
fig = plt.figure(figsize=(6,6))
ax = fig.add_subplot(111)
plt.title("k-d tree")
plt.xlabel("x", size=17)
plt.ylabel("y", size=17)
#所有随机点
plt.scatter(x[:, 0], x[:, 1],c=y, s=20)
#以point为圆心,半径为0.6画圆
cir = Circle(point[0], 0.6, color='r', fill=False)
for index in index1[0]:
#将point和他的k临近连线
plt.plot([point[0][0], x[index][0]], [point[0][1], x[index][1]], 'k--', linewidth=1.8)
ax.add_patch(cir)
plt.show()
输出:
index1: [[ 8 48 4 0 32 36]]
distance1 :[[0. 0.16666319 0.26908474 0.30946747 0.34290224 0.35278116]]
index2:[array([ 0, 4, 8, 10, 32, 36, 48, 57], dtype=int64)]
distance2 :[array([0.30946747, 0.26908474, 0. , 0.49961602, 0.34290224,
0.35278116, 0.16666319, 0.43102397])]