[mlpack-git] master, mlpack-1.0.x: Rectangle tree and tests. Construction seems to work. (432fdc0)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:50:05 EST 2015


Repository : https://github.com/mlpack/mlpack

On branches: master,mlpack-1.0.x
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40

>---------------------------------------------------------------

commit 432fdc011d48aa1013428fcba97d33f827ed958a
Author: andrewmw94 <andrewmw94 at gmail.com>
Date:   Fri Jun 27 15:30:54 2014 +0000

    Rectangle tree and tests.  Construction seems to work.


>---------------------------------------------------------------

432fdc011d48aa1013428fcba97d33f827ed958a
 .../r_tree_descent_heuristic_impl.hpp              |   1 -
 .../core/tree/rectangle_tree/r_tree_split_impl.hpp | 139 ++++++++++++--------
 .../tree/rectangle_tree/rectangle_tree_impl.hpp    |  31 ++---
 src/mlpack/methods/neighbor_search/allknn_main.cpp |   2 +-
 src/mlpack/tests/CMakeLists.txt                    |   1 +
 src/mlpack/tests/rectangle_tree_test.cpp           | 142 +++++++++++++++++++++
 src/mlpack/tests/tree_test.cpp                     |   1 +
 src/mlpack/tests/tree_traits_test.cpp              |   1 +
 8 files changed, 244 insertions(+), 74 deletions(-)

diff --git a/src/mlpack/core/tree/rectangle_tree/r_tree_descent_heuristic_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_tree_descent_heuristic_impl.hpp
index bc4e701..7c15a22 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_tree_descent_heuristic_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_tree_descent_heuristic_impl.hpp
@@ -15,7 +15,6 @@ namespace tree {
 
 inline double RTreeDescentHeuristic::EvalNode(const HRectBound<>& bound, const arma::vec& point)
 {
-  std::cout << "eval node called" << std::endl;
   return bound.Contains(point) ? 0 : bound.MinDistance(point);
 }
 
diff --git a/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
index 79f75b4..32dbee0 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
@@ -13,6 +13,9 @@
 namespace mlpack {
 namespace tree {
 
+  //-r ../test_data_3_1000.csv -n neighbors_out.csv -d distances_out.csv -k 3 -v --r_tree
+  
+  
 /**
  * We call GetPointSeeds to get the two points which will be the initial points in the new nodes
  * We then call AssignPointDestNode to assign the remaining points to the two new nodes.
@@ -25,23 +28,19 @@ template<typename DescentType,
 void RTreeSplit<DescentType, StatisticType, MatType>::SplitLeafNode(
   RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType,  StatisticType, MatType>* tree)
 {
-  
-  std::cout << "splitting a leaf node." << std::endl;
-
   // If we are splitting the root node, we need will do things differently so that the constructor
   // and other methods don't confuse the end user by giving an address of another node.
   if(tree->Parent() == NULL) {
     RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType,  StatisticType, MatType>* copy =
       new RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType,  StatisticType, MatType>(*tree); // We actually want to copy this way.  Pointers and everything.
-      std::cout << "copy made ." << std::endl;
-
     copy->Parent() = tree;
     tree->Count() = 0;
     tree->Children()[(tree->NumChildren())++] = copy; // Because this was a leaf node, numChildren must be 0.
+    assert(tree->NumChildren() == 1);
     RTreeSplit<DescentType, StatisticType, MatType>::SplitLeafNode(copy);
-    std::cout << "finished split" << std::endl;
     return;
   }
+  assert(tree->Parent()->NumChildren() < tree->Parent()->MaxNumChildren()); 
   
   // Use the quadratic split method from: Guttman "R-Trees: A Dynamic Index Structure for
   // Spatial Searching"  It is simplified since we don't handle rectangles, only points.
@@ -50,22 +49,14 @@ void RTreeSplit<DescentType, StatisticType, MatType>::SplitLeafNode(
   int j = 0;
   GetPointSeeds(*tree, &i, &j);
   
-  std::cout << "point seeds found." << std::endl;
-  
   RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType,  StatisticType, MatType> *treeOne = new 
     RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType,  StatisticType, MatType>(tree->Parent());
   RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType,  StatisticType, MatType> *treeTwo = new 
     RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType,  StatisticType, MatType>(tree->Parent());
     
-  std::cout << "new trees made." << std::endl;
- 
-    
   // This will assign the ith and jth point appropriately.
   AssignPointDestNode(tree, treeOne, treeTwo, i, j);
   
-    std::cout << "assignments made." << std::endl;
-
-  
   //Remove this node and insert treeOne and treeTwo
   RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType,  StatisticType, MatType>* par = tree->Parent();
   int index = 0;
@@ -78,22 +69,20 @@ void RTreeSplit<DescentType, StatisticType, MatType>::SplitLeafNode(
   par->Children()[index] = treeOne;
   par->Children()[par->NumChildren()++] = treeTwo;
      
-  
-  std::cout << "points copied." << std::endl;
-
-      
-  //because we copied the points to treeOne and treeTwo, we can just delete this node
-  // I THOUGHT?
-  //delete tree;
+  // We need to delete this carefully since references to points are used.
+  tree->softDelete();
 
   // we only add one at a time, so we should only need to test for equality
   // just in case, we use an assert.
   assert(par->NumChildren() <= par->MaxNumChildren());
   if(par->NumChildren() == par->MaxNumChildren()) {
-    std::cout << "leaf split calls non-leaf split" << std::endl;
     SplitNonLeafNode(par);
   }
-  std::cout << "about to end leaf split." << std::endl;
+  
+  assert(treeOne->Parent()->NumChildren() < treeOne->MaxNumChildren());
+  assert(treeOne->Parent()->NumChildren() >= treeOne->MinNumChildren());
+  assert(treeTwo->Parent()->NumChildren() < treeTwo->MaxNumChildren());
+  assert(treeTwo->Parent()->NumChildren() >= treeTwo->MinNumChildren());
   return;
 }
 
@@ -110,56 +99,51 @@ template<typename DescentType,
 bool RTreeSplit<DescentType, StatisticType, MatType>::SplitNonLeafNode(
   RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType,  StatisticType, MatType>* tree)
 {  
-  std::cout << "splitting non-leaf node." << std::endl;
-
   // If we are splitting the root node, we need will do things differently so that the constructor
   // and other methods don't confuse the end user by giving an address of another node.
   if(tree->Parent() == NULL) {
-    std::cout << "root node" << std::endl;
     RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType,  StatisticType, MatType>* copy =
       new RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType,  StatisticType, MatType>(*tree); // We actually want to copy this way.  Pointers and everything.
     copy->Parent() = tree;
     tree->NumChildren() = 0;
     tree->Children()[(tree->NumChildren())++] = copy;
     RTreeSplit<DescentType, StatisticType, MatType>::SplitNonLeafNode(copy);
-    
-    std::cout << tree->ToString() << std::endl;
-    std::cout << "root split finished" << std::endl;
-    
     return true;
   }
 
-  std::cout << "about to get bound seeds" << std::endl;
   int i = 0;
   int j = 0;
   GetBoundSeeds(*tree, &i, &j);
   
-  std::cout << "bound seeds" << std::endl;
+  assert(i != j);
   
   RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType,  StatisticType, MatType>* treeOne = new 
     RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType,  StatisticType, MatType>(tree->Parent());
   RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType,  StatisticType, MatType>* treeTwo = new 
     RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType,  StatisticType, MatType>(tree->Parent());
 
-  std::cout << "new nodes created" << std::endl;
-
   // This will assign the ith and jth rectangles appropriately.
   AssignNodeDestNode(tree, treeOne, treeTwo, i, j);
 
-  std::cout << "nodes assigned" << std::endl;
-  
   //Remove this node and insert treeOne and treeTwo
   RectangleTree<RTreeSplit<DescentType, StatisticType, MatType>, DescentType,  StatisticType, MatType>* par = tree->Parent();
-  int index = 0;
+  int index = -1;
   for(int i = 0; i < par->NumChildren(); i++) {
     if(par->Children()[i] == tree) {
       index = i;
       break;
     }
   }
+  assert(index != -1);
   par->Children()[index] = treeOne;
   par->Children()[par->NumChildren()++] = treeTwo;
   
+  for(int i = 0; i < par->NumChildren(); i++) {
+    if(par->Children()[i] == tree) {
+      assert(par->Children()[i] != tree);
+    }
+  }
+
   // Because we now have pointers to the information stored under this tree,
   // we need to delete this node carefully.
   tree->softDelete(); //currently does nothing but leak memory.
@@ -181,7 +165,10 @@ bool RTreeSplit<DescentType, StatisticType, MatType>::SplitNonLeafNode(
     treeTwo->Children()[i]->Parent() = treeTwo;
   }
   
-  std::cout << "about to end split non-leaf" << std::endl;
+  assert(treeOne->NumChildren() < treeOne->MaxNumChildren());
+  assert(treeTwo->NumChildren() < treeTwo->MaxNumChildren());
+  assert(treeOne->Parent()->NumChildren() < treeOne->MaxNumChildren()); 
+
   return false;
 }
 
@@ -269,25 +256,26 @@ void RTreeSplit<DescentType, StatisticType, MatType>::AssignPointDestNode(
 {
   
   int end = oldTree->Count();
-  assert(end > 1); // If this isn't true, the tree is really weird.
   
+  assert(end > 1); // If this isn't true, the tree is really weird.
 
   // Restart the point counts since we are going to move them.
   oldTree->Count() = 0;
   treeOne->Count() = 0;
   treeTwo->Count() = 0;
 
-  std::cout << " about to assign i and j" << std::endl;
-  
   treeOne->InsertPoint(oldTree->Dataset().col(intI));
-      std::cout << "assignment of i made." << std::endl;
-
-  oldTree->Dataset().col(intI) = oldTree->Dataset().col(--end); // decrement end
   treeTwo->InsertPoint(oldTree->Dataset().col(intJ));
-  oldTree->Dataset().col(intJ) = oldTree->Dataset().col(--end); // decrement end
   
+  // If intJ is the last point in the tree, we need to switch the order so that we remove the correct points.
+  if(intI > intJ) {
+    oldTree->Dataset().col(intI) = oldTree->Dataset().col(--end); // decrement end
+    oldTree->Dataset().col(intJ) = oldTree->Dataset().col(--end); // decrement end
+  } else {
+    oldTree->Dataset().col(intJ) = oldTree->Dataset().col(--end); // decrement end
+    oldTree->Dataset().col(intI) = oldTree->Dataset().col(--end); // decrement end
+  }
     
-  std::cout << "i and j assigned" << std::endl;
     
   int numAssignedOne = 1;
   int numAssignedTwo = 1;
@@ -301,8 +289,6 @@ void RTreeSplit<DescentType, StatisticType, MatType>::AssignPointDestNode(
   // on the same iteration, we added the point to the node with fewer points anyways.
   while(end > 0 && end > oldTree->MinLeafSize() - std::min(numAssignedOne, numAssignedTwo)) {
 
-    std::cout << "while loop entered with end = "<< end << std::endl;
-    
     int bestIndex = 0;
     double bestScore = DBL_MAX;
     int bestRect = 1;
@@ -388,10 +374,43 @@ void RTreeSplit<DescentType, StatisticType, MatType>::AssignNodeDestNode(
   int end = oldTree->NumChildren();
   assert(end > 1); // If this isn't true, the tree is really weird.
 
-  treeOne->Children()[0] = oldTree->Children()[intI];
-  oldTree->Children()[intI] = oldTree->Children()[--end]; // decrement end
-  treeTwo->Children()[0] = oldTree->Children()[intJ];
-  oldTree->Children()[intJ] = oldTree->Children()[--end]; // decrement end
+  assert(intI != intJ);
+  
+  for(int i = 0; i < oldTree->NumChildren(); i++) {
+    for(int j = i+1; j < oldTree->NumChildren(); j++) {
+      assert(oldTree->Children()[i] != oldTree->Children()[j]);
+    }
+  }
+  
+  insertNodeIntoTree(treeOne, oldTree->Children()[intI]);
+  insertNodeIntoTree(treeTwo, oldTree->Children()[intJ]);
+  
+  // If intJ is the last node in the tree, we need to switch the order so that we remove the correct nodes.
+  if(intI > intJ) {
+    oldTree->Children()[intI] = oldTree->Children()[--end]; // decrement end
+    oldTree->Children()[intJ] = oldTree->Children()[--end]; // decrement end
+  } else {
+    oldTree->Children()[intJ] = oldTree->Children()[--end]; // decrement end
+    oldTree->Children()[intI] = oldTree->Children()[--end]; // decrement end
+  }
+
+  assert(treeOne->NumChildren() == 1);
+  assert(treeTwo->NumChildren() == 1);
+  
+  for(int i = 0; i < end; i++) {
+    for(int j = i+1; j < end; j++) {
+      assert(oldTree->Children()[i] != oldTree->Children()[j]);
+    }
+  }
+  
+  for(int i = 0; i < end; i++) {
+      assert(oldTree->Children()[i] != treeOne->Children()[0]);
+  }
+  
+  for(int i = 0; i < end; i++) {
+      assert(oldTree->Children()[i] != treeTwo->Children()[0]);
+  }
+  
   
   int numAssignTreeOne = 1;
   int numAssignTreeTwo = 1;
@@ -461,13 +480,29 @@ void RTreeSplit<DescentType, StatisticType, MatType>::AssignNodeDestNode(
     if(numAssignTreeOne < numAssignTreeTwo) {
       for(int i = 0; i < end; i++) {
         insertNodeIntoTree(treeOne, oldTree->Children()[i]);
+	numAssignTreeOne++;
       }
     } else {
       for(int i = 0; i < end; i++) {
         insertNodeIntoTree(treeTwo, oldTree->Children()[i]);
+	numAssignTreeTwo++;
       }
     }
   }
+  
+  for(int i = 0; i < treeOne->NumChildren(); i++) {
+    for(int j = i+1; j < treeOne->NumChildren(); j++) {
+      assert(treeOne->Children()[i] != treeOne->Children()[j]);
+    }
+  }
+  for(int i = 0; i < treeTwo->NumChildren(); i++) {
+    for(int j = i+1; j < treeTwo->NumChildren(); j++) {
+      assert(treeTwo->Children()[i] != treeTwo->Children()[j]);
+    }
+  }
+  assert(treeOne->NumChildren() == numAssignTreeOne);
+  assert(treeTwo->NumChildren() == numAssignTreeTwo);
+  assert(numAssignTreeOne+numAssignTreeTwo == 5);
 }
 
 /**
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
index c5a0ed7..6cd63ae 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
@@ -43,20 +43,13 @@ RectangleTree<SplitType, DescentType, StatisticType, MatType>::RectangleTree(
 {
   stat = StatisticType(*this);
   
-  std::cout << ToString() << std::endl;
-  
-  
   // For now, just insert the points in order.
   RectangleTree* root = this;
   
   //for(int i = firstDataIndex; i < 57; i++) { // 56,57 are the bound for where it works/breaks
   for(int i = firstDataIndex; i < data.n_cols; i++) {
-    std::cout << "inserting point number: " << i << std::endl;
     root->InsertPoint(data.col(i));
-    std::cout << "finished inserting point number: " << i << std::endl;
-    std::cout << ToString() << std::endl;
   }
-  
 }
 
 template<typename SplitType,
@@ -93,8 +86,6 @@ template<typename SplitType,
 RectangleTree<SplitType, DescentType, StatisticType, MatType>::
   ~RectangleTree()
 {
-  //LEAK MEMORY
-  
   for(int i = 0; i < numChildren; i++) {
     delete children[i];
   }
@@ -126,14 +117,11 @@ template<typename SplitType,
 void RectangleTree<SplitType, DescentType, StatisticType, MatType>::
     InsertPoint(const arma::vec& point)
 {
-  
-  std::cout << "insert point called" << std::endl;
   // Expand the bound regardless of whether it is a leaf node.
   bound |= point;
 
   // If this is a leaf node, we stop here and add the point.
   if(numChildren == 0) {
-    std::cout << "count = " << count << std::endl;
     dataset->col(count++) = point;
     SplitNode();
     return;
@@ -254,8 +242,7 @@ inline size_t RectangleTree<SplitType, DescentType, StatisticType, MatType>::
 }
 
 /**
- * Return the number of descendants contained in this node.  MEANINIGLESS AS IT CURRENTLY STANDS.
- * USE NumPoints() INSTEAD.
+ * Return the number of descendants under or in this node.
  */
 template<typename SplitType,
 	 typename DescentType,
@@ -264,7 +251,15 @@ template<typename SplitType,
 inline size_t RectangleTree<SplitType, DescentType, StatisticType, MatType>::
     NumDescendants() const
 {
-  return count;
+  if(numChildren == 0)
+    return count;
+  else {
+    size_t n = 0;
+    for(int i = 0; i < numChildren; i++) {
+      n += children[i]->NumDescendants();
+    }
+    return n;
+  }
 }
 
 /**
@@ -328,13 +323,9 @@ void RectangleTree<SplitType, DescentType, StatisticType, MatType>::SplitNode()
   if(count < maxLeafSize)
     return; // We don't need to split.
   
-  std::cout << "we are actually splitting the node." << std::endl;
   // If we are full, then we need to split (or at least try).  The SplitType takes
   // care of this and of moving up the tree if necessary.
   SplitType::SplitLeafNode(this);
-  std::cout << "we finished actually splitting the node." << std::endl;
-  
-  std::cout << ToString() << std::endl;
 }
 
 
@@ -362,7 +353,7 @@ std::string RectangleTree<SplitType, DescentType, StatisticType, MatType>::ToStr
   convert << "  Min num of children: " << minNumChildren << std::endl;
   convert << "  Parent address: " << parent << std::endl;
 
-  // How many levels should we print?  This will print the root and it's children.
+  // How many levels should we print?  This will print 3 levels (counting the root).
   if(parent == NULL || parent->Parent() == NULL) {
     for(int i = 0; i < numChildren; i++) {
       convert << children[i]->ToString();
diff --git a/src/mlpack/methods/neighbor_search/allknn_main.cpp b/src/mlpack/methods/neighbor_search/allknn_main.cpp
index 0541470..432c23e 100644
--- a/src/mlpack/methods/neighbor_search/allknn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/allknn_main.cpp
@@ -278,7 +278,7 @@ int main(int argc, char *argv[])
                       arma::mat>
         refTree(referenceData, leafSize, leafSize/3, 5, 2, 0);
       Timer::Stop("tree_building");
-      std::cout << "completed tree building" << std::endl;
+      std::cout << "completed tree building " << refTree.NumDescendants() << std::endl;
     }
   }
   else // Cover trees.
diff --git a/src/mlpack/tests/CMakeLists.txt b/src/mlpack/tests/CMakeLists.txt
index ec36031..c49b70f 100644
--- a/src/mlpack/tests/CMakeLists.txt
+++ b/src/mlpack/tests/CMakeLists.txt
@@ -38,6 +38,7 @@ add_executable(mlpack_test
   perceptron_test.cpp
   radical_test.cpp
   range_search_test.cpp
+  rectangle_tree_test.cpp
   save_restore_utility_test.cpp
   sgd_test.cpp
   sort_policy_test.cpp
diff --git a/src/mlpack/tests/rectangle_tree_test.cpp b/src/mlpack/tests/rectangle_tree_test.cpp
new file mode 100644
index 0000000..bbfa4ee
--- /dev/null
+++ b/src/mlpack/tests/rectangle_tree_test.cpp
@@ -0,0 +1,142 @@
+ 
+/**
+ * @file tree_traits_test.cpp
+ * @author Andrew Wells
+ *
+ * Tests for the RectangleTree class.  This should ensure that the class works correctly
+ * and that subsequent changes don't break anything.  Because it's only used to test the trees,
+ * it is slow.
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/core/tree/tree_traits.hpp>
+#include <mlpack/core/tree/rectangle_tree.hpp>
+#include <mlpack/methods/neighbor_search/neighbor_search.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace mlpack;
+using namespace mlpack::neighbor;
+using namespace mlpack::tree;
+using namespace mlpack::metric;
+
+BOOST_AUTO_TEST_SUITE(RectangleTreeTest);
+
+// Be careful!  When writing new tests, always get the boolean value and store
+// it in a temporary, because the Boost unit test macros do weird things and
+// will cause bizarre problems.
+
+// Test the traits on RectangleTrees.
+BOOST_AUTO_TEST_CASE(RectangeTreeTraitsTest)
+{
+  // Children may be overlapping.
+  bool b = TreeTraits<RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+                      tree::RTreeDescentHeuristic,
+                      NeighborSearchStat<NearestNeighborSort>,
+                      arma::mat> >::HasOverlappingChildren;
+  BOOST_REQUIRE_EQUAL(b, true);
+  
+  // Points are not contained in multiple levels.
+  b = TreeTraits<RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+                      tree::RTreeDescentHeuristic,
+                      NeighborSearchStat<NearestNeighborSort>,
+                      arma::mat> >::HasSelfChildren;
+  BOOST_REQUIRE_EQUAL(b, false);
+}
+
+BOOST_AUTO_TEST_CASE(RectangleTreeConstructionCountTest)
+{
+  arma::mat dataset;
+  dataset.randu(3, 1000); // 1000 points in 3 dimensions.
+  
+  RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+                      tree::RTreeDescentHeuristic,
+                      NeighborSearchStat<NearestNeighborSort>,
+                      arma::mat> tree(dataset, 20, 6, 5, 2, 0);
+  BOOST_REQUIRE_EQUAL(tree.NumDescendants(), 1000); 
+}
+
+std::vector<arma::vec*> getAllPointsInTree(const RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+                      tree::RTreeDescentHeuristic,
+                      NeighborSearchStat<NearestNeighborSort>,
+                      arma::mat>& tree)
+{
+  std::vector<arma::vec*> vec;
+  if(tree.NumChildren() > 0) {
+    for(size_t i = 0; i < tree.NumChildren(); i++) {
+      std::vector<arma::vec*> tmp = getAllPointsInTree(*(tree.Children()[i]));
+      vec.insert(vec.begin(), tmp.begin(), tmp.end());
+    }
+  } else {
+    for(size_t i = 0; i < tree.Count(); i++) {
+      arma::vec* c = new arma::vec(tree.Dataset().col(i)); 
+      vec.push_back(c);
+    }
+  }
+  return vec;
+}
+
+BOOST_AUTO_TEST_CASE(RectangleTreeConstructionRepeatTest)
+{
+  arma::mat dataset;
+  dataset.randu(8, 1000); // 1000 points in 8 dimensions.
+  
+  RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+                      tree::RTreeDescentHeuristic,
+                      NeighborSearchStat<NearestNeighborSort>,
+                      arma::mat> tree(dataset, 20, 6, 5, 2, 0);
+
+  std::vector<arma::vec*> allPoints = getAllPointsInTree(tree);
+  for(size_t i = 0; i < allPoints.size(); i++) {
+    for(size_t j = i+1; j < allPoints.size(); j++) {
+      arma::vec v1 = *(allPoints[i]);
+      arma::vec v2 = *(allPoints[j]);
+      bool same = true;
+      for(size_t k = 0; k < v1.n_rows; k++) {
+	same &= (v1[k] == v2[k]);
+      }
+      assert(same != true);
+    }
+  }
+  for(size_t i = 0; i < allPoints.size(); i++) {
+    delete allPoints[i];
+  }  
+}
+
+bool checkContainment(const RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+                      tree::RTreeDescentHeuristic,
+                      NeighborSearchStat<NearestNeighborSort>,
+                      arma::mat>& tree)
+{
+  bool passed = true;
+  if(tree.NumChildren() == 0) {
+    for(size_t i = 0; i < tree.Count(); i++) {
+      passed &= tree.Bound().Contains(tree.Dataset().unsafe_col(i));
+    }
+  } else {
+    for(size_t i = 0; i < tree.NumChildren(); i++) {
+      bool p1 = true;
+      for(size_t j = 0; j < tree.Bound().Dim(); j++) {
+	p1 &= tree.Bound()[j].Contains(tree.Children()[i]->Bound()[j]);
+      }
+      passed &= p1;
+      passed &= checkContainment(*(tree.Children()[i]));
+    }
+  }  
+  return passed;
+}
+
+BOOST_AUTO_TEST_CASE(RectangleTreeContainmentTest)
+{
+    arma::mat dataset;
+  dataset.randu(8, 1000); // 1000 points in 8 dimensions.
+  
+  RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+                      tree::RTreeDescentHeuristic,
+                      NeighborSearchStat<NearestNeighborSort>,
+                      arma::mat> tree(dataset, 20, 6, 5, 2, 0);
+  assert(checkContainment(tree) == true);
+}
+
+
+BOOST_AUTO_TEST_SUITE_END();
diff --git a/src/mlpack/tests/tree_test.cpp b/src/mlpack/tests/tree_test.cpp
index 7c442ba..a5bc474 100644
--- a/src/mlpack/tests/tree_test.cpp
+++ b/src/mlpack/tests/tree_test.cpp
@@ -8,6 +8,7 @@
 #include <mlpack/core/tree/binary_space_tree/binary_space_tree.hpp>
 #include <mlpack/core/metrics/lmetric.hpp>
 #include <mlpack/core/tree/cover_tree/cover_tree.hpp>
+#include <mlpack/core/tree/rectangle_tree.hpp>
 
 #include <queue>
 #include <stack>
diff --git a/src/mlpack/tests/tree_traits_test.cpp b/src/mlpack/tests/tree_traits_test.cpp
index df220fc..f4182b0 100644
--- a/src/mlpack/tests/tree_traits_test.cpp
+++ b/src/mlpack/tests/tree_traits_test.cpp
@@ -12,6 +12,7 @@
 #include <mlpack/core/tree/tree_traits.hpp>
 #include <mlpack/core/tree/binary_space_tree.hpp>
 #include <mlpack/core/tree/cover_tree.hpp>
+#include <mlpack/core/tree/rectangle_tree.hpp>
 
 #include <boost/test/unit_test.hpp>
 #include "old_boost_test_definitions.hpp"



More information about the mlpack-git mailing list