Finding nearest neighbors using K-D Tree

Given N points {Pi} on a 1-D straight line along with their absolute positions {xi} along the line, which are the k nearest points to any randomly selected test point Pj on the line ? We find the absolute distance of Pj from all other points {Pi} i.e. |xj – xi | for all i ≠j, and then sort the points in increasing order of their distances and take out the first k distances. Although we do not need to store all the distances, but only need a max-heap of size k, which we will discuss later.

The same logic could be extended to points in two dimensions, with co-ordinates {(xi , yi )}. But now instead of finding distances in one dimension, we find euclidean distances in two dimensions i.e. D(Pj , Pi ) = (xj – xi )2 + (yj – yi )2. For points in d-dimensions, we find the d-dimensional distances of test point Pj with all other points Pi :

D(Pj , Pi ) = (x1, j – x1, i )2 + (x2, j – x2, i )2.+ … + (xd, j – xd, i )2, where xm, j denotes the position of point Pj along the mth dimension

and either sort or use max-heap to get the nearest k points to Pj . The running time for each distance computation is O(d) and since there are N points, we have O(N*d) run-time complexity for all distance computations. Next if we sort the distances that would add further O(N*log N) complexity compared to O(N*log k) complexity when keeping a max-heap of size k. Thus for each test point Pj the total run-time complexity comes to O(N*d + N*log k) for finding the k-nearest neighbours.

These kind of problems frequently appear in classification also. For example in text classification, given a document to classify, we find the “distances” of the document to other documents in the training set and then for a value of k, we select the nearest k documents, and based on the class information of each of the k documents, we classify our test document to the class which has the majority participation among the k nearest documents. The concept of distance in this case, comes from the fact that a document can be considered a vector in d-dimensional space, where ‘d’ is the number of features present in the document set. In the most simplistic representation, the entry for document D and feature F would be 1 if feature F is present in document D, else 0. Other approaches are weighting the entries using Term Frequencies and Inverse Document Frequencies.

For understanding K-D tree lets assume that we need to find the nearest point only, to a given test point. In our initial problem where all the points were lying on a 1-D straight line, we randomly select one of the points and make it as root of the BST. Then all points to the left of this point on the number line come to the left sub-tree of the BST and all points to the right, are on the right sub-tree of the BST. Then to find the nearest point for the test point Pj with position xj we compare xj  with root. If root is less than xj , then search for the nearest point on the right sub-tree, since any point on the left sub-tree would be further away. We recursively traverse the right sub-tree until we could not traverse further, in which case we either return the node value where we stopped or return the value of its parent node, whichever is “nearer”. Similar logic applies if xj is less than root, then in that case we search the left sub-tree of the root. Thus we see that our search for finding the nearest neighbour reduces to the problem of searching in a BST when the points can be represented as nodes of a BST.


But what if the points are two dimensional or multidimensional. There are no multidimensional binary search trees. How to split the tree on a particular node with multidimensional values ? For the case of two dimensional points, we randomly select a point (x0 , y0 ) as the root, then choosing the x-axis as the splitting axis, we keep all other points with xi < x0 , to the left of the root and all other points with xi >= x0 , to the right of the root.


Then at the next level, i.e. the left sub-tree and the right sub-tree, we repeat the above step by choosing a random point (x1 , y1 ) on the left sub-tree, but now instead of the x-axis, we split on the y-axis, i.e. all other points on the left sub-tree with yi < y1 , goes to the left of (x1 , y1 ) and all points with yi >= y1 , goes to the right of (x1 , y1 ). Similar step is followed with the right sub-tree. We alternate splitting on x-axis and y-axis at each level of the two dimensional BST. For d-dimensions, at each level L of the tree, we select the axis L%d (assuming that the axes are represented as indices, i.e. {0, 1, 2, …d-1}) as the splitting criteria. This is known as K-D Tree.


For exact searching of a point in a K-D Tree, we either find the point we are looking for by recursively traversing the tree or the point does not exists, in which case we return false. In the case of nearest neighbour, earlier we saw that for the case of points in a 1-dimensional straight line, we compared the distance of the test point, with the best candidate node and the parent of this node for the nearest distance, since the test point lies in between the node and its parent, both are potential nearest neighbours.


In the case of two dimensions, we consider a circle around the test point, with radius equal to the distance of the best candidate node (which is our best guess for nearest neighbour). If there are any other node lying within this circle then, we consider the nearest neighbour as the one having the minimum distance from the test point within this circle. For the general case of d-dimensional space, we would have a hypersphere centred at the test point and passing through our estimated nearest neighbour after recursively traversing down the tree.


But how do we prune the tree to include only potential nearest neighbours ? If suppose we find the best current estimate for the nearest neighbour after the split on the ith axis, then if the hypersphere centred at the test point and passing through the best estimated node, cuts through the ith axis then any other potential nearest neighbours could also be on the “other side” of the ith axis, in which case, we also have to search the “other side”, else we continue to search recursively on the “same side” of the best current estimate.

2d5But how to determine whether the hypersphere crosses the splitting hyperplane ? In two dimensions, if the split is on the y-axis, then this line is denoted by the equation y = y0 . Then any circle centred at (x1 , y1 ) crosses this line only if |y1-y0 | is greater than the radius of the circle. Similar is the explanation with d-dimensions, where instead of circle and line, we have hypersphere and hyperplane.

A node in the K-D Tree would have the following structure :


struct Node { Node* left; Node* right; int splittingAxis; double[] values};


where “values” are the position of the Node along each axis and “splittingAxis” is the axis which was used for splitting this Node. Following is the pseudocode for finding nearest neighbour in K-D Tee :

Node* NearestNeighbourSearch(Node* curr, Node* testNode, int dim) {

Node* currentNode = curr, otherSide, bestGuess1, bestGuess2;

double currentNodeValue = currentNode->value;
double currentDistance = distance(testNode, currentNode);

double diff = currentNodeValue - testNode->values[currentNode->splittingAxis];

if (diff < 0) {curr = currentNode->right; otherSide = currentNode->left;}
else {curr = currentNode->left; otherSide = currentNode->right;}

bestGuess1 = NearestNeighbourSearch(curr, testNode);
double distance1 = distance(testNode, bestGuess1);
double distance2;

if (distance1 > abs(diff)) {
bestGuess2 = NearestNeighbourSearch(otherSide, testNode);
distance2 = distance(testNode, bestGuess2);
else distance2 = Inf;

if (distance1 < distance2 && distance1 < currentDistance) return bestGuess1;
else if (distance2 < distance1 && distance2 < currentDistance) return bestGuess2;
else return currentNode;

To find the k-Nearest Neighbours, we have to modify the above pseudo-code, to return a max-heap (priority queue) of Nodes of size k, where the max-heap would be ordered based on the distances of the Nodes from the the test Node. If the heap is of size less than k, then we insert the best estimated Node among “bestGuess1”, “bestGuess2” and “currentNode” into heap and then “heapify” it, else if size of heap is already k, we compare the smallest distance among distance1, distance2 and currentDistance, with the root Node of the max-heap (whose distance is currently maximum among all the Nodes in the heap), if this smallest distance is less than root of max-heap, we pop the root of the heap and insert the corresponding Node and then heapify again, else we do nothing. After we come out of the recursion, our final max-heap would contain the k-Nearest Neighbours of the test Node.