[mlpack-git] master, mlpack-1.0.x: R* tree split. Default to using the R* tree in allknn (4c921ed)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:53:56 EST 2015


Repository : https://github.com/mlpack/mlpack

On branches: master,mlpack-1.0.x
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40

>---------------------------------------------------------------

commit 4c921ed15eea8ce697d7b0a6c6cc1484ad611cc1
Author: andrewmw94 <andrewmw94 at gmail.com>
Date:   Mon Jul 21 14:49:51 2014 +0000

    R* tree split.  Default to using the R* tree in allknn


>---------------------------------------------------------------

4c921ed15eea8ce697d7b0a6c6cc1484ad611cc1
 src/mlpack/core/tree/rectangle_tree.hpp            |   2 +-
 .../core/tree/rectangle_tree/r_star_tree_split.hpp |  10 +
 .../tree/rectangle_tree/r_star_tree_split_impl.hpp | 326 +++++++++++----------
 .../core/tree/rectangle_tree/r_tree_split.hpp      |   2 +-
 .../core/tree/rectangle_tree/r_tree_split_impl.hpp |  14 +-
 .../tree/rectangle_tree/rectangle_tree_impl.hpp    |  38 +--
 src/mlpack/methods/neighbor_search/allknn_main.cpp |  10 +-
 src/mlpack/tests/rectangle_tree_test.cpp           |  58 +++-
 8 files changed, 269 insertions(+), 191 deletions(-)

diff --git a/src/mlpack/core/tree/rectangle_tree.hpp b/src/mlpack/core/tree/rectangle_tree.hpp
index 526b600..2270a82 100644
--- a/src/mlpack/core/tree/rectangle_tree.hpp
+++ b/src/mlpack/core/tree/rectangle_tree.hpp
@@ -18,7 +18,7 @@
 #include "rectangle_tree/dual_tree_traverser.hpp"
 #include "rectangle_tree/dual_tree_traverser_impl.hpp"
 #include "rectangle_tree/r_tree_split.hpp"
-//#include "rectangle_tree/r_star_tree_split.hpp"
+#include "rectangle_tree/r_star_tree_split.hpp"
 #include "rectangle_tree/r_tree_descent_heuristic.hpp"
 #include "rectangle_tree/r_star_tree_descent_heuristic.hpp"
 #include "rectangle_tree/traits.hpp"
diff --git a/src/mlpack/core/tree/rectangle_tree/r_star_tree_split.hpp b/src/mlpack/core/tree/rectangle_tree/r_star_tree_split.hpp
index 646eba9..c7611e3 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_star_tree_split.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_star_tree_split.hpp
@@ -43,6 +43,7 @@ private:
  * Class to allow for faster sorting.
  */
 class sortStruct {
+public:
   double d;
   int n;
 };
@@ -54,6 +55,15 @@ static bool structComp(const sortStruct& s1, const sortStruct& s2) {
   return s1.d < s2.d;
 }
 
+/**
+  * Insert a node into another node.
+  */
+static void InsertNodeIntoTree(
+    RectangleTree<RStarTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>* destTree,
+    RectangleTree<RStarTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>* srcNode);
+
+};
+
 }; // namespace tree
 }; // namespace mlpack
 
diff --git a/src/mlpack/core/tree/rectangle_tree/r_star_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_star_tree_split_impl.hpp
index b80d8c9..22fc327 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_star_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_star_tree_split_impl.hpp
@@ -40,16 +40,16 @@ void RStarTreeSplit<DescentType, StatisticType, MatType>::SplitLeafNode(
     return;
   }
 
-  int bestOverlapIndexOnBestAxis;
-  int bestAreaIndexOnBestAxis;
+  int bestOverlapIndexOnBestAxis = 0;
+  int bestAreaIndexOnBestAxis = 0;
   bool tiedOnOverlap = false;
   int bestAxis = 0;
   double bestAxisScore = DBL_MAX;
-  for(int j = 0; j < tree->Bound().Dim(); j++) {
+  for (int j = 0; j < tree->Bound().Dim(); j++) {
     double axisScore = 0.0;
     // Since we only have points in the leaf nodes, we only need to sort once.
     std::vector<sortStruct> sorted(tree->Count());
-    for(int i = 0; i < sorted.size(); i++) {
+    for (int i = 0; i < sorted.size(); i++) {
       sorted[i].d = tree->LocalDataset().col(i)[j];
       sorted[i].n = i;
     }
@@ -57,16 +57,15 @@ void RStarTreeSplit<DescentType, StatisticType, MatType>::SplitLeafNode(
     std::sort(sorted.begin(), sorted.end(), structComp);
 
     // We'll store each of the three scores for each distribution.
-    std::vector<double> areas(tree->MaxLeafSize() - 2*tree->MinLeafSize() + 2);
-    std::vector<double> margins(tree->MaxLeafSize() - 2*tree->MinLeafSize() + 2);
-    std::vector<double> overlapedAreas(tree->MaxLeafSize() - 2*tree->MinLeafSize() + 2);
-    for(int i = 0; i < areas.size(); i++) {
+    std::vector<double> areas(tree->MaxLeafSize() - 2 * tree->MinLeafSize() + 2);
+    std::vector<double> margins(tree->MaxLeafSize() - 2 * tree->MinLeafSize() + 2);
+    std::vector<double> overlapedAreas(tree->MaxLeafSize() - 2 * tree->MinLeafSize() + 2);
+    for (int i = 0; i < areas.size(); i++) {
       areas[i] = 0.0;
       margins[i] = 0.0;
       overlapedAreas[i] = 0.0;
     }
-  
-    for(int i = 0; i < areas.size(); i++) {
+    for (int i = 0; i < areas.size(); i++) {
       // The ith arrangement is obtained by placing the first tree->MinLeafSize() + i 
       // points in one rectangle and the rest in another.  Then we calculate the three
       // scores for that distribution.
@@ -77,58 +76,57 @@ void RStarTreeSplit<DescentType, StatisticType, MatType>::SplitLeafNode(
       std::vector<double> minG1(maxG1.size());
       std::vector<double> maxG2(maxG1.size());
       std::vector<double> minG2(maxG1.size());
-      for(int k = 0; k < tree->Bound().Dim(); k++) {
-	minG1[k] = maxG1[k] = tree->LocalDataset().col(sorted[0].n)[k];
-	minG2[k] = maxG2[k] = tree->LocalDataset().col(sorted[sorted.size()-1])[k];
-	for(int l = 1; l < tree->Count()-1; l++) {
-          if(l < cutOff) {
-	    if(tree->LocalDataset().col(sorted[l].n)[k] < minG1[k])
-	      minG1[k] = tree->LocalDataset().col(sorted[l].n)[k];
-	    else if(tree->LocalDataset().col(sorted[l].n)[k] > maxG1[k])
-	      maxG1[k] = tree->LocalDataset().col(sorted[l].n)[k];
-	  } else {
-	    if(tree->LocalDataset().col(sorted[l].n)[k] < minG2[k])
-	      minG2[k] = tree->LocalDataset().col(sorted[l].n)[k];
-	    else if(tree->LocalDataset().col(sorted[l].n)[k] > maxG2[k])
-	      maxG2[k] = tree->LocalDataset().col(sorted[l].n)[k];
+      for (int k = 0; k < tree->Bound().Dim(); k++) {
+        minG1[k] = maxG1[k] = tree->LocalDataset().col(sorted[0].n)[k];
+        minG2[k] = maxG2[k] = tree->LocalDataset().col(sorted[sorted.size() - 1].n)[k];
+        for (int l = 1; l < tree->Count() - 1; l++) {
+          if (l < cutOff) {
+            if (tree->LocalDataset().col(sorted[l].n)[k] < minG1[k])
+              minG1[k] = tree->LocalDataset().col(sorted[l].n)[k];
+            else if (tree->LocalDataset().col(sorted[l].n)[k] > maxG1[k])
+              maxG1[k] = tree->LocalDataset().col(sorted[l].n)[k];
+          } else {
+            if (tree->LocalDataset().col(sorted[l].n)[k] < minG2[k])
+              minG2[k] = tree->LocalDataset().col(sorted[l].n)[k];
+            else if (tree->LocalDataset().col(sorted[l].n)[k] > maxG2[k])
+              maxG2[k] = tree->LocalDataset().col(sorted[l].n)[k];
           }
-	}
+        }
       }
       double area1 = 1.0, area2 = 1.0;
       double oArea = 1.0;
-      for(int k = 0; k < maxG1.size(); k++) {
-	margins[i] += maxG1[k] - minG1[k] + maxG2[k] - minG2[k];
-	area1 *= maxG1[k] - minG1[k];
-	area2 *= maxG2[k] - minG2[k];
-	oArea *= maxG1[k] < minG2[k] || maxG2[k] < minG1[k] ? 0.0 : std::min(maxG1[k], maxG2[k]) - std::max(minG1[k], minG2[k]);
+      for (int k = 0; k < maxG1.size(); k++) {
+        margins[i] += maxG1[k] - minG1[k] + maxG2[k] - minG2[k];
+        area1 *= maxG1[k] - minG1[k];
+        area2 *= maxG2[k] - minG2[k];
+        oArea *= maxG1[k] < minG2[k] || maxG2[k] < minG1[k] ? 0.0 : std::min(maxG1[k], maxG2[k]) - std::max(minG1[k], minG2[k]);
       }
       areas[i] += area1 + area2;
       overlapedAreas[i] += oArea;
       axisScore += margins[i];
     }
-    if(axisScore < bestAxisScore) {
+
+    if (axisScore < bestAxisScore) {
       bestAxisScore = axisScore;
       bestAxis = j;
       double bestOverlapIndexOnBestAxis = 0;
       double bestAreaIndexOnBestAxis = 0;
-      for(int i = 1; i < areas.size(); i++) {
-	if(overlapedAreas[i] < overlapedAreas[bestOverlapIndexOnBestAxis]) {
-	  tiedOnOverlap = false;
-	  bestAreaIndexOnBestAxis = i;
-	  bestOverlapIndexOnBestAxis = i;
-	}
-	else if(overlapedAreas[i] == overlapedAreas[bestOverlapIndexOnBestAxis]) {
-	  tiedOnOverlap = true;
-	  if(areas[i] < areas[bestAreaIndexOnBestAxis])
-	    bestAreaIndexOnBestAxis = i;
-	}
+      for (int i = 1; i < areas.size(); i++) {
+        if (overlapedAreas[i] < overlapedAreas[bestOverlapIndexOnBestAxis]) {
+          tiedOnOverlap = false;
+          bestAreaIndexOnBestAxis = i;
+          bestOverlapIndexOnBestAxis = i;
+        } else if (overlapedAreas[i] == overlapedAreas[bestOverlapIndexOnBestAxis]) {
+          tiedOnOverlap = true;
+          if (areas[i] < areas[bestAreaIndexOnBestAxis])
+            bestAreaIndexOnBestAxis = i;
+        }
       }
     }
   }
 
-
   std::vector<sortStruct> sorted(tree->Count());
-  for(int i = 0; i < sorted.size(); i++) {
+  for (int i = 0; i < sorted.size(); i++) {
     sorted[i].d = tree->LocalDataset().col(i)[bestAxis];
     sorted[i].n = i;
   }
@@ -140,19 +138,19 @@ void RStarTreeSplit<DescentType, StatisticType, MatType>::SplitLeafNode(
   RectangleTree<RStarTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType> *treeTwo = new
           RectangleTree<RStarTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>(tree->Parent());
 
-  if(tiedOnOverlap) {
-    for(int i = 0; i < tree.Count(); i++) {
-      if(i < bestAreaIndexOnBestAxis)
-	treeOne->InsertPoint(tree->Points()[sorted[i].n]);
+  if (tiedOnOverlap) {
+    for (int i = 0; i < tree->Count(); i++) {
+      if (i < bestAreaIndexOnBestAxis + tree->MinLeafSize())
+        treeOne->InsertPoint(tree->Points()[sorted[i].n]);
       else
-	treeTwo->InsertPoint(tree->Points()[sorted[i].n]);
+        treeTwo->InsertPoint(tree->Points()[sorted[i].n]);
     }
   } else {
-    for(int i = 0; i < tree.Count(); i++) {
-      if(i < bestOverlapIndexOnBestAxis)
-	treeOne->InsertPoint(tree->Points()[sorted[i].n]);
+    for (int i = 0; i < tree->Count(); i++) {
+      if (i < bestOverlapIndexOnBestAxis + tree->MinLeafSize())
+        treeOne->InsertPoint(tree->Points()[sorted[i].n]);
       else
-	treeTwo->InsertPoint(tree->Points()[sorted[i].n]);
+        treeTwo->InsertPoint(tree->Points()[sorted[i].n]);
     }
   }
 
@@ -211,18 +209,18 @@ bool RStarTreeSplit<DescentType, StatisticType, MatType>::SplitNonLeafNode(
     return true;
   }
 
-  int bestOverlapIndexOnBestAxis;
-  int bestAreaIndexOnBestAxis;
+  int bestOverlapIndexOnBestAxis = 0;
+  int bestAreaIndexOnBestAxis = 0;
   bool tiedOnOverlap = false;
   bool lowIsBest = true;
   int bestAxis = 0;
   double bestAxisScore = DBL_MAX;
-  for(int j = 0; j < tree->Bound().Dim(); j++) {
+  for (int j = 0; j < tree->Bound().Dim(); j++) {
     double axisScore = 0.0;
 
     // We'll do Bound().Lo() now and use Bound().Hi() later.
-    std::vector<sortStruct> sorted(tree->Count());
-    for(int i = 0; i < sorted.size(); i++) {
+    std::vector<sortStruct> sorted(tree->NumChildren());
+    for (int i = 0; i < sorted.size(); i++) {
       sorted[i].d = tree->Child(i)->Bound()[j].Lo();
       sorted[i].n = i;
     }
@@ -230,16 +228,16 @@ bool RStarTreeSplit<DescentType, StatisticType, MatType>::SplitNonLeafNode(
     std::sort(sorted.begin(), sorted.end(), structComp);
 
     // We'll store each of the three scores for each distribution.
-    std::vector<double> areas(tree->MaxNumChildren() - 2*tree->MinNumChildren() + 2);
-    std::vector<double> margins(tree->MaxNumChildren() - 2*tree->MinNumChildren() + 2);
-    std::vector<double> overlapedAreas(tree->MaxNumChildren() - 2*tree->MinNumChildren() + 2);
-    for(int i = 0; i < areas.size(); i++) {
+    std::vector<double> areas(tree->MaxNumChildren() - 2 * tree->MinNumChildren() + 2);
+    std::vector<double> margins(tree->MaxNumChildren() - 2 * tree->MinNumChildren() + 2);
+    std::vector<double> overlapedAreas(tree->MaxNumChildren() - 2 * tree->MinNumChildren() + 2);
+    for (int i = 0; i < areas.size(); i++) {
       areas[i] = 0.0;
       margins[i] = 0.0;
       overlapedAreas[i] = 0.0;
     }
 
-    for(int i = 0; i < areas.size(); i++) {
+    for (int i = 0; i < areas.size(); i++) {
       // The ith arrangement is obtained by placing the first tree->MinNumChildren() + i 
       // points in one rectangle and the rest in another.  Then we calculate the three
       // scores for that distribution.
@@ -250,63 +248,63 @@ bool RStarTreeSplit<DescentType, StatisticType, MatType>::SplitNonLeafNode(
       std::vector<double> minG1(maxG1.size());
       std::vector<double> maxG2(maxG1.size());
       std::vector<double> minG2(maxG1.size());
-      for(int k = 0; k < tree->Bound().Dim(); k++) {
-	minG1[k] = tree->Child(sorted[0].n)->Bound()[k].Lo();
+      for (int k = 0; k < tree->Bound().Dim(); k++) {
+        minG1[k] = tree->Child(sorted[0].n)->Bound()[k].Lo();
         maxG1[k] = tree->Child(sorted[0].n)->Bound()[k].Hi();
-	minG2[k] = tree->Child(sorted[sorted.size()-1])->Bound()[k].Lo();
-	maxG2[k] = tree->Child(sorted[sorted.size()-1])->Bound()[k].Hi();
-	for(int l = 1; l < tree->Count()-1; l++) {
-          if(l < cutOff) {
-	    if(tree->Child(sorted[l].n)->Bound()[k].Lo() < minG1[k])
-	      minG1[k] = tree->Child(sorted[l].n)->Bound()[k].Lo();
-	    else if(tree->Child(sorted[l].n)->Bound()[k].Hi() > maxG1[k])
-	      maxG1[k] = tree->Child(sorted[l].n)->Bound()[k].Hi();
-	  } else {
-	    if(tree->Child(sorted[l].n)->Bound()[k].Lo() < minG2[k])
-	      minG2[k] = tree->Child(sorted[l].n)->Bound()[k].Lo();
-	    else if(tree->Child(sorted[l].n)->Bound()[k].Hi() > maxG2[k])
-	      maxG2[k] = tree->Child(sorted[l].n)->Bound()[k].Hi();
+        minG2[k] = tree->Child(sorted[sorted.size() - 1].n)->Bound()[k].Lo();
+        maxG2[k] = tree->Child(sorted[sorted.size() - 1].n)->Bound()[k].Hi();
+        for (int l = 1; l < tree->NumChildren() - 1; l++) {
+          if (l < cutOff) {
+            if (tree->Child(sorted[l].n)->Bound()[k].Lo() < minG1[k])
+              minG1[k] = tree->Child(sorted[l].n)->Bound()[k].Lo();
+            else if (tree->Child(sorted[l].n)->Bound()[k].Hi() > maxG1[k])
+              maxG1[k] = tree->Child(sorted[l].n)->Bound()[k].Hi();
+          } else {
+            if (tree->Child(sorted[l].n)->Bound()[k].Lo() < minG2[k])
+              minG2[k] = tree->Child(sorted[l].n)->Bound()[k].Lo();
+            else if (tree->Child(sorted[l].n)->Bound()[k].Hi() > maxG2[k])
+              maxG2[k] = tree->Child(sorted[l].n)->Bound()[k].Hi();
           }
-	}
+        }
       }
       double area1 = 1.0, area2 = 1.0;
       double oArea = 1.0;
-      for(int k = 0; k < maxG1.size(); k++) {
-	margins[i] += maxG1[k] - minG1[k] + maxG2[k] - minG2[k];
-	area1 *= maxG1[k] - minG1[k];
-	area2 *= maxG2[k] - minG2[k];
-	oArea *= maxG1[k] < minG2[k] || maxG2[k] < minG1[k] ? 0.0 : std::min(maxG1[k], maxG2[k]) - std::max(minG1[k], minG2[k]);
+      for (int k = 0; k < maxG1.size(); k++) {
+        margins[i] += maxG1[k] - minG1[k] + maxG2[k] - minG2[k];
+        area1 *= maxG1[k] - minG1[k];
+        area2 *= maxG2[k] - minG2[k];
+        oArea *= maxG1[k] < minG2[k] || maxG2[k] < minG1[k] ? 0.0 : std::min(maxG1[k], maxG2[k]) - std::max(minG1[k], minG2[k]);
       }
       areas[i] += area1 + area2;
       overlapedAreas[i] += oArea;
       axisScore += margins[i];
     }
-    if(axisScore < bestAxisScore) {
+    if (axisScore < bestAxisScore) {
       bestAxisScore = axisScore;
       bestAxis = j;
       double bestOverlapIndexOnBestAxis = 0;
       double bestAreaIndexOnBestAxis = 0;
-      for(int i = 1; i < areas.size(); i++) {
-	if(overlapedAreas[i] < overlapedAreas[bestOverlapIndexOnBestAxis]) {
-	  tiedOnOverlap = false;
-	  bestAreaIndexOnBestAxis = i;
-	  bestOverlapIndexOnBestAxis = i;
-	}
-	else if(overlapedAreas[i] == overlapedAreas[bestOverlapIndexOnBestAxis]) {
-	  tiedOnOverlap = true;
-	  if(areas[i] < areas[bestAreaIndexOnBestAxis])
-	    bestAreaIndexOnBestAxis = i;
-	}
+      for (int i = 1; i < areas.size(); i++) {
+        if (overlapedAreas[i] < overlapedAreas[bestOverlapIndexOnBestAxis]) {
+          tiedOnOverlap = false;
+          bestAreaIndexOnBestAxis = i;
+          bestOverlapIndexOnBestAxis = i;
+        } else if (overlapedAreas[i] == overlapedAreas[bestOverlapIndexOnBestAxis]) {
+          tiedOnOverlap = true;
+          if (areas[i] < areas[bestAreaIndexOnBestAxis])
+            bestAreaIndexOnBestAxis = i;
+        }
       }
     }
   }
+
   //Now we do the same thing using Bound().Hi() and choose the best of the two.
-  for(int j = 0; j < tree->Bound().Dim(); j++) {
+  for (int j = 0; j < tree->Bound().Dim(); j++) {
     double axisScore = 0.0;
 
     // We'll do Bound().Lo() now and use Bound().Hi() later.
-    std::vector<sortStruct> sorted(tree->Count());
-    for(int i = 0; i < sorted.size(); i++) {
+    std::vector<sortStruct> sorted(tree->NumChildren());
+    for (int i = 0; i < sorted.size(); i++) {
       sorted[i].d = tree->Child(i)->Bound()[j].Hi();
       sorted[i].n = i;
     }
@@ -314,16 +312,16 @@ bool RStarTreeSplit<DescentType, StatisticType, MatType>::SplitNonLeafNode(
     std::sort(sorted.begin(), sorted.end(), structComp);
 
     // We'll store each of the three scores for each distribution.
-    std::vector<double> areas(tree->MaxNumChildren() - 2*tree->MinNumChildren() + 2);
-    std::vector<double> margins(tree->MaxNumChildren() - 2*tree->MinNumChildren() + 2);
-    std::vector<double> overlapedAreas(tree->MaxNumChildren() - 2*tree->MinNumChildren() + 2);
-    for(int i = 0; i < areas.size(); i++) {
+    std::vector<double> areas(tree->MaxNumChildren() - 2 * tree->MinNumChildren() + 2);
+    std::vector<double> margins(tree->MaxNumChildren() - 2 * tree->MinNumChildren() + 2);
+    std::vector<double> overlapedAreas(tree->MaxNumChildren() - 2 * tree->MinNumChildren() + 2);
+    for (int i = 0; i < areas.size(); i++) {
       areas[i] = 0.0;
       margins[i] = 0.0;
       overlapedAreas[i] = 0.0;
     }
 
-    for(int i = 0; i < areas.size(); i++) {
+    for (int i = 0; i < areas.size(); i++) {
       // The ith arrangement is obtained by placing the first tree->MinNumChildren() + i 
       // points in one rectangle and the rest in another.  Then we calculate the three
       // scores for that distribution.
@@ -334,70 +332,66 @@ bool RStarTreeSplit<DescentType, StatisticType, MatType>::SplitNonLeafNode(
       std::vector<double> minG1(maxG1.size());
       std::vector<double> maxG2(maxG1.size());
       std::vector<double> minG2(maxG1.size());
-      for(int k = 0; k < tree->Bound().Dim(); k++) {
-	minG1[k] = tree->Child(sorted[0].n)->Bound()[k].Lo();
+      for (int k = 0; k < tree->Bound().Dim(); k++) {
+        minG1[k] = tree->Child(sorted[0].n)->Bound()[k].Lo();
         maxG1[k] = tree->Child(sorted[0].n)->Bound()[k].Hi();
-	minG2[k] = tree->Child(sorted[sorted.size()-1])->Bound()[k].Lo();
-	maxG2[k] = tree->Child(sorted[sorted.size()-1])->Bound()[k].Hi();
-	for(int l = 1; l < tree->Count()-1; l++) {
-          if(l < cutOff) {
-	    if(tree->Child(sorted[l].n)->Bound()[k].Lo() < minG1[k])
-	      minG1[k] = tree->Child(sorted[l].n)->Bound()[k].Lo();
-	    else if(tree->Child(sorted[l].n)->Bound()[k].Hi() > maxG1[k])
-	      maxG1[k] = tree->Child(sorted[l].n)->Bound()[k].Hi();
-	  } else {
-	    if(tree->Child(sorted[l].n)->Bound()[k].Lo() < minG2[k])
-	      minG2[k] = tree->Child(sorted[l].n)->Bound()[k].Lo();
-	    else if(tree->Child(sorted[l].n)->Bound()[k].Hi() > maxG2[k])
-	      maxG2[k] = tree->Child(sorted[l].n)->Bound()[k].Hi();
+        minG2[k] = tree->Child(sorted[sorted.size() - 1].n)->Bound()[k].Lo();
+        maxG2[k] = tree->Child(sorted[sorted.size() - 1].n)->Bound()[k].Hi();
+        for (int l = 1; l < tree->NumChildren() - 1; l++) {
+          if (l < cutOff) {
+            if (tree->Child(sorted[l].n)->Bound()[k].Lo() < minG1[k])
+              minG1[k] = tree->Child(sorted[l].n)->Bound()[k].Lo();
+            else if (tree->Child(sorted[l].n)->Bound()[k].Hi() > maxG1[k])
+              maxG1[k] = tree->Child(sorted[l].n)->Bound()[k].Hi();
+          } else {
+            if (tree->Child(sorted[l].n)->Bound()[k].Lo() < minG2[k])
+              minG2[k] = tree->Child(sorted[l].n)->Bound()[k].Lo();
+            else if (tree->Child(sorted[l].n)->Bound()[k].Hi() > maxG2[k])
+              maxG2[k] = tree->Child(sorted[l].n)->Bound()[k].Hi();
           }
-	}
+        }
       }
       double area1 = 1.0, area2 = 1.0;
       double oArea = 1.0;
-      for(int k = 0; k < maxG1.size(); k++) {
-	margins[i] += maxG1[k] - minG1[k] + maxG2[k] - minG2[k];
-	area1 *= maxG1[k] - minG1[k];
-	area2 *= maxG2[k] - minG2[k];
-	oArea *= maxG1[k] < minG2[k] || maxG2[k] < minG1[k] ? 0.0 : std::min(maxG1[k], maxG2[k]) - std::max(minG1[k], minG2[k]);
+      for (int k = 0; k < maxG1.size(); k++) {
+        margins[i] += maxG1[k] - minG1[k] + maxG2[k] - minG2[k];
+        area1 *= maxG1[k] - minG1[k];
+        area2 *= maxG2[k] - minG2[k];
+        oArea *= maxG1[k] < minG2[k] || maxG2[k] < minG1[k] ? 0.0 : std::min(maxG1[k], maxG2[k]) - std::max(minG1[k], minG2[k]);
       }
       areas[i] += area1 + area2;
       overlapedAreas[i] += oArea;
       axisScore += margins[i];
     }
-    if(axisScore < bestAxisScore) {
+    if (axisScore < bestAxisScore) {
       bestAxisScore = axisScore;
       bestAxis = j;
       lowIsBest = false;
       double bestOverlapIndexOnBestAxis = 0;
       double bestAreaIndexOnBestAxis = 0;
-      for(int i = 1; i < areas.size(); i++) {
-	if(overlapedAreas[i] < overlapedAreas[bestOverlapIndexOnBestAxis]) {
-	  tiedOnOverlap = false;
-	  bestAreaIndexOnBestAxis = i;
-	  bestOverlapIndexOnBestAxis = i;
-	}
-	else if(overlapedAreas[i] == overlapedAreas[bestOverlapIndexOnBestAxis]) {
-	  tiedOnOverlap = true;
-	  if(areas[i] < areas[bestAreaIndexOnBestAxis])
-	    bestAreaIndexOnBestAxis = i;
-	}
+      for (int i = 1; i < areas.size(); i++) {
+        if (overlapedAreas[i] < overlapedAreas[bestOverlapIndexOnBestAxis]) {
+          tiedOnOverlap = false;
+          bestAreaIndexOnBestAxis = i;
+          bestOverlapIndexOnBestAxis = i;
+        } else if (overlapedAreas[i] == overlapedAreas[bestOverlapIndexOnBestAxis]) {
+          tiedOnOverlap = true;
+          if (areas[i] < areas[bestAreaIndexOnBestAxis])
+            bestAreaIndexOnBestAxis = i;
+        }
       }
     }
   }
 
-
-
-
   std::vector<sortStruct> sorted(tree->NumChildren());
-  if(lowIsBest) {
-    for(int i = 0; i < sorted.size(); i++) {
-      sorted[i].d = tree->Child()->Bound().Lo()[bestAxis];
+  if (lowIsBest) {
+    for (int i = 0; i < sorted.size(); i++) {
+      sorted[i].d = tree->Child(i)->Bound()[bestAxis].Lo();
       sorted[i].n = i;
     }
   } else {
-    for(int i = 0; i < sorted.size(); i++) {
-      sorted[i].d = tree->Child()->Bound().Hi()[bestAxis];
+    for (int i = 0; i < sorted.size(); i++) {
+      sorted[i].d = tree->Child(i)->Bound()[bestAxis].Hi();
       sorted[i].n = i;
     }
   }
@@ -409,19 +403,19 @@ bool RStarTreeSplit<DescentType, StatisticType, MatType>::SplitNonLeafNode(
   RectangleTree<RStarTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType> *treeTwo = new
           RectangleTree<RStarTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>(tree->Parent());
 
-  if(tiedOnOverlap) {
-    for(int i = 0; i < tree.Count(); i++) {
-      if(i < bestAreaIndexOnBestAxis)
-	treeOne->InsertPoint(tree->Points()[sorted[i].n]);
+  if (tiedOnOverlap) {
+    for (int i = 0; i < tree->NumChildren(); i++) {
+      if (i < bestAreaIndexOnBestAxis + tree->MinNumChildren())
+        InsertNodeIntoTree(treeOne, tree->Child(sorted[i].n));
       else
-	treeTwo->InsertPoint(tree->Points()[sorted[i].n]);
+        InsertNodeIntoTree(treeTwo, tree->Child(sorted[i].n));
     }
   } else {
-    for(int i = 0; i < tree.Count(); i++) {
-      if(i < bestOverlapIndexOnBestAxis)
-	treeOne->InsertPoint(tree->Points()[sorted[i].n]);
+    for (int i = 0; i < tree->NumChildren(); i++) {
+      if (i < bestOverlapIndexOnBestAxis + tree->MinNumChildren())
+        InsertNodeIntoTree(treeOne, tree->Child(sorted[i].n));
       else
-	treeTwo->InsertPoint(tree->Points()[sorted[i].n]);
+        InsertNodeIntoTree(treeTwo, tree->Child(sorted[i].n));
     }
   }
 
@@ -444,6 +438,15 @@ bool RStarTreeSplit<DescentType, StatisticType, MatType>::SplitNonLeafNode(
     SplitNonLeafNode(par);
   }
   
+  // We have to update the children of each of these new nodes so that they record the 
+  // correct parent.
+  for (int i = 0; i < treeOne->NumChildren(); i++) {
+    treeOne->Child(i)->Parent() = treeOne;
+  }
+  for (int i = 0; i < treeTwo->NumChildren(); i++) {
+    treeTwo->Child(i)->Parent() = treeTwo;
+  }
+
   assert(treeOne->Parent()->NumChildren() < treeOne->MaxNumChildren());
   assert(treeOne->Parent()->NumChildren() >= treeOne->MinNumChildren());
   assert(treeTwo->Parent()->NumChildren() < treeTwo->MaxNumChildren());
@@ -454,8 +457,19 @@ bool RStarTreeSplit<DescentType, StatisticType, MatType>::SplitNonLeafNode(
   return false;
 }
 
-
-
+/**
+ * Insert a node into another node.  Expanding the bounds and updating the numberOfChildren.
+ */
+template<typename DescentType,
+typename StatisticType,
+typename MatType>
+void RStarTreeSplit<DescentType, StatisticType, MatType>::InsertNodeIntoTree(
+        RectangleTree<RStarTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>* destTree,
+        RectangleTree<RStarTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>* srcNode)
+{
+  destTree->Bound() |= srcNode->Bound();
+  destTree->Child(destTree->NumChildren()++) = srcNode;
+}
 
 }; // namespace tree
 }; // namespace mlpack
diff --git a/src/mlpack/core/tree/rectangle_tree/r_tree_split.hpp b/src/mlpack/core/tree/rectangle_tree/r_tree_split.hpp
index 3b30834..e26475c 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_tree_split.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_tree_split.hpp
@@ -72,7 +72,7 @@ static void AssignNodeDestNode(
 /**
   * Insert a node into another node.
   */
-static void insertNodeIntoTree(
+static void InsertNodeIntoTree(
     RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>* destTree,
     RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>* srcNode);
 };
diff --git a/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
index 2e43bd4..f7f154f 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
@@ -387,8 +387,8 @@ void RTreeSplit<DescentType, StatisticType, MatType>::AssignNodeDestNode(
     }
   }
 
-  insertNodeIntoTree(treeOne, oldTree->Child(intI));
-  insertNodeIntoTree(treeTwo, oldTree->Child(intJ));
+  InsertNodeIntoTree(treeOne, oldTree->Child(intI));
+  InsertNodeIntoTree(treeTwo, oldTree->Child(intJ));
 
   // If intJ is the last node in the tree, we need to switch the order so that we remove the correct nodes.
   if (intI > intJ) {
@@ -470,10 +470,10 @@ void RTreeSplit<DescentType, StatisticType, MatType>::AssignNodeDestNode(
     // Assign the rectangle that causes the least increase in volume 
     // to the appropriate rectangle.
     if (bestRect == 1) {
-      insertNodeIntoTree(treeOne, oldTree->Child(bestIndex));
+      InsertNodeIntoTree(treeOne, oldTree->Child(bestIndex));
       numAssignTreeOne++;
     } else {
-      insertNodeIntoTree(treeTwo, oldTree->Child(bestIndex));
+      InsertNodeIntoTree(treeTwo, oldTree->Child(bestIndex));
       numAssignTreeTwo++;
     }
 
@@ -483,12 +483,12 @@ void RTreeSplit<DescentType, StatisticType, MatType>::AssignNodeDestNode(
   if (end > 0) {
     if (numAssignTreeOne < numAssignTreeTwo) {
       for (int i = 0; i < end; i++) {
-        insertNodeIntoTree(treeOne, oldTree->Child(i));
+        InsertNodeIntoTree(treeOne, oldTree->Child(i));
         numAssignTreeOne++;
       }
     } else {
       for (int i = 0; i < end; i++) {
-        insertNodeIntoTree(treeTwo, oldTree->Child(i));
+        InsertNodeIntoTree(treeTwo, oldTree->Child(i));
         numAssignTreeTwo++;
       }
     }
@@ -515,7 +515,7 @@ void RTreeSplit<DescentType, StatisticType, MatType>::AssignNodeDestNode(
 template<typename DescentType,
 typename StatisticType,
 typename MatType>
-void RTreeSplit<DescentType, StatisticType, MatType>::insertNodeIntoTree(
+void RTreeSplit<DescentType, StatisticType, MatType>::InsertNodeIntoTree(
         RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>* destTree,
         RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType, StatisticType, MatType>* srcNode)
 {
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
index 67cebf3..f58aed9 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
@@ -572,25 +572,25 @@ std::string RectangleTree<SplitType, DescentType, StatisticType, MatType>::ToStr
 {
   std::ostringstream convert;
   convert << "RectangleTree [" << this << "]" << std::endl;
-//  convert << "  First point: " << begin << std::endl;
-//  convert << "  Number of descendants: " << numChildren << std::endl;
-//  convert << "  Number of points: " << count << std::endl;
-//  convert << "  Bound: " << std::endl;
-//  convert << mlpack::util::Indent(bound.ToString(), 2);
-//  convert << "  Statistic: " << std::endl;
-//  //convert << mlpack::util::Indent(stat.ToString(), 2);
-//  convert << "  Max leaf size: " << maxLeafSize << std::endl;
-//  convert << "  Min leaf size: " << minLeafSize << std::endl;
-//  convert << "  Max num of children: " << maxNumChildren << std::endl;
-//  convert << "  Min num of children: " << minNumChildren << std::endl;
-//  convert << "  Parent address: " << parent << std::endl;
-//
-//  // How many levels should we print?  This will print 3 levels (counting the root).
-//  if (parent == NULL || parent->Parent() == NULL) {
-//    for (int i = 0; i < numChildren; i++) {
-//      convert << children[i]->ToString();
-//    }
-//  }
+  convert << "  First point: " << begin << std::endl;
+  convert << "  Number of descendants: " << numChildren << std::endl;
+  convert << "  Number of points: " << count << std::endl;
+  convert << "  Bound: " << std::endl;
+  convert << mlpack::util::Indent(bound.ToString(), 2);
+  convert << "  Statistic: " << std::endl;
+  //convert << mlpack::util::Indent(stat.ToString(), 2);
+  convert << "  Max leaf size: " << maxLeafSize << std::endl;
+  convert << "  Min leaf size: " << minLeafSize << std::endl;
+  convert << "  Max num of children: " << maxNumChildren << std::endl;
+  convert << "  Min num of children: " << minNumChildren << std::endl;
+  convert << "  Parent address: " << parent << std::endl;
+
+  // How many levels should we print?  This will print 3 levels (counting the root).
+  if (parent == NULL || parent->Parent() == NULL) {
+    for (int i = 0; i < numChildren; i++) {
+      convert << children[i]->ToString();
+    }
+  }
   return convert.str();
 }
 
diff --git a/src/mlpack/methods/neighbor_search/allknn_main.cpp b/src/mlpack/methods/neighbor_search/allknn_main.cpp
index 58de810..7510746 100644
--- a/src/mlpack/methods/neighbor_search/allknn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/allknn_main.cpp
@@ -272,7 +272,7 @@ int main(int argc, char *argv[])
       
       // Because we may construct it differently, we need a pointer.
       NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
-      RectangleTree<tree::RTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+      RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
 		    tree::RStarTreeDescentHeuristic,
 		    NeighborSearchStat<NearestNeighborSort>,
 		    arma::mat> >* allknn = NULL;
@@ -282,13 +282,13 @@ int main(int argc, char *argv[])
       Log::Info << "Building reference tree..." << endl;
       Timer::Start("tree_building");
 
-      RectangleTree<tree::RTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+      RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
 		    tree::RStarTreeDescentHeuristic,
 		    NeighborSearchStat<NearestNeighborSort>,
 		    arma::mat>
       refTree(referenceData, leafSize, leafSize/3, 5, 2, 0);
 
-      RectangleTree<tree::RTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+      RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
 		    tree::RStarTreeDescentHeuristic,
 		    NeighborSearchStat<NearestNeighborSort>,
 		    arma::mat>*
@@ -302,7 +302,7 @@ int main(int argc, char *argv[])
 	    << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
 
         allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
-        RectangleTree<tree::RTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+        RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
 	  	      tree::RStarTreeDescentHeuristic,
   		      NeighborSearchStat<NearestNeighborSort>,
   		      arma::mat> >(&refTree, queryTree,
@@ -310,7 +310,7 @@ int main(int argc, char *argv[])
       } else
       {
 	      allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
-      RectangleTree<tree::RTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+      RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
 		    tree::RStarTreeDescentHeuristic,
 		    NeighborSearchStat<NearestNeighborSort>,
 		    arma::mat> >(&refTree,
diff --git a/src/mlpack/tests/rectangle_tree_test.cpp b/src/mlpack/tests/rectangle_tree_test.cpp
index 5bf6208..a9325e4 100644
--- a/src/mlpack/tests/rectangle_tree_test.cpp
+++ b/src/mlpack/tests/rectangle_tree_test.cpp
@@ -201,6 +201,55 @@ BOOST_AUTO_TEST_CASE(PointDeletion) {
 
 }
 
+bool checkContainment(const RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+        tree::RStarTreeDescentHeuristic,
+        NeighborSearchStat<NearestNeighborSort>,
+        arma::mat>& tree) {
+  bool passed = true;
+  if (tree.NumChildren() == 0) {
+    for (size_t i = 0; i < tree.Count(); i++) {
+      passed &= tree.Bound().Contains(tree.Dataset().unsafe_col(tree.Points()[i]));
+      if(!passed)
+	std::cout << ".................PointContainmentFailed" << std::endl;
+    }
+  } else {
+    for (size_t i = 0; i < tree.NumChildren(); i++) {
+      bool p1 = true;
+      for (size_t j = 0; j < tree.Bound().Dim(); j++) {
+        p1 &= tree.Bound()[j].Contains(tree.Children()[i]->Bound()[j]);
+	      if(!p1)
+	std::cout << ".................BoundContainmentFailed" << std::endl;
+      }
+      passed &= p1;
+      passed &= checkContainment(*(tree.Child(i)));
+    }
+  }
+  return passed;
+}
+
+
+bool checkSync(const RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+        tree::RStarTreeDescentHeuristic,
+        NeighborSearchStat<NearestNeighborSort>,
+        arma::mat>& tree) {
+  if (tree.IsLeaf()) {
+    for (size_t i = 0; i < tree.Count(); i++) {
+      for (size_t j = 0; j < tree.LocalDataset().n_rows; j++) {
+        if (tree.LocalDataset().col(i)[j] != tree.Dataset().col(tree.Points()[i])[j]) {
+          return false;
+        }
+      }
+    }
+  } else {
+    for (size_t i = 0; i < tree.NumChildren(); i++) {
+      if (!checkSync(*tree.Child(i)))
+        return false;
+    }
+  }
+  return true;
+}
+
+
 BOOST_AUTO_TEST_CASE(SingleTreeTraverserTest) {
   arma::mat dataset;
   dataset.randu(8, 1000); // 1000 points in 8 dimensions.
@@ -209,19 +258,24 @@ BOOST_AUTO_TEST_CASE(SingleTreeTraverserTest) {
   arma::Mat<size_t> neighbors2;
   arma::mat distances2;
 
-  RectangleTree<tree::RTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+  RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
           tree::RStarTreeDescentHeuristic,
           NeighborSearchStat<NearestNeighborSort>,
           arma::mat> RTree(dataset, 20, 6, 5, 2, 0);
 
   // nearest neighbor search with the R tree.
   mlpack::neighbor::NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
-          RectangleTree<tree::RTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+          RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
           tree::RStarTreeDescentHeuristic,
           NeighborSearchStat<NearestNeighborSort>,
           arma::mat> > allknn1(&RTree,
           dataset, true);
 
+  assert(RTree.NumDescendants() == 1000);
+  assert(checkSync(RTree) == true);
+  assert(checkContainment(RTree) == true);
+
+
   allknn1.Search(5, neighbors1, distances1);
 
   // nearest neighbor search the naive way.



More information about the mlpack-git mailing list