[mlpack-git] master: X tree (3317c56)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:57:56 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40

>---------------------------------------------------------------

commit 3317c56c5ba1c5e21ef309e7c8a5f83bc3eee802
Author: andrewmw94 <andrewmw94 at gmail.com>
Date:   Sat Aug 16 15:12:06 2014 +0000

    X tree


>---------------------------------------------------------------

3317c56c5ba1c5e21ef309e7c8a5f83bc3eee802
 .../tree/rectangle_tree/rectangle_tree_impl.hpp    |   6 +-
 .../core/tree/rectangle_tree/x_tree_split_impl.hpp |  57 +++++---
 src/mlpack/tests/rectangle_tree_test.cpp           | 147 ++++++++++++++++++++-
 3 files changed, 186 insertions(+), 24 deletions(-)

diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
index 9000f48..afbb5dc 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
@@ -665,8 +665,11 @@ void RectangleTree<SplitType, DescentType, StatisticType, MatType>::
       // If there are multiple children, we can't do anything to the root.
       RectangleTree<SplitType, DescentType, StatisticType, MatType>* child =
           children[0];
-      for (size_t i = 0; i < child->NumChildren(); i++)
+      for (size_t i = 0; i < child->NumChildren(); i++) {
         children[i] = child->Children()[i];
+        children[i]->Parent() = this;
+      }
+      
       numChildren = child->NumChildren();
 
       for (size_t i = 0; i < child->Count(); i++)
@@ -677,6 +680,7 @@ void RectangleTree<SplitType, DescentType, StatisticType, MatType>::
       }
 
       count = child->Count();
+      maxNumChildren = child->MaxNumChildren(); // Required for the X tree.
       child->SoftDelete();
       return;
     }
diff --git a/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp
index 1ea4018..f1530a2 100644
--- a/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp
@@ -233,10 +233,10 @@ void XTreeSplit<DescentType, StatisticType, MatType>::SplitLeafNode(
     SplitNonLeafNode(par, relevels);
   }
   
-  assert(treeOne->Parent()->NumChildren() <= treeOne->MaxNumChildren());
-  assert(treeOne->Parent()->NumChildren() >= treeOne->MinNumChildren());
-  assert(treeTwo->Parent()->NumChildren() <= treeTwo->MaxNumChildren());
-  assert(treeTwo->Parent()->NumChildren() >= treeTwo->MinNumChildren());
+  assert(treeOne->Parent()->NumChildren() <= treeOne->Parent()->MaxNumChildren());
+  assert(treeOne->Parent()->NumChildren() >= treeOne->Parent()->MinNumChildren());
+  assert(treeTwo->Parent()->NumChildren() <= treeTwo->Parent()->MaxNumChildren());
+  assert(treeTwo->Parent()->NumChildren() >= treeTwo->Parent()->MinNumChildren());
 
   tree->SoftDelete();
 
@@ -600,17 +600,41 @@ bool XTreeSplit<DescentType, StatisticType, MatType>::SplitNonLeafNode(
       
       
       
+      // We don't create a supernode that would be the only child of the root.
+      // (Note that if you did try to do so you would need to update the parent field on
+      // each child of this new node as creating a supernode causes the function to return
+      // before that is done.
       
+      // I thought commenting out the bellow would make the tree less efficient but would still work.
+      // It doesn't.  I should look into that to see if there is another bug.
       
-      // The min overlap split failed so we create a supernode instead.
+      
+      if(tree->Parent()->Parent() == NULL && tree->Parent()->NumChildren() == 1) {
+        // We make the root a supernode instead.
+        tree->Parent()->MaxNumChildren() *= 2;
+        tree->Parent()->Children().resize(tree->Parent()->MaxNumChildren()+1);
+        tree->Parent()->NumChildren() = tree->NumChildren();
+        for(int i = 0; i < tree->NumChildren(); i++) {
+          tree->Parent()->Children()[i] = tree->Children()[i];
+        }
+        delete treeOne;
+        delete treeTwo;
+        tree->NullifyData();
+        tree->SoftDelete();
+        return false;
+      }
+      
+      
+      
+      // If we don't have to worry about the root, we just enlarge this node.
       tree->MaxNumChildren() *= 2;
-      tree->MaxLeafSize() *= 2;
-      tree->LocalDataset().resize(tree->LocalDataset().n_rows, 2*tree->LocalDataset().n_cols);
       tree->Children().resize(tree->MaxNumChildren()+1);
-      tree->Points().resize(tree->MaxLeafSize()+1);
+      for(int i = 0; i < tree->NumChildren(); i++)
+        tree->Child(i).Parent() = tree;
+      delete treeOne;
+      delete treeTwo;
       
       return false;     
-      
     }
   }
 
@@ -637,11 +661,15 @@ bool XTreeSplit<DescentType, StatisticType, MatType>::SplitNonLeafNode(
       break;
     }
   }
+  
   par->Children()[index] = treeOne;
   par->Children()[par->NumChildren()++] = treeTwo;
 
   // we only add one at a time, so we should only need to test for equality
   // just in case, we use an assert.
+  
+  if(!(par->NumChildren() <= par->MaxNumChildren()+1))
+    std::cout<<"error " << par->NumChildren() << ", "<<par->MaxNumChildren()+1<<std::endl;
   assert(par->NumChildren() <= par->MaxNumChildren()+1);
   if (par->NumChildren() == par->MaxNumChildren()+1) {
     SplitNonLeafNode(par, relevels);
@@ -657,13 +685,10 @@ bool XTreeSplit<DescentType, StatisticType, MatType>::SplitNonLeafNode(
   }
   
   
-  assert(treeOne->Parent()->NumChildren() <= treeOne->MaxNumChildren());
-  assert(treeOne->Parent()->NumChildren() >= treeOne->MinNumChildren());
-  assert(treeTwo->Parent()->NumChildren() <= treeTwo->MaxNumChildren());
-  assert(treeTwo->Parent()->NumChildren() >= treeTwo->MinNumChildren());
-  
-  assert(treeOne->MaxNumChildren() < 7);
-  assert(treeTwo->MaxNumChildren() < 7);
+  assert(treeOne->Parent()->NumChildren() <= treeOne->Parent()->MaxNumChildren());
+  assert(treeOne->Parent()->NumChildren() >= treeOne->Parent()->MinNumChildren());
+  assert(treeTwo->Parent()->NumChildren() <= treeTwo->Parent()->MaxNumChildren());
+  assert(treeTwo->Parent()->NumChildren() >= treeTwo->Parent()->MinNumChildren());
 
   tree->SoftDelete();
   
diff --git a/src/mlpack/tests/rectangle_tree_test.cpp b/src/mlpack/tests/rectangle_tree_test.cpp
index d1210db..fd2175f 100644
--- a/src/mlpack/tests/rectangle_tree_test.cpp
+++ b/src/mlpack/tests/rectangle_tree_test.cpp
@@ -224,6 +224,86 @@ void checkExactContainment(const RectangleTree<tree::RStarTreeSplit<tree::RStarT
   }
 }
 
+/**
+ * A function to check that containment is as tight as possible.
+ */
+void checkExactContainment(const RectangleTree<tree::XTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+        tree::RStarTreeDescentHeuristic,
+        NeighborSearchStat<NearestNeighborSort>,
+        arma::mat>& tree) {
+  if(tree.NumChildren() == 0) {
+    for(size_t i = 0; i < tree.Bound().Dim(); i++) {
+      double min = DBL_MAX;
+      double max = -1.0 * DBL_MAX;
+      for(size_t j = 0; j < tree.Count(); j++) {
+	if(tree.LocalDataset().col(j)[i] < min)
+	  min = tree.LocalDataset().col(j)[i];
+	if(tree.LocalDataset().col(j)[i] > max)
+	  max = tree.LocalDataset().col(j)[i];
+      }
+      BOOST_REQUIRE_EQUAL(max, tree.Bound()[i].Hi());
+      BOOST_REQUIRE_EQUAL(min, tree.Bound()[i].Lo());
+    }
+  } else {
+    for(size_t i = 0; i < tree.Bound().Dim(); i++) {
+      double min = DBL_MAX;
+      double max = -1.0 * DBL_MAX;
+      for(size_t j = 0; j < tree.NumChildren(); j++) {
+	if(tree.Child(j).Bound()[i].Lo() < min)
+	  min = tree.Child(j).Bound()[i].Lo();
+	if(tree.Child(j).Bound()[i].Hi() > max)
+	  max = tree.Child(j).Bound()[i].Hi();
+      }
+      BOOST_REQUIRE_EQUAL(max, tree.Bound()[i].Hi());
+      BOOST_REQUIRE_EQUAL(min, tree.Bound()[i].Lo());
+    }
+    for(size_t i = 0; i < tree.NumChildren(); i++)
+      checkExactContainment(tree.Child(i));
+  }
+}
+
+/**
+ * A function to check that parents and children are set correctly.
+ */
+void checkHierarchy(const RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+        tree::RTreeDescentHeuristic,
+        NeighborSearchStat<NearestNeighborSort>,
+        arma::mat>& tree) {
+    for(size_t i = 0; i < tree.NumChildren(); i++) {
+      BOOST_REQUIRE_EQUAL(&tree, tree.Child(i).Parent());
+      checkHierarchy(tree.Child(i));
+    }
+}
+
+/**
+ * A function to check that parents and children are set correctly.
+ */
+void checkHierarchy(const RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+        tree::RStarTreeDescentHeuristic,
+        NeighborSearchStat<NearestNeighborSort>,
+        arma::mat>& tree) {
+    for(size_t i = 0; i < tree.NumChildren(); i++) {
+      BOOST_REQUIRE_EQUAL(&tree, tree.Child(i).Parent());
+      checkHierarchy(tree.Child(i));
+    }
+}
+
+/**
+ * A function to check that parents and children are set correctly.
+ */
+void checkHierarchy(const RectangleTree<tree::XTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+        tree::RStarTreeDescentHeuristic,
+        NeighborSearchStat<NearestNeighborSort>,
+        arma::mat>& tree) {
+    for(size_t i = 0; i < tree.NumChildren(); i++) {
+      BOOST_REQUIRE_EQUAL(&tree, tree.Child(i).Parent());
+      checkHierarchy(tree.Child(i));
+    }
+}
+
+
+
+
 // Test to see if the bounds of the tree are correct. (Cover all bounds and points
 // beneath this node of the tree).
 BOOST_AUTO_TEST_CASE(RectangleTreeContainmentTest) {
@@ -568,6 +648,29 @@ void checkSync(const RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHe
   return;
 }
 
+/**
+ * A function to ensure that the dataset for the tree, and the datasets stored
+ * in each leaf node are in sync.
+ * @param tree The tree to check.
+ */
+void checkSync(const RectangleTree<tree::XTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+        tree::RStarTreeDescentHeuristic,
+        NeighborSearchStat<NearestNeighborSort>,
+        arma::mat>& tree) {
+  if (tree.IsLeaf()) {
+    for (size_t i = 0; i < tree.Count(); i++) {
+      for (size_t j = 0; j < tree.LocalDataset().n_rows; j++) {
+        BOOST_REQUIRE_EQUAL(tree.LocalDataset().col(i)[j], tree.Dataset().col(tree.Points()[i])[j]);
+      }
+    }
+  } else {
+    for (size_t i = 0; i < tree.NumChildren(); i++) {
+      checkSync(*tree.Children()[i]);
+    }
+  }
+  return;
+}
+
 // A test to ensure that the SingleTreeTraverser is working correctly by comparing
 // its results to the results of a naive search.
 BOOST_AUTO_TEST_CASE(SingleTreeTraverserTest) {
@@ -595,6 +698,7 @@ BOOST_AUTO_TEST_CASE(SingleTreeTraverserTest) {
   checkSync(RTree);
   checkContainment(RTree);
   checkExactContainment(RTree);
+  checkHierarchy(RTree);
 
   allknn1.Search(5, neighbors1, distances1);
 
@@ -631,14 +735,25 @@ BOOST_AUTO_TEST_CASE(SingleTreeTraverserTest) {
 
 
 
-/*
+
+
+
+
+
+
+
+
+
 
 
 // A test to ensure that the SingleTreeTraverser is working correctly by comparing
 // its results to the results of a naive search.
 BOOST_AUTO_TEST_CASE(XTreeTraverserTest) {
   arma::mat dataset;
-  dataset.randu(8, 1000); // 1000 points in 8 dimensions.
+  
+  const int numP = 1000;
+  
+  dataset.randu(8, numP); // 1000 points in 8 dimensions.
   arma::Mat<size_t> neighbors1;
   arma::mat distances1;
   arma::Mat<size_t> neighbors2;
@@ -657,10 +772,11 @@ BOOST_AUTO_TEST_CASE(XTreeTraverserTest) {
           arma::mat> > allknn1(&RTree,
           dataset, true);
 
-  BOOST_REQUIRE_EQUAL(RTree.NumDescendants(), 1000);
-//   checkSync(RTree);
-//   checkContainment(RTree);
-//   checkExactContainment(RTree);
+  BOOST_REQUIRE_EQUAL(RTree.NumDescendants(), numP);
+   checkSync(RTree);
+   //checkContainment(RTree);
+   checkExactContainment(RTree);
+   checkHierarchy(RTree);
 
   allknn1.Search(5, neighbors1, distances1);
 
@@ -674,13 +790,30 @@ BOOST_AUTO_TEST_CASE(XTreeTraverserTest) {
     BOOST_REQUIRE_EQUAL(neighbors1[i], neighbors2[i]);
     BOOST_REQUIRE_EQUAL(distances1[i], distances2[i]);
   }
+  
+  //std::cout<<""<<RTree.ToString()<<std::endl;
 }
 
 
 
 
 
-*/
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
 
 
 



More information about the mlpack-git mailing list