1. kdtree概念
kd树(k-dimensional树的简称),是一种分割k维数据空间的数据结构,主要应用于多维空间关键数据的搜索,如范围搜索和最近邻搜索。
如下图所示,在既定的分割维度上,每一个根节点的值均大于其左子树,并小于其右子树。这样的二叉树,对于搜索某个点的最临近点或k近邻点,是十分高效快速的。
2. 建立kdtree
建立kdtree,主要有两步操作:选择合适的分割维度,选择中值节点作为分割节点。分割维度的选择遵循的原则是,选择范围最大的纬度,也即是方差最大的纬度作为分割维度;分割节点的选择原则是,将这一维度的数据进行排序,选择正中间的节点作为分割节点,确保节点左边的点的维度值小于节点的维度值,节点右边的点的维度值大于节点的维度值。
建立kdtree可遵循以下步骤:
1) 建立一维数组,存储每一个点的索引,并进行随机打乱。
2) 定义合适的kdtree函数定义,方便进行递归建树。
3) 编写分割维度函数
4) 编写选择分割节点函数
5) kdtree函数功能实现:选择分割维度,选择分割节点,将节点左边的数据进行递归建立左子树,将节点右边的数据进行递归建立右子树
下面通过实际代码,讲解kdtree建立的过程:
1)数据及索引的存储定义
无论是数据还是索引均存储在一维数组中,通过二维指针数组来索引,用一个指针数组来存储每一维数据的起始地址,用另一个指针数组来存储每一类索引的起始位置,比如分割维度、父节点、左子树、右子树
/* * dataPtr一维数组表示多维数组 * 数据排布方式:{[x1, x2, x3……], [y1, y2, y3……], [z1, z2, z3……], ……} */ /* * 所有数据存储在一维数组dataPtr里,data分别是x/y/z等数据的起始地址 * 因此,建树及knn只需传递数据的索引编号即可 */ float **data; float *dataPtr; int **tree; // 4 * n :分割维度、父节点、左子树、右子树 int *treePtr; // 使用一维数据表示二维数组,存储建立的kdtree索引
对定义的数组进行初始化操作:
1 int ZtKDTree::setSize(int dimension, unsigned int sz) 2 { 3 nDimension = dimension; // 数据的维度 4 treeSize = sz; // 数据的总数 5 6 if (nDimension > 0 && treeSize > 0) 7 { 8 offset = new double[nDimension]; 9 10 tree = new int *[4]; 11 treePtr = new int[4 * treeSize]; 12 for (int i = 0; i < 4; i++) 13 { 14 tree[i] = treePtr + i * treeSize; 15 } 16 17 data = new float *[nDimension]; 18 dataPtr = new float[nDimension * sz]; 19 for (int i = 0; i < nDimension; i++) 20 { 21 data[i] = dataPtr + i * treeSize; 22 } 23 } 24 25 return 0; 26 }
2) kdtree建立准备,建立一维数组存储数据索引,定义建树函数
使用一维数组存储每一个数据的索引,并进行随机打乱,建树过程中,可以通过索引来访问数据,并且不会打乱原来数据的顺序,快速排序等操作也不必操作数据,只需操作索引即可
1 int buildTree() 2 { 3 std::vector<int> vtr(treeSize); 4 5 for (int i = 0; i < treeSize; i++) 6 { 7 vtr[i] = i; 8 } 9 10 std::random_shuffle(vtr.begin(), vtr.end()); 11 12 treeRoot = buildTree(&vtr[0], treeSize, -1); // 根节点的父节点是-1 13 14 return treeRoot; 15 } 16 17 // 建立kdtree函数 18 int buildTree(int *indices, int count, int parent)
3)分割维度函数编写
分割维度的选择至关重要,选择合适的维度,可提高建树效率及搜索效率。计算当前空间的所有数据每一维度的方差,选择方差最大的维度作为分割维度,并顺便传出维度均值,以用于节点选择函数。
1 int chooseSplitDimension(int *ids, int sz, float &key) 2 { 3 int split = 0; 4 5 float *var = new float[nDimension]; 6 float *mean = new float[nDimension]; 7 8 int cnt = std::min((int)SAMPLE_MEAN, sz);/* cnt = sz;*/ 9 double rt = 1.0 / cnt; 10 11 for (int i = 0; i < nDimension; i++) 12 { 13 double sum1 = 0, sum2 = 0; 14 for (int j = 0; j < cnt; j++) 15 { 16 sum1 += rt * data[i][ids[j]] * data[i][ids[j]]; 17 sum2 += rt * data[i][ids[j]]; 18 } 19 var[i] = sum1 - sum2 * sum2; 20 mean[i] = sum2; 21 } 22 23 double max = 0; 24 25 for (int i = 0; i < nDimension; i++) 26 { 27 if (var[i] > max) 28 { 29 key = mean[i]; 30 max = var[i]; 31 split = i; 32 } 33 } 34 35 delete[] var; 36 delete[] mean; 37 38 return split; 39 }