[mlpack-git] master: RectangleTree:NumDescendants() optimization (fbded36)

gitdub at mlpack.org gitdub at mlpack.org
Wed Jun 29 07:10:49 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/eea2aa43b9b914b4d64f45bdf1f5358faefe2522...809ed4bf33cef9de8412fc167cb0e356a369e3b6

>---------------------------------------------------------------

commit fbded362ff8fb33b35faf4fafbbdb7326334bff2
Author: Mikhail Lozhnikov <lozhnikovma at gmail.com>
Date:   Wed Jun 29 14:10:49 2016 +0300

    RectangleTree:NumDescendants() optimization


>---------------------------------------------------------------

fbded362ff8fb33b35faf4fafbbdb7326334bff2
 .../rectangle_tree/hilbert_r_tree_split_impl.hpp   |  5 ++
 .../tree/rectangle_tree/r_star_tree_split_impl.hpp |  1 +
 .../core/tree/rectangle_tree/r_tree_split_impl.hpp |  1 +
 .../core/tree/rectangle_tree/rectangle_tree.hpp    |  2 +
 .../tree/rectangle_tree/rectangle_tree_impl.hpp    | 54 +++++++++++++++++-----
 .../core/tree/rectangle_tree/x_tree_split_impl.hpp |  1 +
 src/mlpack/tests/rectangle_tree_test.cpp           | 27 +++++++++++
 7 files changed, 80 insertions(+), 11 deletions(-)

diff --git a/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_split_impl.hpp
index 1fc6b28..7742e9d 100644
--- a/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_split_impl.hpp
@@ -231,10 +231,12 @@ RedistributeNodesEvenly(const TreeType *parent,
     // Since we redistribute children of a sibling we should recalculate the
     // bound.
     parent->Child(i).Bound().Clear();
+    parent->Child(i).numDescendants = 0;
 
     for (size_t j = 0; j < numChildrenPerNode; j++)
     {
       parent->Child(i).Bound() |= children[iChild]->Bound();
+      parent->Child(i).numDescendants += children[iChild]->numDescendants;
       parent->Child(i).children[j] = children[iChild];
       children[iChild]->Parent() = parent->children[i];
       iChild++;
@@ -242,6 +244,7 @@ RedistributeNodesEvenly(const TreeType *parent,
     if (numRestChildren > 0)
     {
       parent->Child(i).Bound() |= children[iChild]->Bound();
+      parent->Child(i).numDescendants += children[iChild]->numDescendants;
       parent->Child(i).children[numChildrenPerNode] = children[iChild];
       children[iChild]->Parent() = parent->children[i];
       parent->Child(i).NumChildren() = numChildrenPerNode + 1;
@@ -313,6 +316,8 @@ RedistributePointsEvenly(TreeType* parent,
     {
       parent->Child(i).Count() = numPointsPerNode;
     }
+    parent->Child(i).numDescendants = parent->Child(i).Count();
+
     assert(parent->Child(i).NumPoints() <=
            parent->Child(i).MaxLeafSize());
   }
diff --git a/src/mlpack/core/tree/rectangle_tree/r_star_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_star_tree_split_impl.hpp
index 12a7b4a..9347b87 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_star_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_star_tree_split_impl.hpp
@@ -675,6 +675,7 @@ template<typename TreeType>
 void RStarTreeSplit::InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode)
 {
   destTree->Bound() |= srcNode->Bound();
+  destTree->numDescendants += srcNode->numDescendants;
   destTree->children[destTree->NumChildren()++] = srcNode;
 }
 
diff --git a/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
index 98701ff..faf75f3 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
@@ -521,6 +521,7 @@ template<typename TreeType>
 void RTreeSplit::InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode)
 {
   destTree->Bound() |= srcNode->Bound();
+  destTree->numDescendants += srcNode->numDescendants;
   destTree->children[destTree->NumChildren()++] = srcNode;
 }
 
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
index 46bccfd..90ac546 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
@@ -76,6 +76,8 @@ class RectangleTree
   //! The number of points in the dataset contained in this node (and its
   //! children).
   size_t count;
+  //! The number of descendants of this node.
+  size_t numDescendants;
   //! The max leaf size.
   size_t maxLeafSize;
   //! The minimum leaf size.
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
index 0da0fd7..6c4b42f 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
@@ -37,6 +37,7 @@ RectangleTree(const MatType& data,
     parent(NULL),
     begin(0),
     count(0),
+    numDescendants(0),
     maxLeafSize(maxLeafSize),
     minLeafSize(minLeafSize),
     bound(data.n_rows),
@@ -76,6 +77,7 @@ RectangleTree(MatType&& data,
     parent(NULL),
     begin(0),
     count(0),
+    numDescendants(0),
     maxLeafSize(maxLeafSize),
     minLeafSize(minLeafSize),
     bound(data.n_rows),
@@ -114,6 +116,7 @@ RectangleTree(
     parent(parentNode),
     begin(0),
     count(0),
+    numDescendants(0),
     maxLeafSize(parentNode->MaxLeafSize()),
     minLeafSize(parentNode->MinLeafSize()),
     bound(parentNode->Bound().Dim()),
@@ -148,6 +151,7 @@ RectangleTree(
     parent(other.Parent()),
     begin(other.Begin()),
     count(other.Count()),
+    numDescendants(other.numDescendants),
     maxLeafSize(other.MaxLeafSize()),
     minLeafSize(other.MinLeafSize()),
     bound(other.bound),
@@ -269,6 +273,8 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
   // Expand the bound regardless of whether it is a leaf node.
   bound |= dataset->col(point);
 
+  numDescendants++;
+
   std::vector<bool> lvls(TreeDepth());
   for (size_t i = 0; i < lvls.size(); i++)
     lvls[i] = true;
@@ -306,6 +312,8 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
   // Expand the bound regardless of whether it is a leaf node.
   bound |= dataset->col(point);
 
+  numDescendants++;
+
   // If this is a leaf node, we stop here and add the point.
   if (numChildren == 0)
   {
@@ -345,6 +353,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
 {
   // Expand the bound regardless of the level.
   bound |= node->Bound();
+  numDescendants += node->numDescendants;
   if (level == TreeDepth())
   {
     if (!auxiliaryInfo.HandleNodeInsertion(this, node, true))
@@ -395,6 +404,12 @@ bool RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
         if (!auxiliaryInfo.HandlePointDeletion(this, i))
           points[i] = points[--count];
 
+        RectangleTree* tree = this;
+        while (tree != NULL)
+        {
+          tree->numDescendants--;
+          tree = tree->Parent();
+        }
         // This function wil ensure that minFill is satisfied.
         CondenseTree(dataset->col(point), lvls, true);
         return true;
@@ -433,6 +448,12 @@ bool RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
         if (!auxiliaryInfo.HandlePointDeletion(this, i))
           points[i] = points[--count];
 
+        RectangleTree* tree = this;
+        while (tree != NULL)
+        {
+          tree->numDescendants--;
+          tree = tree->Parent();
+        }
         // This function will ensure that minFill is satisfied.
         CondenseTree(dataset->col(point), relevels, true);
         return true;
@@ -471,6 +492,12 @@ bool RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
       {
         children[i] = children[--numChildren]; // Decrement numChildren.
       }
+      RectangleTree* tree = this;
+      while (tree != NULL)
+      {
+        tree->numDescendants -= node->numDescendants;
+        tree = tree->Parent();
+      }
       CondenseTree(arma::vec(), relevels, false);
       return true;
     }
@@ -613,17 +640,7 @@ template<typename MetricType,
 inline size_t RectangleTree<MetricType, StatisticType, MatType, SplitType,
                   DescentType, AuxiliaryInformationType>::NumDescendants() const
 {
-  if (numChildren == 0)
-  {
-    return count;
-  }
-  else
-  {
-    size_t n = 0;
-    for (size_t i = 0; i < numChildren; i++)
-      n += children[i]->NumDescendants();
-    return n;
-  }
+  return numDescendants;
 }
 
 /**
@@ -763,6 +780,13 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
         if (stillShrinking)
           stillShrinking = root->ShrinkBoundForBound(bound);
 
+        root = parent;
+        while (root != NULL)
+        {
+          root->numDescendants -= numDescendants;
+          root = root->Parent();
+        }
+
         stillShrinking = true;
         root = parent;
         while (root->Parent() != NULL)
@@ -817,6 +841,13 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
           if (stillShrinking)
             stillShrinking = root->ShrinkBoundForBound(bound);
 
+          root = parent;
+          while (root != NULL)
+          {
+            root->numDescendants -= numDescendants;
+            root = root->Parent();
+          }
+
           stillShrinking = true;
           root = parent;
           while (root->Parent() != NULL)
@@ -1068,6 +1099,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
 
   ar & CreateNVP(begin, "begin");
   ar & CreateNVP(count, "count");
+  ar & CreateNVP(numDescendants, "numDescendants");
   ar & CreateNVP(maxLeafSize, "maxLeafSize");
   ar & CreateNVP(minLeafSize, "minLeafSize");
   ar & CreateNVP(bound, "bound");
diff --git a/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp
index 8f8d144..8d8fc6d 100644
--- a/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp
@@ -840,6 +840,7 @@ template<typename TreeType>
 void XTreeSplit::InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode)
 {
   destTree->Bound() |= srcNode->Bound();
+  destTree->numDescendants += srcNode->numDescendants;
   destTree->children[destTree->NumChildren()] = srcNode;
   destTree->NumChildren()++;
 }
diff --git a/src/mlpack/tests/rectangle_tree_test.cpp b/src/mlpack/tests/rectangle_tree_test.cpp
index 5522a03..e6aedd9 100644
--- a/src/mlpack/tests/rectangle_tree_test.cpp
+++ b/src/mlpack/tests/rectangle_tree_test.cpp
@@ -316,6 +316,28 @@ int GetMinLevel(const TreeType& tree)
   return min;
 }
 
+/**
+ * A function to check that numDescendants values are set correctly.
+ */
+template<typename TreeType>
+size_t CheckNumDescendants(const TreeType& tree)
+{
+  if (tree.IsLeaf())
+  {
+    BOOST_REQUIRE_EQUAL(tree.NumDescendants(), tree.Count());
+    return tree.Count();
+  }
+
+  size_t numDescendants = 0;
+
+  for (size_t i = 0; i < tree.NumChildren(); i++)
+    numDescendants += CheckNumDescendants(tree.Child(i));
+
+  BOOST_REQUIRE_EQUAL(tree.NumDescendants(), numDescendants);
+
+  return numDescendants;
+}
+
 // A test to ensure that all leaf nodes are stored on the same level of the
 // tree.
 BOOST_AUTO_TEST_CASE(TreeBalance)
@@ -378,6 +400,7 @@ BOOST_AUTO_TEST_CASE(PointDeletion)
 
   CheckContainment(tree);
   CheckExactContainment(tree);
+  CheckNumDescendants(tree);
 
   // Single-tree search.
   NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>, arma::mat,
@@ -460,6 +483,7 @@ BOOST_AUTO_TEST_CASE(PointDynamicAdd)
   BOOST_REQUIRE_EQUAL(tree.NumDescendants(), 1000 + numIter);
   CheckContainment(tree);
   CheckExactContainment(tree);
+  CheckNumDescendants(tree);
 
   // Now we will compare the output of the R Tree vs the output of a naive
   // search.
@@ -510,6 +534,7 @@ BOOST_AUTO_TEST_CASE(SingleTreeTraverserTest)
   CheckContainment(rTree);
   CheckExactContainment(rTree);
   CheckHierarchy(rTree);
+  CheckNumDescendants(rTree);
 
   knn1.Search(5, neighbors1, distances1);
 
@@ -552,6 +577,7 @@ BOOST_AUTO_TEST_CASE(XTreeTraverserTest)
   CheckContainment(xTree);
   CheckExactContainment(xTree);
   CheckHierarchy(xTree);
+  CheckNumDescendants(xTree);
 
   knn1.Search(5, neighbors1, distances1);
 
@@ -592,6 +618,7 @@ BOOST_AUTO_TEST_CASE(HilbertRTreeTraverserTest)
   CheckContainment(hilbertRTree);
   CheckExactContainment(hilbertRTree);
   CheckHierarchy(hilbertRTree);
+  CheckNumDescendants(hilbertRTree);
 
   knn1.Search(5, neighbors1, distances1);
 




More information about the mlpack-git mailing list