[mlpack-git] master: Use the RectangleTree::Split() method instead of the friend class. Move the normalNodeMaxNumChildren to the XTreeSplit. (6008c15)

gitdub at mlpack.org gitdub at mlpack.org
Fri Apr 22 13:21:23 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/4c8a8d1ccbc33916794fe0f6142fa5378ff8503d...93eb34cc4daf08cc331d1a22fccf44950b276daa

>---------------------------------------------------------------

commit 6008c15f62263229eec240cec5487fc59062b993
Author: Mikhail Lozhnikov <lozhnikovma at gmail.com>
Date:   Fri Apr 22 20:21:23 2016 +0300

    Use the RectangleTree::Split() method instead of the friend class.
    Move the normalNodeMaxNumChildren to the XTreeSplit.


>---------------------------------------------------------------

6008c15f62263229eec240cec5487fc59062b993
 .../core/tree/rectangle_tree/r_star_tree_split.hpp |  8 +++-
 .../tree/rectangle_tree/r_star_tree_split_impl.hpp | 25 ++++++++---
 .../core/tree/rectangle_tree/r_tree_split.hpp      |  8 +++-
 .../core/tree/rectangle_tree/r_tree_split_impl.hpp | 22 ++++++++--
 .../core/tree/rectangle_tree/rectangle_tree.hpp    | 14 +++----
 .../tree/rectangle_tree/rectangle_tree_impl.hpp    |  8 +---
 .../core/tree/rectangle_tree/x_tree_split.hpp      | 20 ++++++++-
 .../core/tree/rectangle_tree/x_tree_split_impl.hpp | 49 +++++++++++++++++-----
 8 files changed, 116 insertions(+), 38 deletions(-)

diff --git a/src/mlpack/core/tree/rectangle_tree/r_star_tree_split.hpp b/src/mlpack/core/tree/rectangle_tree/r_star_tree_split.hpp
index 3970f3a..8ef11ec 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_star_tree_split.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_star_tree_split.hpp
@@ -25,9 +25,15 @@ class RStarTreeSplit
   //! Default constructor
   RStarTreeSplit();
 
-  //! Construct this with specified node.
+  //! Construct this with the specified node.
   RStarTreeSplit(TreeType *node);
 
+  //! Construct this with the specified node and the parent of the node.
+  RStarTreeSplit(TreeType *node,const TreeType *parentNode);
+
+  //! Create a copy of the other.split.
+  RStarTreeSplit(TreeType *node,const TreeType &other);
+
   /**
    * Split a leaf node using the algorithm described in "The R*-tree: An
    * Efficient and Robust Access method for Points and Rectangles."  If
diff --git a/src/mlpack/core/tree/rectangle_tree/r_star_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_star_tree_split_impl.hpp
index a2af8aa..41994db 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_star_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_star_tree_split_impl.hpp
@@ -29,6 +29,21 @@ RStarTreeSplit<TreeType>::RStarTreeSplit(TreeType *node) :
 
 }
 
+template<typename TreeType>
+RStarTreeSplit<TreeType>::RStarTreeSplit(TreeType *node,const TreeType *) :
+    tree(node)
+{
+
+}
+
+template<typename TreeType>
+RStarTreeSplit<TreeType>::RStarTreeSplit(TreeType *node,const TreeType &) :
+    tree(node)
+{
+
+}
+
+
 /**
  * We call GetPointSeeds to get the two points which will be the initial points
  * in the new nodes We then call AssignPointDestNode to assign the remaining
@@ -55,7 +70,7 @@ void RStarTreeSplit<TreeType>::SplitLeafNode(std::vector<bool>& relevels)
     tree->Children()[(tree->NumChildren())++] = copy;
     assert(tree->NumChildren() == 1);
 
-    copy->SplitNode(relevels);
+    copy->Split().SplitLeafNode(relevels);
     return;
   }
 
@@ -72,7 +87,7 @@ void RStarTreeSplit<TreeType>::SplitLeafNode(std::vector<bool>& relevels)
     size_t p = tree->MaxLeafSize() * 0.3; // The paper says this works the best.
     if (p == 0)
     {
-      tree->SplitNode(relevels);
+      tree->Split().SplitLeafNode(relevels);
       return;
     }
 
@@ -265,7 +280,7 @@ void RStarTreeSplit<TreeType>::SplitLeafNode(std::vector<bool>& relevels)
   // just in case, we use an assert.
   assert(par->NumChildren() <= par->MaxNumChildren() + 1);
   if (par->NumChildren() == par->MaxNumChildren() + 1)
-    par->SplitNode(relevels);
+    par->Split().SplitNonLeafNode(relevels);
 
   assert(treeOne->Parent()->NumChildren() <= treeOne->MaxNumChildren());
   assert(treeOne->Parent()->NumChildren() >= treeOne->MinNumChildren());
@@ -301,7 +316,7 @@ bool RStarTreeSplit<TreeType>::SplitNonLeafNode(std::vector<bool>& relevels)
     tree->NullifyData();
     tree->Children()[(tree->NumChildren())++] = copy;
 
-    copy->SplitNode(relevels);
+    copy->Split().SplitNonLeafNode(relevels);
     return true;
   }
 
@@ -657,7 +672,7 @@ bool RStarTreeSplit<TreeType>::SplitNonLeafNode(std::vector<bool>& relevels)
   assert(par->NumChildren() <= par->MaxNumChildren() + 1);
   if (par->NumChildren() == par->MaxNumChildren() + 1)
   {
-    par->SplitNode(relevels);
+    par->Split().SplitNonLeafNode(relevels);
   }
 
   // We have to update the children of each of these new nodes so that they
diff --git a/src/mlpack/core/tree/rectangle_tree/r_tree_split.hpp b/src/mlpack/core/tree/rectangle_tree/r_tree_split.hpp
index 8a412ad..c72a2e2 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_tree_split.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_tree_split.hpp
@@ -25,9 +25,15 @@ class RTreeSplit
   //! Default constructor
   RTreeSplit();
 
-  //! Construct this with specified node.
+  //! Construct this with the specified node.
   RTreeSplit(TreeType *node);
 
+  //! Construct this with the specified node and the parent of the node.
+  RTreeSplit(TreeType *node,const TreeType *parentNode);
+
+  //! Create a copy of the other.split.
+  RTreeSplit(TreeType *node,const TreeType &other);
+
   /**
    * Split a leaf node using the "default" algorithm.  If necessary, this split
    * will propagate upwards through the tree.
diff --git a/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
index 5bc0461..6cb925d 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
@@ -28,6 +28,20 @@ RTreeSplit<TreeType>::RTreeSplit(TreeType *node) :
 
 }
 
+template<typename TreeType>
+RTreeSplit<TreeType>::RTreeSplit(TreeType *node,const TreeType *) :
+    tree(node)
+{
+
+}
+
+template<typename TreeType>
+RTreeSplit<TreeType>::RTreeSplit(TreeType *node,const TreeType &) :
+    tree(node)
+{
+
+}
+
 /**
  * We call GetPointSeeds to get the two points which will be the initial points
  * in the new nodes We then call AssignPointDestNode to assign the remaining
@@ -49,7 +63,7 @@ void RTreeSplit<TreeType>::SplitLeafNode(std::vector<bool>& relevels)
     tree->NullifyData();
     // Because this was a leaf node, numChildren must be 0.
     tree->Children()[(tree->NumChildren())++] = copy;
-    copy->SplitNode(relevels);
+    copy->Split().SplitLeafNode(relevels);
     return;
   }
 
@@ -80,7 +94,7 @@ void RTreeSplit<TreeType>::SplitLeafNode(std::vector<bool>& relevels)
   // just in case, we use an assert.
   assert(par->NumChildren() <= par->MaxNumChildren() + 1);
   if (par->NumChildren() == par->MaxNumChildren() + 1)
-    par->SplitNode(relevels);
+    par->Split().SplitNonLeafNode(relevels);
 
   assert(treeOne->Parent()->NumChildren() <= treeOne->MaxNumChildren());
   assert(treeOne->Parent()->NumChildren() >= treeOne->MinNumChildren());
@@ -112,7 +126,7 @@ bool RTreeSplit<TreeType>::SplitNonLeafNode(std::vector<bool>& relevels)
     tree->NumChildren() = 0;
     tree->NullifyData();
     tree->Children()[(tree->NumChildren())++] = copy;
-    copy->SplitNode(relevels);
+    copy->Split().SplitNonLeafNode(relevels);
     return true;
   }
 
@@ -145,7 +159,7 @@ bool RTreeSplit<TreeType>::SplitNonLeafNode(std::vector<bool>& relevels)
   assert(par->NumChildren() <= par->MaxNumChildren() + 1);
 
   if (par->NumChildren() == par->MaxNumChildren() + 1)
-    par->SplitNode(relevels);
+    par->Split().SplitNonLeafNode(relevels);
 
   // We have to update the children of each of these new nodes so that they
   // record the correct parent.
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
index 090a832..8930998 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
@@ -43,7 +43,6 @@ template<typename MetricType = metric::EuclideanDistance,
          typename DescentType = RTreeDescentHeuristic>
 class RectangleTree
 {
-  friend class SplitType<RectangleTree>;
   // The metric *must* be the euclidean distance.
   static_assert(boost::is_same<MetricType, metric::EuclideanDistance>::value,
       "RectangleTree: MetricType must be metric::EuclideanDistance.");
@@ -78,9 +77,6 @@ class RectangleTree
   } SplitHistoryStruct;
 
  private:
-  //! The max number of child nodes a non-leaf normal node can have. 
-  //! (used in x-trees)
-  size_t normalNodeMaxNumChildren;
   //! The max number of child nodes a non-leaf node can have.
   size_t maxNumChildren;
   //! The minimum number of child nodes a non-leaf node can have.
@@ -325,6 +321,11 @@ class RectangleTree
   //! Modify the split history object of this node.
   SplitHistoryStruct& SplitHistory() { return splitHistory; }
 
+  //! Return the split object of this node.
+  const SplitType<RectangleTree>& Split() const { return split; }
+  //! Modify the split object of this node.
+  SplitType<RectangleTree>& Split() { return split; }
+
   //! Return whether or not this node is a leaf (true if it has no children).
   bool IsLeaf() const;
 
@@ -338,11 +339,6 @@ class RectangleTree
   //! Modify the minimum leaf size.
   size_t& MinLeafSize() { return minLeafSize; }
 
-  //! Return the maximum number of a normal node's children (used in x-trees).
-  size_t NormalNodeMaxNumChildren() const { return normalNodeMaxNumChildren; }
-  //! Modify the maximum number of a normal node's children (used in x-trees).
-  size_t& NormalNodeMaxNumChildren() { return normalNodeMaxNumChildren; }
-
   //! Return the maximum number of children (in a non-leaf node).
   size_t MaxNumChildren() const { return maxNumChildren; }
   //! Modify the maximum number of children (in a non-leaf node).
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
index cca78de..57916c8 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
@@ -29,7 +29,6 @@ RectangleTree(const MatType& data,
               const size_t maxNumChildren,
               const size_t minNumChildren,
               const size_t firstDataIndex) :
-    normalNodeMaxNumChildren(maxNumChildren),
     maxNumChildren(maxNumChildren),
     minNumChildren(minNumChildren),
     numChildren(0),
@@ -71,7 +70,6 @@ RectangleTree(MatType&& data,
               const size_t maxNumChildren,
               const size_t minNumChildren,
               const size_t firstDataIndex) :
-    normalNodeMaxNumChildren(maxNumChildren),
     maxNumChildren(maxNumChildren),
     minNumChildren(minNumChildren),
     numChildren(0),
@@ -110,7 +108,6 @@ RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
 RectangleTree(
     RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>*
         parentNode,const size_t numMaxChildren) :
-    normalNodeMaxNumChildren(parentNode->NormalNodeMaxNumChildren()),
     maxNumChildren(numMaxChildren > 0 ? numMaxChildren : parentNode->MaxNumChildren()),
     minNumChildren(parentNode->MinNumChildren()),
     numChildren(0),
@@ -130,7 +127,7 @@ RectangleTree(
                                                   maxLeafSize + 1)))
 {
   stat = StatisticType(*this);
-  split = SplitType<RectangleTree>(this);
+  split = SplitType<RectangleTree>(this,parentNode);
 }
 
 /**
@@ -146,7 +143,6 @@ RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
 RectangleTree(
     const RectangleTree& other,
     const bool deepCopy) :
-    normalNodeMaxNumChildren(other.NormalNodeMaxNumChildren()),
     maxNumChildren(other.MaxNumChildren()),
     minNumChildren(other.MinNumChildren()),
     numChildren(other.NumChildren()),
@@ -164,7 +160,7 @@ RectangleTree(
     points(other.Points()),
     localDataset(NULL)
 {
-  split = SplitType<RectangleTree>(this);
+  split = SplitType<RectangleTree>(this,other);
   if (deepCopy)
   {
     if (numChildren > 0)
diff --git a/src/mlpack/core/tree/rectangle_tree/x_tree_split.hpp b/src/mlpack/core/tree/rectangle_tree/x_tree_split.hpp
index 3264353..884034c 100644
--- a/src/mlpack/core/tree/rectangle_tree/x_tree_split.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/x_tree_split.hpp
@@ -35,9 +35,18 @@ class XTreeSplit
   //! Default constructor
   XTreeSplit();
 
-  //! Construct this with specified node.
+  //! Construct this with the specified node.
   XTreeSplit(TreeType *node);
 
+  //! Construct this with the specified node and the specified normalNodeMaxNumChildren.
+  XTreeSplit(TreeType *node,const size_t normalNodeMaxNumChildren);
+
+  //! Construct this with the specified node and the parent of the node.
+  XTreeSplit(TreeType *node,const TreeType *parentNode);
+
+  //! Create a copy of the other.split.
+  XTreeSplit(TreeType *node,const TreeType &other);
+
   /**
    * Split a leaf node using the algorithm described in "The R*-tree: An
    * Efficient and Robust Access method for Points and Rectangles."  If
@@ -55,6 +64,9 @@ class XTreeSplit
   //! The node which has to be split.
   TreeType *tree;
 
+  //! The max number of child nodes a non-leaf normal node can have.
+  size_t normalNodeMaxNumChildren;
+
   /**
    * Class to allow for faster sorting.
    */
@@ -80,6 +92,12 @@ class XTreeSplit
    * Insert a node into another node.
    */
   static void InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode);
+
+ public:
+  //! Return the maximum number of a normal node's children.
+  size_t NormalNodeMaxNumChildren() const { return normalNodeMaxNumChildren; }
+  //! Modify the maximum number of a normal node's children.
+  size_t& NormalNodeMaxNumChildren() { return normalNodeMaxNumChildren; }
 };
 
 } // namespace tree
diff --git a/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp
index 685b26a..a8f803b 100644
--- a/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp
@@ -16,14 +16,41 @@ namespace tree {
 
 template<typename TreeType>
 XTreeSplit<TreeType>::XTreeSplit() :
-    tree(NULL)
+    tree(NULL),
+    normalNodeMaxNumChildren(0)
 {
 
 }
 
 template<typename TreeType>
 XTreeSplit<TreeType>::XTreeSplit(TreeType *node) :
-    tree(node)
+    tree(node),
+    normalNodeMaxNumChildren(node->MaxNumChildren())
+{
+
+}
+
+
+template<typename TreeType>
+XTreeSplit<TreeType>::XTreeSplit(TreeType *node,const size_t normalNodeMaxNumChildren) :
+    tree(node),
+    normalNodeMaxNumChildren(normalNodeMaxNumChildren)
+{
+
+}
+
+template<typename TreeType>
+XTreeSplit<TreeType>::XTreeSplit(TreeType *node,const TreeType *parent) :
+    tree(node),
+    normalNodeMaxNumChildren(parent->Split().NormalNodeMaxNumChildren())
+{
+
+}
+
+template<typename TreeType>
+XTreeSplit<TreeType>::XTreeSplit(TreeType *node,const TreeType &other) :
+    tree(node),
+    normalNodeMaxNumChildren(other.Split().NormalNodeMaxNumChildren())
 {
 
 }
@@ -54,7 +81,7 @@ void XTreeSplit<TreeType>::SplitLeafNode(std::vector<bool>& relevels)
     // Because this was a leaf node, numChildren must be 0.
     tree->Children()[(tree->NumChildren())++] = copy;
     assert(tree->NumChildren() == 1);
-    copy->SplitNode(relevels);
+    copy->Split().SplitLeafNode(relevels);
     return;
   }
 
@@ -72,7 +99,7 @@ void XTreeSplit<TreeType>::SplitLeafNode(std::vector<bool>& relevels)
     size_t p = tree->MaxLeafSize() * 0.3;
     if (p == 0)
     {
-      tree->SplitNode(relevels);
+      tree->Split().SplitLeafNode(relevels);
       return;
     }
 
@@ -233,8 +260,8 @@ void XTreeSplit<TreeType>::SplitLeafNode(std::vector<bool>& relevels)
 
   std::sort(sorted.begin(), sorted.end(), structComp<ElemType>);
 
-  TreeType* treeOne = new TreeType(tree->Parent(),tree->NormalNodeMaxNumChildren());
-  TreeType* treeTwo = new TreeType(tree->Parent(),tree->NormalNodeMaxNumChildren());
+  TreeType* treeOne = new TreeType(tree->Parent(),NormalNodeMaxNumChildren());
+  TreeType* treeTwo = new TreeType(tree->Parent(),NormalNodeMaxNumChildren());
 
   // The leaf nodes should never have any overlap introduced by the above method
   // since a split axis is chosen and then points are assigned based on their
@@ -285,7 +312,7 @@ void XTreeSplit<TreeType>::SplitLeafNode(std::vector<bool>& relevels)
   // in case, we use an assert.
   assert(par->NumChildren() <= par->MaxNumChildren() + 1);
   if (par->NumChildren() == par->MaxNumChildren() + 1)
-    par->SplitNode(relevels);
+    par->Split().SplitNonLeafNode(relevels);
 
   assert(treeOne->Parent()->NumChildren() <=
       treeOne->Parent()->MaxNumChildren());
@@ -324,7 +351,7 @@ bool XTreeSplit<TreeType>::SplitNonLeafNode(std::vector<bool>& relevels)
     tree->NumChildren() = 0;
     tree->NullifyData();
     tree->Children()[(tree->NumChildren())++] = copy;
-    copy->SplitNode(relevels);
+    copy->Split().SplitNonLeafNode(relevels);
     return true;
   }
 
@@ -758,7 +785,7 @@ bool XTreeSplit<TreeType>::SplitNonLeafNode(std::vector<bool>& relevels)
           (tree->Parent()->NumChildren() == 1))
       {
         // We make the root a supernode instead.
-        tree->Parent()->MaxNumChildren() = tree->MaxNumChildren() + tree->NormalNodeMaxNumChildren();
+        tree->Parent()->MaxNumChildren() = tree->MaxNumChildren() + NormalNodeMaxNumChildren();
         tree->Parent()->Children().resize(tree->Parent()->MaxNumChildren() + 1);
         tree->Parent()->NumChildren() = tree->NumChildren();
         for (size_t i = 0; i < tree->NumChildren(); i++)
@@ -775,7 +802,7 @@ bool XTreeSplit<TreeType>::SplitNonLeafNode(std::vector<bool>& relevels)
       }
 
       // If we don't have to worry about the root, we just enlarge this node.
-      tree->MaxNumChildren() += tree->NormalNodeMaxNumChildren();
+      tree->MaxNumChildren() += NormalNodeMaxNumChildren();
       tree->Children().resize(tree->MaxNumChildren() + 1);
       for (size_t i = 0; i < tree->NumChildren(); i++)
         tree->Child(i).Parent() = tree;
@@ -818,7 +845,7 @@ bool XTreeSplit<TreeType>::SplitNonLeafNode(std::vector<bool>& relevels)
 
   if (par->NumChildren() == par->MaxNumChildren() + 1)
   {
-    par->SplitNode(relevels);
+    par->Split().SplitNonLeafNode(relevels);
   }
 
   // We have to update the children of each of these new nodes so that they




More information about the mlpack-git mailing list