问题描述
我正在尝试处理涉及查找最近邻居的 KDTree 类。我想我知道如何执行向下遍历到叶节点然后评估当前节点的部分。但是我对如何编写算法检查其他子树上可能存在任何点的部分的代码感到非常迷茫。到目前为止,这是我对该方法的了解:
template <int Dim>
Point<Dim> KDTree<Dim>::nearestNeighbor(const Point<Dim>& query,const Point<Dim>& current_best,int left,int right,int dimension) const {
//Base case: Reached a leaf node
if (left == right) {
if (shouldReplace(query,current_best,tree[left])) {
return tree[left];
}
return current_best;
}
int median_index = (left + right) / 2;
//Check right subtree
if (smallerDimVal(current_best,query,dimension)) {
current_best = nearestNeighbor(query,median_index + 1,right,(dimension + 1) % Dim);
}
//Check left subtree
if (smallerDimVal(query,left,median_index - 1,(dimension + 1) % Dim);
}
//Check the current node (the node we were at before doing all the recursion)
if (shouldReplace(query,tree[median_index])) {
current_best = tree[median_index];
}
//Check if the other subtree Could possibly contain a value closer
if (pow(query[curDim] - tree[median_index][curDim]) <= calculatedistance(query,current_best)) {
//What to do here????
}
return current_best;
}
这些是由nearestNeighbor 函数以及调用它的初始方法使用的类中的函数。我还提供了TreeNode结构供参考:
struct KDTreeNode
{
Point<Dim> point;
KDTreeNode *left,*right;
KDTreeNode() : point(),left(NULL),right(NULL) {}
KDTreeNode(const Point<Dim> &point) : point(point),right(NULL) {}
};
template <int Dim>
bool KDTree<Dim>::smallerDimVal(const Point<Dim>& first,const Point<Dim>& second,int curDim) const
{
if (curDim < 0 || curDim >= Dim) {
return false;
}
//If the coordinate of the first point at curDim is equal to the coordinate
//of the second point at curDim,then return whether or not first is less than second.
if (first[curDim] == second[curDim]) {
return (first < second);
}
//If the coordainte values differ,then return true if the coordinate of the
//first point at k is less than the coordinate of the second point at k.
return (first[curDim] < second[curDim]);
}
template <int Dim>
bool KDTree<Dim>::shouldReplace(const Point<Dim>& target,const Point<Dim>& currentBest,const Point<Dim>& potential) const
{
int target_current_distance = 0;
int target_potential_distance = 0;
for (int i = 0; i < Dim; i++) {
target_current_distance += ((currentBest[i] - target[i]) * (currentBest[i] - target[i]));
}
for (int i = 0; i < Dim; i++) {
target_potential_distance += ((potential[i] - target[i]) * (potential[i] - target[i]));
}
//Return true if the potential point is closer.
if (target_potential_distance != target_current_distance) {
return (target_potential_distance < target_current_distance);
}
else {
return (potential < currentBest);
}
}
template <int Dim>
double KDTree<Dim>::calculatedistance(const Point<Dim>& first,const Point<Dim>& second) const {
double distance = 0;
for (int i = 0; i < Dim; i++) {
distance += ((second[i] - first[i]) * (second[i] - first[i]));
}
return distance;
}
template <int Dim>
Point<Dim> KDTree<Dim>::findNearestNeighbor(const Point<Dim>& query) const
{
//query is the point where we want to find the closest distance to in the tree
int median_index = (tree.size() - 1) / 2;
return nearestNeighbor(query,tree[median_index],tree.size() - 1,0);
}
解决方法
编辑:已解决。也修改了方法签名。
template<int Dim>
typename KDTree<Dim>::KDTreeNode* KDTree<Dim>::nearestNeighbor(const Point<Dim>& query,int dimension,KDTreeNode* subroot) const {
//Base case: Query point is a point in the tree
if (query == subroot->point) {
return subroot;
}
//Base case: Subroot is a leaf
if (subroot->left == NULL && subroot->right == NULL) {
return subroot;
}
KDTreeNode* nearest_node;
bool recursed_left = false;
//Recursive case: Query point at current dimension is less than the point of the subroot at current dimension
if (smallerDimVal(query,subroot->point,dimension)) {
if (subroot->left != NULL) {
nearest_node = nearestNeighbor(query,(dimension + 1) % Dim,subroot->left);
recursed_left = true;
}
else {
nearest_node = nearestNeighbor(query,subroot->right);
}
}
//Recursive case: Query point at current dimension is greater than the point of the subroot at current dimension
else {
if (subroot->right != NULL) {
nearest_node = nearestNeighbor(query,subroot->right);
}
else {
nearest_node = nearestNeighbor(query,subroot->left);
recursed_left = true;
}
}
//Check if current root is closer
if (shouldReplace(query,nearest_node->point,subroot->point)) {
nearest_node = subroot;
}
//Radius between query point and the point currently labeled as nearest
double radius = calculateDistance(query,nearest_node->point);
//Split distance on plane
double split_distance = pow(subroot->point[dimension] - query[dimension],2);
if (radius >= split_distance) {
if (recursed_left) {
if (subroot->right != NULL) {
KDTreeNode* temp_nearest_node = nearestNeighbor(query,subroot->right);
if (shouldReplace(query,temp_nearest_node->point)) {
nearest_node = temp_nearest_node;
}
}
}
else {
if (subroot->left != NULL) {
KDTreeNode* temp_nearest_node = nearestNeighbor(query,subroot->left);
if (shouldReplace(query,temp_nearest_node->point)) {
nearest_node = temp_nearest_node;
}
}
}
}
return nearest_node;
}