[mlpack-git] master: Replace static SplitType in the RectangleTree by an instantiated object. (03a34ed)

gitdub at mlpack.org gitdub at mlpack.org
Fri Apr 22 08:10:45 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/4c8a8d1ccbc33916794fe0f6142fa5378ff8503d...93eb34cc4daf08cc331d1a22fccf44950b276daa

>---------------------------------------------------------------

commit 03a34ed2e59c640f29d40483f02d3175f32abe78
Author: Mikhail Lozhnikov <lozhnikovma at gmail.com>
Date:   Fri Apr 22 15:10:45 2016 +0300

    Replace static SplitType in the RectangleTree by an instantiated object.


>---------------------------------------------------------------

03a34ed2e59c640f29d40483f02d3175f32abe78
 .../tree/rectangle_tree/dual_tree_traverser.hpp    |  2 +-
 .../rectangle_tree/dual_tree_traverser_impl.hpp    |  4 +-
 .../core/tree/rectangle_tree/r_star_tree_split.hpp | 17 ++++--
 .../tree/rectangle_tree/r_star_tree_split_impl.hpp | 31 +++++++---
 .../core/tree/rectangle_tree/r_tree_split.hpp      | 27 ++++-----
 .../core/tree/rectangle_tree/r_tree_split_impl.hpp | 60 +++++++++++--------
 .../core/tree/rectangle_tree/rectangle_tree.hpp    |  6 +-
 .../tree/rectangle_tree/rectangle_tree_impl.hpp    | 68 ++++++++++++----------
 .../tree/rectangle_tree/single_tree_traverser.hpp  |  2 +-
 .../rectangle_tree/single_tree_traverser_impl.hpp  |  4 +-
 src/mlpack/core/tree/rectangle_tree/traits.hpp     |  2 +-
 .../core/tree/rectangle_tree/x_tree_split.hpp      | 17 ++++--
 .../core/tree/rectangle_tree/x_tree_split_impl.hpp | 31 +++++++---
 13 files changed, 169 insertions(+), 102 deletions(-)

diff --git a/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp
index 397a1c2..75a0501 100644
--- a/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp
@@ -19,7 +19,7 @@ namespace tree {
 template<typename MetricType,
          typename StatisticType,
          typename MatType,
-         typename SplitType,
+         template<typename> class SplitType,
          typename DescentType>
 template<typename RuleType>
 class RectangleTree<MetricType, StatisticType, MatType, SplitType,
diff --git a/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp
index 42379d2..7f02ad9 100644
--- a/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp
@@ -20,7 +20,7 @@ namespace tree {
 template<typename MetricType,
          typename StatisticType,
          typename MatType,
-         typename SplitType,
+         template<typename> class SplitType,
          typename DescentType>
 template<typename RuleType>
 RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
@@ -35,7 +35,7 @@ DualTreeTraverser<RuleType>::DualTreeTraverser(RuleType& rule) :
 template<typename MetricType,
          typename StatisticType,
          typename MatType,
-         typename SplitType,
+         template<typename> class SplitType,
          typename DescentType>
 template<typename RuleType>
 void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
diff --git a/src/mlpack/core/tree/rectangle_tree/r_star_tree_split.hpp b/src/mlpack/core/tree/rectangle_tree/r_star_tree_split.hpp
index 54d1bba..3970f3a 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_star_tree_split.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_star_tree_split.hpp
@@ -18,25 +18,33 @@ namespace tree /** Trees and tree-building procedures. */ {
  * nodes overflow, we split them, moving up the tree and splitting nodes
  * as necessary.
  */
+template <typename TreeType>
 class RStarTreeSplit
 {
  public:
+  //! Default constructor
+  RStarTreeSplit();
+
+  //! Construct this with specified node.
+  RStarTreeSplit(TreeType *node);
+
   /**
    * Split a leaf node using the algorithm described in "The R*-tree: An
    * Efficient and Robust Access method for Points and Rectangles."  If
    * necessary, this split will propagate upwards through the tree.
    */
-  template<typename TreeType>
-  static void SplitLeafNode(TreeType* tree, std::vector<bool>& relevels);
+  void SplitLeafNode(std::vector<bool>& relevels);
 
   /**
    * Split a non-leaf node using the "default" algorithm.  If this is a root
    * node, the tree increases in depth.
    */
-  template<typename TreeType>
-  static bool SplitNonLeafNode(TreeType* tree, std::vector<bool>& relevels);
+  bool SplitNonLeafNode(std::vector<bool>& relevels);
 
  private:
+  //! The node which has to be split.
+  TreeType *tree;
+ 
   /**
    * Class to allow for faster sorting.
    */
@@ -60,7 +68,6 @@ class RStarTreeSplit
   /**
    * Insert a node into another node.
    */
-  template<typename TreeType>
   static void InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode);
 };
 
diff --git a/src/mlpack/core/tree/rectangle_tree/r_star_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_star_tree_split_impl.hpp
index 9c0e95c..a2af8aa 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_star_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_star_tree_split_impl.hpp
@@ -15,6 +15,20 @@
 namespace mlpack {
 namespace tree {
 
+template<typename TreeType>
+RStarTreeSplit<TreeType>::RStarTreeSplit() :
+    tree(NULL)
+{
+
+}
+
+template<typename TreeType>
+RStarTreeSplit<TreeType>::RStarTreeSplit(TreeType *node) :
+    tree(node)
+{
+
+}
+
 /**
  * We call GetPointSeeds to get the two points which will be the initial points
  * in the new nodes We then call AssignPointDestNode to assign the remaining
@@ -22,7 +36,7 @@ namespace tree {
  * new nodes into the tree, spliting the parent if necessary.
  */
 template<typename TreeType>
-void RStarTreeSplit::SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
+void RStarTreeSplit<TreeType>::SplitLeafNode(std::vector<bool>& relevels)
 {
   // Convenience typedef.
   typedef typename TreeType::ElemType ElemType;
@@ -41,7 +55,7 @@ void RStarTreeSplit::SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
     tree->Children()[(tree->NumChildren())++] = copy;
     assert(tree->NumChildren() == 1);
 
-    SplitLeafNode(copy, relevels);
+    copy->SplitNode(relevels);
     return;
   }
 
@@ -58,7 +72,7 @@ void RStarTreeSplit::SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
     size_t p = tree->MaxLeafSize() * 0.3; // The paper says this works the best.
     if (p == 0)
     {
-      SplitLeafNode(tree, relevels);
+      tree->SplitNode(relevels);
       return;
     }
 
@@ -251,7 +265,7 @@ void RStarTreeSplit::SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
   // just in case, we use an assert.
   assert(par->NumChildren() <= par->MaxNumChildren() + 1);
   if (par->NumChildren() == par->MaxNumChildren() + 1)
-    SplitNonLeafNode(par, relevels);
+    par->SplitNode(relevels);
 
   assert(treeOne->Parent()->NumChildren() <= treeOne->MaxNumChildren());
   assert(treeOne->Parent()->NumChildren() >= treeOne->MinNumChildren());
@@ -269,8 +283,7 @@ void RStarTreeSplit::SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
  * higher up the tree because they were already updated if necessary.
  */
 template<typename TreeType>
-bool RStarTreeSplit::SplitNonLeafNode(TreeType* tree,
-                                      std::vector<bool>& relevels)
+bool RStarTreeSplit<TreeType>::SplitNonLeafNode(std::vector<bool>& relevels)
 {
   // Convenience typedef.
   typedef typename TreeType::ElemType ElemType;
@@ -288,7 +301,7 @@ bool RStarTreeSplit::SplitNonLeafNode(TreeType* tree,
     tree->NullifyData();
     tree->Children()[(tree->NumChildren())++] = copy;
 
-    SplitNonLeafNode(copy, relevels);
+    copy->SplitNode(relevels);
     return true;
   }
 
@@ -644,7 +657,7 @@ bool RStarTreeSplit::SplitNonLeafNode(TreeType* tree,
   assert(par->NumChildren() <= par->MaxNumChildren() + 1);
   if (par->NumChildren() == par->MaxNumChildren() + 1)
   {
-    SplitNonLeafNode(par, relevels);
+    par->SplitNode(relevels);
   }
 
   // We have to update the children of each of these new nodes so that they
@@ -673,7 +686,7 @@ bool RStarTreeSplit::SplitNonLeafNode(TreeType* tree,
  * numberOfChildren.
  */
 template<typename TreeType>
-void RStarTreeSplit::InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode)
+void RStarTreeSplit<TreeType>::InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode)
 {
   destTree->Bound() |= srcNode->Bound();
   destTree->Children()[destTree->NumChildren()++] = srcNode;
diff --git a/src/mlpack/core/tree/rectangle_tree/r_tree_split.hpp b/src/mlpack/core/tree/rectangle_tree/r_tree_split.hpp
index 1bf5745..8a412ad 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_tree_split.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_tree_split.hpp
@@ -18,42 +18,45 @@ namespace tree /** Trees and tree-building procedures. */ {
  * nodes overflow, we split them, moving up the tree and splitting nodes
  * as necessary.
  */
+template<typename TreeType>
 class RTreeSplit
 {
  public:
+  //! Default constructor
+  RTreeSplit();
+
+  //! Construct this with specified node.
+  RTreeSplit(TreeType *node);
+
   /**
    * Split a leaf node using the "default" algorithm.  If necessary, this split
    * will propagate upwards through the tree.
    */
-  template<typename TreeType>
-  static void SplitLeafNode(TreeType* tree,
-                            std::vector<bool>& relevels);
+  void SplitLeafNode(std::vector<bool>& relevels);
 
   /**
    * Split a non-leaf node using the "default" algorithm.  If this is a root
    * node, the tree increases in depth.
    */
-  template<typename TreeType>
-  static bool SplitNonLeafNode(TreeType* tree,
-                               std::vector<bool>& relevels);
+  bool SplitNonLeafNode(std::vector<bool>& relevels);
 
  private:
+  //! The node which has to be split.
+  TreeType *tree;
+
   /**
    * Get the seeds for splitting a leaf node.
    */
-  template<typename TreeType>
-  static void GetPointSeeds(const TreeType& tree, int& i, int& j);
+  void GetPointSeeds(int& i, int& j);
 
   /**
    * Get the seeds for splitting a non-leaf node.
    */
-  template<typename TreeType>
-  static void GetBoundSeeds(const TreeType& tree, int& i, int& j);
+  void GetBoundSeeds(int& i, int& j);
 
   /**
    * Assign points to the two new nodes.
    */
-  template<typename TreeType>
   static void AssignPointDestNode(TreeType* oldTree,
                                   TreeType* treeOne,
                                   TreeType* treeTwo,
@@ -63,7 +66,6 @@ class RTreeSplit
   /**
    * Assign nodes to the two new nodes.
    */
-  template<typename TreeType>
   static void AssignNodeDestNode(TreeType* oldTree,
                                  TreeType* treeOne,
                                  TreeType* treeTwo,
@@ -73,7 +75,6 @@ class RTreeSplit
   /**
    * Insert a node into another node.
    */
-  template<typename TreeType>
   static void InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode);
 };
 
diff --git a/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
index 15d70bd..5bc0461 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
@@ -14,6 +14,20 @@
 namespace mlpack {
 namespace tree {
 
+template<typename TreeType>
+RTreeSplit<TreeType>::RTreeSplit() :
+    tree(NULL)
+{
+
+}
+
+template<typename TreeType>
+RTreeSplit<TreeType>::RTreeSplit(TreeType *node) :
+    tree(node)
+{
+
+}
+
 /**
  * We call GetPointSeeds to get the two points which will be the initial points
  * in the new nodes We then call AssignPointDestNode to assign the remaining
@@ -21,7 +35,7 @@ namespace tree {
  * new nodes into the tree, spliting the parent if necessary.
  */
 template<typename TreeType>
-void RTreeSplit::SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
+void RTreeSplit<TreeType>::SplitLeafNode(std::vector<bool>& relevels)
 {
   // If we are splitting the root node, we need will do things differently so
   // that the constructor and other methods don't confuse the end user by giving
@@ -35,7 +49,7 @@ void RTreeSplit::SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
     tree->NullifyData();
     // Because this was a leaf node, numChildren must be 0.
     tree->Children()[(tree->NumChildren())++] = copy;
-    SplitLeafNode(copy, relevels);
+    copy->SplitNode(relevels);
     return;
   }
 
@@ -46,7 +60,7 @@ void RTreeSplit::SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
   // rectangles, only points.  We assume that the tree uses Euclidean Distance.
   int i = 0;
   int j = 0;
-  GetPointSeeds(*tree, i, j);
+  GetPointSeeds(i, j);
 
   TreeType* treeOne = new TreeType(tree->Parent());
   TreeType* treeTwo = new TreeType(tree->Parent());
@@ -66,7 +80,7 @@ void RTreeSplit::SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
   // just in case, we use an assert.
   assert(par->NumChildren() <= par->MaxNumChildren() + 1);
   if (par->NumChildren() == par->MaxNumChildren() + 1)
-    SplitNonLeafNode(par, relevels);
+    par->SplitNode(relevels);
 
   assert(treeOne->Parent()->NumChildren() <= treeOne->MaxNumChildren());
   assert(treeOne->Parent()->NumChildren() >= treeOne->MinNumChildren());
@@ -85,7 +99,7 @@ void RTreeSplit::SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
  * higher up the tree because they were already updated if necessary.
  */
 template<typename TreeType>
-bool RTreeSplit::SplitNonLeafNode(TreeType* tree, std::vector<bool>& relevels)
+bool RTreeSplit<TreeType>::SplitNonLeafNode(std::vector<bool>& relevels)
 {
   // If we are splitting the root node, we need will do things differently so
   // that the constructor and other methods don't confuse the end user by giving
@@ -98,13 +112,13 @@ bool RTreeSplit::SplitNonLeafNode(TreeType* tree, std::vector<bool>& relevels)
     tree->NumChildren() = 0;
     tree->NullifyData();
     tree->Children()[(tree->NumChildren())++] = copy;
-    SplitNonLeafNode(copy, relevels);
+    copy->SplitNode(relevels);
     return true;
   }
 
   int i = 0;
   int j = 0;
-  GetBoundSeeds(*tree, i, j);
+  GetBoundSeeds(i, j);
 
   assert(i != j);
 
@@ -131,7 +145,7 @@ bool RTreeSplit::SplitNonLeafNode(TreeType* tree, std::vector<bool>& relevels)
   assert(par->NumChildren() <= par->MaxNumChildren() + 1);
 
   if (par->NumChildren() == par->MaxNumChildren() + 1)
-    SplitNonLeafNode(par, relevels);
+    par->SplitNode(relevels);
 
   // We have to update the children of each of these new nodes so that they
   // record the correct parent.
@@ -157,18 +171,18 @@ bool RTreeSplit::SplitNonLeafNode(TreeType* tree, std::vector<bool>& relevels)
  * The indices of these points will be stored in iRet and jRet.
  */
 template<typename TreeType>
-void RTreeSplit::GetPointSeeds(const TreeType& tree, int& iRet, int& jRet)
+void RTreeSplit<TreeType>::GetPointSeeds(int& iRet, int& jRet)
 {
   // Here we want to find the pair of points that it is worst to place in the
   // same node.  Because we are just using points, we will simply choose the two
   // that would create the most voluminous hyperrectangle.
   typename TreeType::ElemType worstPairScore = -1.0;
-  for (size_t i = 0; i < tree.Count(); i++)
+  for (size_t i = 0; i < tree->Count(); i++)
   {
-    for (size_t j = i + 1; j < tree.Count(); j++)
+    for (size_t j = i + 1; j < tree->Count(); j++)
     {
       const typename TreeType::ElemType score = arma::prod(arma::abs(
-          tree.LocalDataset().col(i) - tree.LocalDataset().col(j)));
+          tree->LocalDataset().col(i) - tree->LocalDataset().col(j)));
 
       if (score > worstPairScore)
       {
@@ -185,23 +199,23 @@ void RTreeSplit::GetPointSeeds(const TreeType& tree, int& iRet, int& jRet)
  * indices of the bounds will be stored in iRet and jRet.
  */
 template<typename TreeType>
-void RTreeSplit::GetBoundSeeds(const TreeType& tree, int& iRet, int& jRet)
+void RTreeSplit<TreeType>::GetBoundSeeds(int& iRet, int& jRet)
 {
   // Convenience typedef.
   typedef typename TreeType::ElemType ElemType;
 
   ElemType worstPairScore = -1.0;
-  for (size_t i = 0; i < tree.NumChildren(); i++)
+  for (size_t i = 0; i < tree->NumChildren(); i++)
   {
-    for (size_t j = i + 1; j < tree.NumChildren(); j++)
+    for (size_t j = i + 1; j < tree->NumChildren(); j++)
     {
       ElemType score = 1.0;
-      for (size_t k = 0; k < tree.Bound().Dim(); k++)
+      for (size_t k = 0; k < tree->Bound().Dim(); k++)
       {
-        const ElemType hiMax = std::max(tree.Children()[i]->Bound()[k].Hi(),
-                                        tree.Children()[j]->Bound()[k].Hi());
-        const ElemType loMin = std::min(tree.Children()[i]->Bound()[k].Lo(),
-                                        tree.Children()[j]->Bound()[k].Lo());
+        const ElemType hiMax = std::max(tree->Children()[i]->Bound()[k].Hi(),
+                                        tree->Children()[j]->Bound()[k].Hi());
+        const ElemType loMin = std::min(tree->Children()[i]->Bound()[k].Lo(),
+                                        tree->Children()[j]->Bound()[k].Lo());
         score *= (hiMax - loMin);
       }
 
@@ -216,7 +230,7 @@ void RTreeSplit::GetBoundSeeds(const TreeType& tree, int& iRet, int& jRet)
 }
 
 template<typename TreeType>
-void RTreeSplit::AssignPointDestNode(TreeType* oldTree,
+void RTreeSplit<TreeType>::AssignPointDestNode(TreeType* oldTree,
                                      TreeType* treeOne,
                                      TreeType* treeTwo,
                                      const int intI,
@@ -357,7 +371,7 @@ void RTreeSplit::AssignPointDestNode(TreeType* oldTree,
 }
 
 template<typename TreeType>
-void RTreeSplit::AssignNodeDestNode(TreeType* oldTree,
+void RTreeSplit<TreeType>::AssignNodeDestNode(TreeType* oldTree,
                                     TreeType* treeOne,
                                     TreeType* treeTwo,
                                     const int intI,
@@ -522,7 +536,7 @@ void RTreeSplit::AssignNodeDestNode(TreeType* oldTree,
  * numberOfChildren.
  */
 template<typename TreeType>
-void RTreeSplit::InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode)
+void RTreeSplit<TreeType>::InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode)
 {
   destTree->Bound() |= srcNode->Bound();
   destTree->Children()[destTree->NumChildren()++] = srcNode;
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
index e4d69d3..090a832 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
@@ -35,13 +35,15 @@ namespace tree /** Trees and tree-building procedures. */ {
  * @tparam DescentType The heuristic to use when descending the tree to insert
  *    points.
  */
+
 template<typename MetricType = metric::EuclideanDistance,
          typename StatisticType = EmptyStatistic,
          typename MatType = arma::mat,
-         typename SplitType = RTreeSplit,
+         template<typename> class SplitType = RTreeSplit,
          typename DescentType = RTreeDescentHeuristic>
 class RectangleTree
 {
+  friend class SplitType<RectangleTree>;
   // The metric *must* be the euclidean distance.
   static_assert(boost::is_same<MetricType, metric::EuclideanDistance>::value,
       "RectangleTree: MetricType must be metric::EuclideanDistance.");
@@ -118,6 +120,8 @@ class RectangleTree
   std::vector<size_t> points;
   //! The local dataset
   MatType* localDataset;
+  //! The class that performs the split of the node.
+  SplitType<RectangleTree> split;
 
  public:
   //! A single traverser for rectangle type trees.  See
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
index 2c747ca..cca78de 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
@@ -20,7 +20,7 @@ namespace tree {
 template<typename MetricType,
          typename StatisticType,
          typename MatType,
-         typename SplitType,
+         template<typename> class SplitType,
          typename DescentType>
 RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
 RectangleTree(const MatType& data,
@@ -50,6 +50,8 @@ RectangleTree(const MatType& data,
 {
   stat = StatisticType(*this);
 
+  split = SplitType<RectangleTree>(this);
+
   // For now, just insert the points in order.
   RectangleTree* root = this;
 
@@ -60,7 +62,7 @@ RectangleTree(const MatType& data,
 template<typename MetricType,
          typename StatisticType,
          typename MatType,
-         typename SplitType,
+         template<typename> class SplitType,
          typename DescentType>
 RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
 RectangleTree(MatType&& data,
@@ -90,6 +92,8 @@ RectangleTree(MatType&& data,
 {
   stat = StatisticType(*this);
 
+  split = SplitType<RectangleTree>(this);
+
   // For now, just insert the points in order.
   RectangleTree* root = this;
 
@@ -100,7 +104,7 @@ RectangleTree(MatType&& data,
 template<typename MetricType,
          typename StatisticType,
          typename MatType,
-         typename SplitType,
+         template<typename> class SplitType,
          typename DescentType>
 RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
 RectangleTree(
@@ -126,6 +130,7 @@ RectangleTree(
                                                   maxLeafSize + 1)))
 {
   stat = StatisticType(*this);
+  split = SplitType<RectangleTree>(this);
 }
 
 /**
@@ -135,7 +140,7 @@ RectangleTree(
 template<typename MetricType,
          typename StatisticType,
          typename MatType,
-         typename SplitType,
+         template<typename> class SplitType,
          typename DescentType>
 RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
 RectangleTree(
@@ -159,6 +164,7 @@ RectangleTree(
     points(other.Points()),
     localDataset(NULL)
 {
+  split = SplitType<RectangleTree>(this);
   if (deepCopy)
   {
     if (numChildren > 0)
@@ -187,7 +193,7 @@ RectangleTree(
 template<typename MetricType,
          typename StatisticType,
          typename MatType,
-         typename SplitType,
+         template<typename> class SplitType,
          typename DescentType>
 template<typename Archive>
 RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
@@ -208,7 +214,7 @@ RectangleTree(
 template<typename MetricType,
          typename StatisticType,
          typename MatType,
-         typename SplitType,
+         template<typename> class SplitType,
          typename DescentType>
 RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
 ~RectangleTree()
@@ -229,7 +235,7 @@ RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
 template<typename MetricType,
          typename StatisticType,
          typename MatType,
-         typename SplitType,
+         template<typename> class SplitType,
          typename DescentType>
 void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
     SoftDelete()
@@ -249,7 +255,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
 template<typename MetricType,
          typename StatisticType,
          typename MatType,
-         typename SplitType,
+         template<typename> class SplitType,
          typename DescentType>
 void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
     NullifyData()
@@ -264,7 +270,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
 template<typename MetricType,
          typename StatisticType,
          typename MatType,
-         typename SplitType,
+         template<typename> class SplitType,
          typename DescentType>
 void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
     InsertPoint(const size_t point)
@@ -301,7 +307,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
 template<typename MetricType,
          typename StatisticType,
          typename MatType,
-         typename SplitType,
+         template<typename> class SplitType,
          typename DescentType>
 void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
     InsertPoint(const size_t point, std::vector<bool>& relevels)
@@ -336,7 +342,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
 template<typename MetricType,
          typename StatisticType,
          typename MatType,
-         typename SplitType,
+         template<typename> class SplitType,
          typename DescentType>
 void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
     InsertNode(RectangleTree* node,
@@ -365,7 +371,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
 template<typename MetricType,
          typename StatisticType,
          typename MatType,
-         typename SplitType,
+         template<typename> class SplitType,
          typename DescentType>
 bool RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
     DeletePoint(const size_t point)
@@ -410,7 +416,7 @@ bool RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
 template<typename MetricType,
          typename StatisticType,
          typename MatType,
-         typename SplitType,
+         template<typename> class SplitType,
          typename DescentType>
 bool RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
     DeletePoint(const size_t point, std::vector<bool>& relevels)
@@ -445,7 +451,7 @@ bool RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
 template<typename MetricType,
          typename StatisticType,
          typename MatType,
-         typename SplitType,
+         template<typename> class SplitType,
          typename DescentType>
 bool RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
     RemoveNode(const RectangleTree* node, std::vector<bool>& relevels)
@@ -474,7 +480,7 @@ bool RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
 template<typename MetricType,
          typename StatisticType,
          typename MatType,
-         typename SplitType,
+         template<typename> class SplitType,
          typename DescentType>
 size_t RectangleTree<MetricType, StatisticType, MatType, SplitType,
                      DescentType>::TreeSize() const
@@ -489,7 +495,7 @@ size_t RectangleTree<MetricType, StatisticType, MatType, SplitType,
 template<typename MetricType,
          typename StatisticType,
          typename MatType,
-         typename SplitType,
+         template<typename> class SplitType,
          typename DescentType>
 size_t RectangleTree<MetricType, StatisticType, MatType, SplitType,
                      DescentType>::TreeDepth() const
@@ -509,7 +515,7 @@ size_t RectangleTree<MetricType, StatisticType, MatType, SplitType,
 template<typename MetricType,
          typename StatisticType,
          typename MatType,
-         typename SplitType,
+         template<typename> class SplitType,
          typename DescentType>
 inline bool RectangleTree<MetricType, StatisticType, MatType, SplitType,
                           DescentType>::IsLeaf() const
@@ -524,7 +530,7 @@ inline bool RectangleTree<MetricType, StatisticType, MatType, SplitType,
 template<typename MetricType,
          typename StatisticType,
          typename MatType,
-         typename SplitType,
+         template<typename> class SplitType,
          typename DescentType>
 inline
 typename RectangleTree<MetricType, StatisticType, MatType, SplitType,
@@ -549,7 +555,7 @@ RectangleTree<MetricType, StatisticType, MatType, SplitType,
 template<typename MetricType,
          typename StatisticType,
          typename MatType,
-         typename SplitType,
+         template<typename> class SplitType,
          typename DescentType>
 inline
 typename RectangleTree<MetricType, StatisticType, MatType, SplitType,
@@ -568,7 +574,7 @@ RectangleTree<MetricType, StatisticType, MatType, SplitType,
 template<typename MetricType,
          typename StatisticType,
          typename MatType,
-         typename SplitType,
+         template<typename> class SplitType,
          typename DescentType>
 inline size_t RectangleTree<MetricType, StatisticType, MatType, SplitType,
                             DescentType>::NumPoints() const
@@ -585,7 +591,7 @@ inline size_t RectangleTree<MetricType, StatisticType, MatType, SplitType,
 template<typename MetricType,
          typename StatisticType,
          typename MatType,
-         typename SplitType,
+         template<typename> class SplitType,
          typename DescentType>
 inline size_t RectangleTree<MetricType, StatisticType, MatType, SplitType,
                             DescentType>::NumDescendants() const
@@ -609,7 +615,7 @@ inline size_t RectangleTree<MetricType, StatisticType, MatType, SplitType,
 template<typename MetricType,
          typename StatisticType,
          typename MatType,
-         typename SplitType,
+         template<typename> class SplitType,
          typename DescentType>
 inline size_t RectangleTree<MetricType, StatisticType, MatType, SplitType,
                             DescentType>::Descendant(const size_t index) const
@@ -641,7 +647,7 @@ inline size_t RectangleTree<MetricType, StatisticType, MatType, SplitType,
 template<typename MetricType,
          typename StatisticType,
          typename MatType,
-         typename SplitType,
+         template<typename> class SplitType,
          typename DescentType>
 inline size_t RectangleTree<MetricType, StatisticType, MatType, SplitType,
                             DescentType>::Point(const size_t index) const
@@ -656,7 +662,7 @@ inline size_t RectangleTree<MetricType, StatisticType, MatType, SplitType,
 template<typename MetricType,
          typename StatisticType,
          typename MatType,
-         typename SplitType,
+         template<typename> class SplitType,
          typename DescentType>
 void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
     SplitNode(std::vector<bool>& relevels)
@@ -669,7 +675,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
 
     // If we are full, then we need to split (or at least try).  The SplitType
     // takes care of this and of moving up the tree if necessary.
-    SplitType::SplitLeafNode(this, relevels);
+    split.SplitLeafNode(relevels);
   }
   else
   {
@@ -679,7 +685,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
 
     // If we are full, then we need to split (or at least try).  The SplitType
     // takes care of this and of moving up the tree if necessary.
-    SplitType::SplitNonLeafNode(this, relevels);
+    split.SplitNonLeafNode(relevels);
   }
 }
 
@@ -687,7 +693,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
 template<typename MetricType,
          typename StatisticType,
          typename MatType,
-         typename SplitType,
+         template<typename> class SplitType,
          typename DescentType>
 RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
 RectangleTree() :
@@ -715,7 +721,7 @@ RectangleTree() :
 template<typename MetricType,
          typename StatisticType,
          typename MatType,
-         typename SplitType,
+         template<typename> class SplitType,
          typename DescentType>
 void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
     CondenseTree(const arma::vec& point,
@@ -844,7 +850,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
 template<typename MetricType,
          typename StatisticType,
          typename MatType,
-         typename SplitType,
+         template<typename> class SplitType,
          typename DescentType>
 bool RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
     ShrinkBoundForPoint(const arma::vec& point)
@@ -940,7 +946,7 @@ bool RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
 template<typename MetricType,
          typename StatisticType,
          typename MatType,
-         typename SplitType,
+         template<typename> class SplitType,
          typename DescentType>
 bool RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
     ShrinkBoundForBound(const bound::HRectBound<MetricType>& /* b */)
@@ -974,7 +980,7 @@ bool RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
 template<typename MetricType,
          typename StatisticType,
          typename MatType,
-         typename SplitType,
+         template<typename> class SplitType,
          typename DescentType>
 template<typename Archive>
 void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
diff --git a/src/mlpack/core/tree/rectangle_tree/single_tree_traverser.hpp b/src/mlpack/core/tree/rectangle_tree/single_tree_traverser.hpp
index 720a06e..7c4c938 100644
--- a/src/mlpack/core/tree/rectangle_tree/single_tree_traverser.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/single_tree_traverser.hpp
@@ -19,7 +19,7 @@ namespace tree {
 template<typename MetricType,
          typename StatisticType,
          typename MatType,
-         typename SplitType,
+         template<typename> class SplitType,
          typename DescentType>
 template<typename RuleType>
 class RectangleTree<MetricType, StatisticType, MatType, SplitType,
diff --git a/src/mlpack/core/tree/rectangle_tree/single_tree_traverser_impl.hpp b/src/mlpack/core/tree/rectangle_tree/single_tree_traverser_impl.hpp
index a5d9792..b67ae72 100644
--- a/src/mlpack/core/tree/rectangle_tree/single_tree_traverser_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/single_tree_traverser_impl.hpp
@@ -20,7 +20,7 @@ namespace tree {
 template<typename MetricType,
          typename StatisticType,
          typename MatType,
-         typename SplitType,
+         template<typename> class SplitType,
          typename DescentType>
 template<typename RuleType>
 RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
@@ -32,7 +32,7 @@ SingleTreeTraverser<RuleType>::SingleTreeTraverser(RuleType& rule) :
 template<typename MetricType,
          typename StatisticType,
          typename MatType,
-         typename SplitType,
+         template<typename> class SplitType,
          typename DescentType>
 template<typename RuleType>
 void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
diff --git a/src/mlpack/core/tree/rectangle_tree/traits.hpp b/src/mlpack/core/tree/rectangle_tree/traits.hpp
index 811d9ac..49b357f 100644
--- a/src/mlpack/core/tree/rectangle_tree/traits.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/traits.hpp
@@ -21,7 +21,7 @@ namespace tree {
 template<typename MetricType,
          typename StatisticType,
          typename MatType,
-         typename SplitType,
+         template<typename> class SplitType,
          typename DescentType>
 class TreeTraits<RectangleTree<MetricType, StatisticType, MatType, SplitType,
                                DescentType>>
diff --git a/src/mlpack/core/tree/rectangle_tree/x_tree_split.hpp b/src/mlpack/core/tree/rectangle_tree/x_tree_split.hpp
index ef9224c..3264353 100644
--- a/src/mlpack/core/tree/rectangle_tree/x_tree_split.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/x_tree_split.hpp
@@ -28,25 +28,33 @@ const double MAX_OVERLAP = 0.2;
  * nodes overflow, we split them, moving up the tree and splitting nodes
  * as necessary.
  */
+template<typename TreeType>
 class XTreeSplit
 {
  public:
+  //! Default constructor
+  XTreeSplit();
+
+  //! Construct this with specified node.
+  XTreeSplit(TreeType *node);
+
   /**
    * Split a leaf node using the algorithm described in "The R*-tree: An
    * Efficient and Robust Access method for Points and Rectangles."  If
    * necessary, this split will propagate upwards through the tree.
    */
-  template<typename TreeType>
-  static void SplitLeafNode(TreeType* tree, std::vector<bool>& relevels);
+  void SplitLeafNode(std::vector<bool>& relevels);
 
   /**
    * Split a non-leaf node using the "default" algorithm.  If this is a root
    * node, the tree increases in depth.
    */
-  template<typename TreeType>
-  static bool SplitNonLeafNode(TreeType* tree, std::vector<bool>& relevels);
+  bool SplitNonLeafNode(std::vector<bool>& relevels);
 
  private:
+  //! The node which has to be split.
+  TreeType *tree;
+
   /**
    * Class to allow for faster sorting.
    */
@@ -71,7 +79,6 @@ class XTreeSplit
   /**
    * Insert a node into another node.
    */
-  template<typename TreeType>
   static void InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode);
 };
 
diff --git a/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp
index 21d3da3..685b26a 100644
--- a/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp
@@ -14,6 +14,21 @@
 namespace mlpack {
 namespace tree {
 
+template<typename TreeType>
+XTreeSplit<TreeType>::XTreeSplit() :
+    tree(NULL)
+{
+
+}
+
+template<typename TreeType>
+XTreeSplit<TreeType>::XTreeSplit(TreeType *node) :
+    tree(node)
+{
+
+}
+
+
 /**
  * We call GetPointSeeds to get the two points which will be the initial points
  * in the new nodes We then call AssignPointDestNode to assign the remaining
@@ -21,7 +36,7 @@ namespace tree {
  * new nodes into the tree, spliting the parent if necessary.
  */
 template<typename TreeType>
-void XTreeSplit::SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
+void XTreeSplit<TreeType>::SplitLeafNode(std::vector<bool>& relevels)
 {
   // Convenience typedef.
   typedef typename TreeType::ElemType ElemType;
@@ -39,7 +54,7 @@ void XTreeSplit::SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
     // Because this was a leaf node, numChildren must be 0.
     tree->Children()[(tree->NumChildren())++] = copy;
     assert(tree->NumChildren() == 1);
-    XTreeSplit::SplitLeafNode(copy, relevels);
+    copy->SplitNode(relevels);
     return;
   }
 
@@ -57,7 +72,7 @@ void XTreeSplit::SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
     size_t p = tree->MaxLeafSize() * 0.3;
     if (p == 0)
     {
-      SplitLeafNode(tree, relevels);
+      tree->SplitNode(relevels);
       return;
     }
 
@@ -270,7 +285,7 @@ void XTreeSplit::SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
   // in case, we use an assert.
   assert(par->NumChildren() <= par->MaxNumChildren() + 1);
   if (par->NumChildren() == par->MaxNumChildren() + 1)
-    SplitNonLeafNode(par, relevels);
+    par->SplitNode(relevels);
 
   assert(treeOne->Parent()->NumChildren() <=
       treeOne->Parent()->MaxNumChildren());
@@ -292,7 +307,7 @@ void XTreeSplit::SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
  * higher up the tree because they were already updated if necessary.
  */
 template<typename TreeType>
-bool XTreeSplit::SplitNonLeafNode(TreeType* tree, std::vector<bool>& relevels)
+bool XTreeSplit<TreeType>::SplitNonLeafNode(std::vector<bool>& relevels)
 {
   // Convenience typedef.
   typedef typename TreeType::ElemType ElemType;
@@ -309,7 +324,7 @@ bool XTreeSplit::SplitNonLeafNode(TreeType* tree, std::vector<bool>& relevels)
     tree->NumChildren() = 0;
     tree->NullifyData();
     tree->Children()[(tree->NumChildren())++] = copy;
-    XTreeSplit::SplitNonLeafNode(copy, relevels);
+    copy->SplitNode(relevels);
     return true;
   }
 
@@ -803,7 +818,7 @@ bool XTreeSplit::SplitNonLeafNode(TreeType* tree, std::vector<bool>& relevels)
 
   if (par->NumChildren() == par->MaxNumChildren() + 1)
   {
-    SplitNonLeafNode(par, relevels);
+    par->SplitNode(relevels);
   }
 
   // We have to update the children of each of these new nodes so that they
@@ -832,7 +847,7 @@ bool XTreeSplit::SplitNonLeafNode(TreeType* tree, std::vector<bool>& relevels)
  * numberOfChildren.
  */
 template<typename TreeType>
-void XTreeSplit::InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode)
+void XTreeSplit<TreeType>::InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode)
 {
   destTree->Bound() |= srcNode->Bound();
   destTree->Children()[destTree->NumChildren()] = srcNode;




More information about the mlpack-git mailing list