[mlpack-git] master: R++ tree implementation (147617a)

gitdub at mlpack.org gitdub at mlpack.org
Thu Jul 7 17:30:32 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/6147ed01bab6eadcd6a5e796e259a6afacae4662...e0fd69006b17a845f066ea4de1e205fc0922739d

>---------------------------------------------------------------

commit 147617add993a5b3d037aa883412c6ec3f672bbb
Author: Mikhail Lozhnikov <lozhnikovma at gmail.com>
Date:   Sun Jun 19 16:34:28 2016 +0300

    R++ tree implementation


>---------------------------------------------------------------

147617add993a5b3d037aa883412c6ec3f672bbb
 src/mlpack/core/tree/CMakeLists.txt                |   6 +
 src/mlpack/core/tree/rectangle_tree.hpp            |   4 +
 .../rectangle_tree/no_auxiliary_information.hpp    |   7 ++
 .../r_plus_plus_tree_auxiliary_information.hpp     |  88 ++++++++++++++
 ...r_plus_plus_tree_auxiliary_information_impl.hpp | 131 ++++++++++++++++++++
 ....hpp => r_plus_plus_tree_descent_heuristic.hpp} |  16 +--
 .../r_plus_plus_tree_descent_heuristic_impl.hpp    |  47 ++++++++
 .../r_plus_plus_tree_split_policy.hpp              |  46 +++++++
 .../core/tree/rectangle_tree/r_plus_tree_split.hpp |   4 +
 .../tree/rectangle_tree/r_plus_tree_split_impl.hpp | 133 +++++++++++++++++----
 .../rectangle_tree/r_plus_tree_split_policy.hpp    |  46 +++++++
 src/mlpack/core/tree/rectangle_tree/typedef.hpp    |   9 +-
 src/mlpack/tests/rectangle_tree_test.cpp           | 117 +++++++++++++++++-
 13 files changed, 617 insertions(+), 37 deletions(-)

diff --git a/src/mlpack/core/tree/CMakeLists.txt b/src/mlpack/core/tree/CMakeLists.txt
index 527dd59..1cefc80 100644
--- a/src/mlpack/core/tree/CMakeLists.txt
+++ b/src/mlpack/core/tree/CMakeLists.txt
@@ -67,6 +67,12 @@ set(SOURCES
   rectangle_tree/r_plus_tree_descent_heuristic_impl.hpp
   rectangle_tree/r_plus_tree_split.hpp
   rectangle_tree/r_plus_tree_split_impl.hpp
+  rectangle_tree/r_plus_tree_split_policy.hpp
+  rectangle_tree/r_plus_plus_tree_descent_heuristic.hpp
+  rectangle_tree/r_plus_plus_tree_descent_heuristic_impl.hpp
+  rectangle_tree/r_plus_plus_tree_split_policy.hpp
+  rectangle_tree/r_plus_plus_tree_auxiliary_information.hpp
+  rectangle_tree/r_plus_plus_tree_auxiliary_information_impl.hpp
   statistic.hpp
   traversal_info.hpp
   tree_traits.hpp
diff --git a/src/mlpack/core/tree/rectangle_tree.hpp b/src/mlpack/core/tree/rectangle_tree.hpp
index a28cd9f..d418314 100644
--- a/src/mlpack/core/tree/rectangle_tree.hpp
+++ b/src/mlpack/core/tree/rectangle_tree.hpp
@@ -31,7 +31,11 @@
 #include "rectangle_tree/recursive_hilbert_value.hpp"
 #include "rectangle_tree/discrete_hilbert_value.hpp"
 #include "rectangle_tree/r_plus_tree_descent_heuristic.hpp"
+#include "rectangle_tree/r_plus_tree_split_policy.hpp"
 #include "rectangle_tree/r_plus_tree_split.hpp"
+#include "rectangle_tree/r_plus_plus_tree_auxiliary_information.hpp"
+#include "rectangle_tree/r_plus_plus_tree_descent_heuristic.hpp"
+#include "rectangle_tree/r_plus_plus_tree_split_policy.hpp"
 #include "rectangle_tree/typedef.hpp"
 
 #endif
diff --git a/src/mlpack/core/tree/rectangle_tree/no_auxiliary_information.hpp b/src/mlpack/core/tree/rectangle_tree/no_auxiliary_information.hpp
index ac37908..046075d 100644
--- a/src/mlpack/core/tree/rectangle_tree/no_auxiliary_information.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/no_auxiliary_information.hpp
@@ -65,6 +65,13 @@ class NoAuxiliaryInformation
   }
 
   /**
+   * Nothing to split.
+   */
+  void SplitAuxiliaryInfo(TreeType* , TreeType* , size_t ,
+      typename TreeType::ElemType)
+  { }
+
+  /**
    * Nothing to copy.
    */
   void Copy(TreeType* , TreeType* )
diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_auxiliary_information.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_auxiliary_information.hpp
new file mode 100644
index 0000000..926c300
--- /dev/null
+++ b/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_auxiliary_information.hpp
@@ -0,0 +1,88 @@
+/**
+ * @file r_plus_plus_tree_auxiliary_information.hpp
+ * @author Mikhail Lozhnikov
+ *
+ * Definition of the RPlusPlusTreeAuxiliaryInformation class,
+ * a class that provides some r++-tree specific information
+ * about the nodes.
+ */
+#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_AUXILIARY_INFORMATION_HPP
+#define MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_AUXILIARY_INFORMATION_HPP
+
+#include <mlpack/core.hpp>
+#include "../hrectbound.hpp"
+
+namespace mlpack {
+namespace tree {
+
+template<typename TreeType>
+class RPlusPlusTreeAuxiliaryInformation
+{
+ public:
+  typedef typename TreeType::ElemType ElemType;
+
+  RPlusPlusTreeAuxiliaryInformation();
+  RPlusPlusTreeAuxiliaryInformation(const TreeType* );
+  RPlusPlusTreeAuxiliaryInformation(const RPlusPlusTreeAuxiliaryInformation& );
+
+  /**
+   * Some tree types require to save some properties at the insertion process.
+   * This method should return false if it does not handle the process.
+   */
+  bool HandlePointInsertion(TreeType* , const size_t);
+
+  /**
+   * Some tree types require to save some properties at the insertion process.
+   * This method should return false if it does not handle the process.
+   */
+  bool HandleNodeInsertion(TreeType* , TreeType* ,bool);
+
+  /**
+   * Some tree types require to save some properties at the deletion process.
+   * This method should return false if it does not handle the process.
+   */
+  bool HandlePointDeletion(TreeType* , const size_t);
+
+  /**
+   * Some tree types require to save some properties at the deletion process.
+   * This method should return false if it does not handle the process.
+   */
+  bool HandleNodeRemoval(TreeType* , const size_t);
+
+
+  /**
+   * Some tree types require to propagate the information downward.
+   * This method should return false if this is not the case.
+   */
+  bool UpdateAuxiliaryInfo(TreeType* );
+
+  void SplitAuxiliaryInfo(TreeType* treeOne, TreeType* treeTwo,
+      size_t axis, ElemType cut);
+
+  static void Copy(TreeType* ,const TreeType* );
+
+  void NullifyData();
+
+  
+  bound::HRectBound<metric::EuclideanDistance, ElemType>& OuterBound()
+  { return outerBound; }
+
+  const bound::HRectBound<metric::EuclideanDistance, ElemType>& OuterBound() const
+  { return outerBound; }
+ private:
+
+  bound::HRectBound<metric::EuclideanDistance, ElemType> outerBound;
+ public:
+  /**
+   * Serialize the information.
+   */
+  template<typename Archive>
+  void Serialize(Archive &, const unsigned int /* version */);
+};
+
+} // namespace tree
+} // namespace mlpack
+
+#include "r_plus_plus_tree_auxiliary_information_impl.hpp"
+
+#endif//MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_AUXILIARY_INFORMATION_HPP
diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_auxiliary_information_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_auxiliary_information_impl.hpp
new file mode 100644
index 0000000..b6b2aea
--- /dev/null
+++ b/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_auxiliary_information_impl.hpp
@@ -0,0 +1,131 @@
+/**
+ * @file r_plus_plus_tree_auxiliary_information.hpp
+ * @author Mikhail Lozhnikov
+ *
+ * Implementation of the RPlusPlusTreeAuxiliaryInformation class,
+ * a class that provides some r++-tree specific information
+ * about the nodes.
+ */
+#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_AUXILIARY_INFORMATION_IMPL_HPP
+#define MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_AUXILIARY_INFORMATION_IMPL_HPP
+
+#include "r_plus_plus_tree_auxiliary_information.hpp"
+
+namespace mlpack {
+namespace tree {
+
+template<typename  TreeType>
+RPlusPlusTreeAuxiliaryInformation<TreeType>::
+RPlusPlusTreeAuxiliaryInformation() :
+    outerBound(0)
+{
+
+}
+
+template<typename  TreeType>
+RPlusPlusTreeAuxiliaryInformation<TreeType>::
+RPlusPlusTreeAuxiliaryInformation(const TreeType* tree) :
+    outerBound(tree->Parent() ?
+               tree->Parent()->AuxiliaryInfo().OuterBound() :
+               tree->Bound().Dim())
+{
+  if (!tree->Parent())
+    for (size_t k = 0; k < outerBound.Dim(); k++)
+    {
+      outerBound[k].Lo() = std::numeric_limits<ElemType>::lowest();
+      outerBound[k].Hi() = std::numeric_limits<ElemType>::max();
+    }
+}
+
+template<typename  TreeType>
+RPlusPlusTreeAuxiliaryInformation<TreeType>::
+RPlusPlusTreeAuxiliaryInformation(const RPlusPlusTreeAuxiliaryInformation& other) :
+    outerBound(other.OuterBound())
+{
+
+}
+
+template<typename  TreeType>
+bool RPlusPlusTreeAuxiliaryInformation<TreeType>::
+HandlePointInsertion(TreeType* , const size_t )
+{
+  return false;
+}
+
+template<typename  TreeType>
+bool RPlusPlusTreeAuxiliaryInformation<TreeType>::
+HandleNodeInsertion(TreeType* , TreeType* ,bool)
+{
+  assert(false);
+  return false;
+}
+
+template<typename  TreeType>
+bool RPlusPlusTreeAuxiliaryInformation<TreeType>::
+HandlePointDeletion(TreeType* , const size_t)
+{
+  return false;
+}
+
+template<typename  TreeType>
+bool RPlusPlusTreeAuxiliaryInformation<TreeType>::
+HandleNodeRemoval(TreeType* , const size_t)
+{
+  return false;
+}
+
+template<typename  TreeType>
+bool RPlusPlusTreeAuxiliaryInformation<TreeType>::
+UpdateAuxiliaryInfo(TreeType* )
+{
+  return false;
+}
+
+template<typename  TreeType>
+void RPlusPlusTreeAuxiliaryInformation<TreeType>::
+SplitAuxiliaryInfo(TreeType* treeOne, TreeType* treeTwo, size_t axis,
+    typename TreeType::ElemType cut)
+{
+  typedef bound::HRectBound<metric::EuclideanDistance, ElemType> Bound;
+  Bound& treeOneBound = treeOne->AuxiliaryInfo().OuterBound();
+  Bound& treeTwoBound = treeTwo->AuxiliaryInfo().OuterBound();
+
+  treeOneBound = outerBound;
+  treeTwoBound = outerBound;
+
+  treeOneBound[axis].Hi() = cut;
+  treeTwoBound[axis].Lo() = cut;
+}
+
+
+template<typename  TreeType>
+void RPlusPlusTreeAuxiliaryInformation<TreeType>::
+Copy(TreeType* dst, const TreeType* src)
+{
+  dst.OuterBound() = src.OuterBound();
+}
+
+template<typename  TreeType>
+void RPlusPlusTreeAuxiliaryInformation<TreeType>::
+NullifyData()
+{
+
+}
+
+/**
+ * Serialize the information.
+ */
+template<typename  TreeType>
+template<typename Archive>
+void RPlusPlusTreeAuxiliaryInformation<TreeType>::
+Serialize(Archive& ar, const unsigned int /* version */)
+{
+  using data::CreateNVP;
+
+  ar & CreateNVP(outerBound, "outerBound");
+}
+
+} // namespace tree
+} // namespace mlpack
+
+#endif//MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_AUXILIARY_INFORMATION_IMPL_HPP
diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_descent_heuristic.hpp
similarity index 66%
copy from src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic.hpp
copy to src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_descent_heuristic.hpp
index dfe8e0a..18166f4 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_descent_heuristic.hpp
@@ -1,19 +1,19 @@
 /**
- * @file r_plus_tree_descent_heuristic.hpp
+ * @file r_plus_plus_tree_descent_heuristic.hpp
  * @author Mikhail Lozhnikov
  *
- * Definition of RPlusTreeDescentHeuristic, a class that chooses the best child of a
- * node in an R tree when inserting a new point.
+ * Definition of RPlusPlusTreeDescentHeuristic, a class that chooses the best child of a
+ * node in an R++ tree when inserting a new point.
  */
-#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_DESCENT_HEURISTIC_HPP
-#define MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_DESCENT_HEURISTIC_HPP
+#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_DESCENT_HEURISTIC_HPP
+#define MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_DESCENT_HEURISTIC_HPP
 
 #include <mlpack/core.hpp>
 
 namespace mlpack {
 namespace tree {
 
-class RPlusTreeDescentHeuristic
+class RPlusPlusTreeDescentHeuristic
 {
  public:
   /**
@@ -44,6 +44,6 @@ class RPlusTreeDescentHeuristic
 } //  namespace tree
 } //  namespace mlpack
 
-#include "r_plus_tree_descent_heuristic_impl.hpp"
+#include "r_plus_plus_tree_descent_heuristic_impl.hpp"
 
-#endif  //  MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_DESCENT_HEURISTIC_HPP
+#endif  //  MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_DESCENT_HEURISTIC_HPP
diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_descent_heuristic_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_descent_heuristic_impl.hpp
new file mode 100644
index 0000000..566f686
--- /dev/null
+++ b/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_descent_heuristic_impl.hpp
@@ -0,0 +1,47 @@
+/**
+ * @file r_plus_plus_tree_descent_heuristic_impl.hpp
+ * @author Mikhail Lozhnikov
+ *
+ * Implementation of RPlusPlusTreeDescentHeuristic, a class that chooses the best child
+ * of a node in an R++ tree when inserting a new point.
+ */
+#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_DESCENT_HEURISTIC_IMPL_HPP
+#define MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_DESCENT_HEURISTIC_IMPL_HPP
+
+#include "r_plus_plus_tree_descent_heuristic.hpp"
+#include "../hrectbound.hpp"
+
+namespace mlpack {
+namespace tree {
+
+template<typename TreeType>
+size_t RPlusPlusTreeDescentHeuristic::
+ChooseDescentNode(TreeType* node, const size_t point)
+{
+  for (size_t bestIndex = 0; bestIndex < node->NumChildren(); bestIndex++)
+  {
+    if (node->Children()[bestIndex]->AuxiliaryInfo().OuterBound().Contains(node->Dataset().col(point)))
+      return bestIndex;
+  }
+
+  assert(false);
+
+  return 0;
+}
+
+template<typename TreeType>
+size_t RPlusPlusTreeDescentHeuristic::
+ChooseDescentNode(const TreeType* , const TreeType* )
+{
+  size_t bestIndex = 0;
+
+  assert(false);
+
+  return bestIndex;
+}
+
+
+} //  namespace tree
+} //  namespace mlpack
+
+#endif  //MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_DESCENT_HEURISTIC_IMPL_HPP
diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_split_policy.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_split_policy.hpp
new file mode 100644
index 0000000..e0484af
--- /dev/null
+++ b/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_split_policy.hpp
@@ -0,0 +1,46 @@
+/**
+ * @file r_plus_plus_tree_split_policy.hpp
+ * @author Mikhail Lozhnikov
+ *
+ * Defintion and implementation of the RPlusPlusTreeSplitPolicy class, a class that
+ * helps to determine the node into which we should insert an intermediate node.
+ */
+#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_SPLIT_POLICY_HPP
+#define MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_SPLIT_POLICY_HPP
+
+namespace mlpack {
+namespace tree {
+
+class RPlusPlusTreeSplitPolicy
+{
+ public:
+  static const int SplitRequired = 0;
+  static const int AssignToFirstTree = 1;
+  static const int AssignToSecondTree = 2;
+
+  template<typename TreeType>
+  static int GetSplitPolicy(const TreeType* child, size_t axis,
+      typename TreeType::ElemType cut)
+  {
+    if (child->AuxiliaryInfo().OuterBound()[axis].Hi() <= cut)
+      return AssignToFirstTree;
+    else if (child->AuxiliaryInfo().OuterBound()[axis].Lo() >= cut)
+      return AssignToSecondTree;
+
+    return SplitRequired;
+  }
+
+  template<typename TreeType>
+  static const
+      bound::HRectBound<metric::EuclideanDistance, typename TreeType::ElemType>&
+          Bound(const TreeType* node)
+  {
+    return node->AuxiliaryInfo().OuterBound();
+  }
+};
+
+} //  namespace tree
+} //  namespace mlpack
+#endif //  MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_SPLIT_POLICY_HPP
+
+
diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split.hpp
index f06b813..8c2d056 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split.hpp
@@ -15,6 +15,7 @@ const double fillFactorFraction = 0.5;
 namespace mlpack {
 namespace tree /** Trees and tree-building procedures. */ {
 
+template<typename SplitPolicyType>
 class RPlusTreeSplit
 {
  public:
@@ -63,6 +64,9 @@ class RPlusTreeSplit
       TreeType* treeOne, TreeType* treeTwo, size_t cutAxis, double cut);
 
   template<typename TreeType>
+  static void AddFakeNodes(const TreeType* tree, TreeType* emptyTree);
+
+  template<typename TreeType>
   static bool PartitionNode(const TreeType* node, size_t fillFactor,
       size_t& minCutAxis, double& minCut);
 
diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp
index fca51fc..e484872 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp
@@ -9,12 +9,16 @@
 
 #include "r_plus_tree_split.hpp"
 #include "rectangle_tree.hpp"
+#include "r_plus_plus_tree_auxiliary_information.hpp"
+#include "r_plus_tree_split_policy.hpp"
+#include "r_plus_plus_tree_split_policy.hpp"
 
 namespace mlpack {
 namespace tree {
 
+template<typename SplitPolicyType>
 template<typename TreeType>
-void RPlusTreeSplit::SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
+void RPlusTreeSplit<SplitPolicyType>::SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
 {
   if (tree->Count() == 1)
   {
@@ -88,8 +92,9 @@ void RPlusTreeSplit::SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
   tree->SoftDelete();
 }
 
+template<typename SplitPolicyType>
 template<typename TreeType>
-bool RPlusTreeSplit::SplitNonLeafNode(TreeType* tree,
+bool RPlusTreeSplit<SplitPolicyType>::SplitNonLeafNode(TreeType* tree,
     std::vector<bool>& relevels)
 {
   // If we are splitting the root node, we need will do things differently so
@@ -148,10 +153,13 @@ bool RPlusTreeSplit::SplitNonLeafNode(TreeType* tree,
   return false;
 }
 
+template<typename SplitPolicyType>
 template<typename TreeType>
-void RPlusTreeSplit::SplitLeafNodeAlongPartition(TreeType* tree,
+void RPlusTreeSplit<SplitPolicyType>::SplitLeafNodeAlongPartition(TreeType* tree,
   TreeType* treeOne, TreeType* treeTwo, size_t cutAxis, double cut)
 {
+  tree->AuxiliaryInfo().SplitAuxiliaryInfo(treeOne, treeTwo, cutAxis, cut);
+
   for (size_t i = 0; i < tree->NumPoints(); i++)
   {
     if (tree->Dataset().col(tree->Point(i))[cutAxis] <= cut)
@@ -172,19 +180,24 @@ void RPlusTreeSplit::SplitLeafNodeAlongPartition(TreeType* tree,
   assert(treeOne->Bound()[cutAxis].Hi() < treeTwo->Bound()[cutAxis].Lo());
 }
 
+template<typename SplitPolicyType>
 template<typename TreeType>
-void RPlusTreeSplit::SplitNonLeafNodeAlongPartition(TreeType* tree,
+void RPlusTreeSplit<SplitPolicyType>::SplitNonLeafNodeAlongPartition(TreeType* tree,
   TreeType* treeOne, TreeType* treeTwo, size_t cutAxis, double cut)
 {
+  tree->AuxiliaryInfo().SplitAuxiliaryInfo(treeOne, treeTwo, cutAxis, cut);
+
   for (size_t i = 0; i < tree->NumChildren(); i++)
   {
     TreeType* child = tree->Children()[i];
-    if (child->Bound()[cutAxis].Hi() <= cut)
+    int policy = SplitPolicyType::GetSplitPolicy(child, cutAxis, cut);
+
+    if (policy == SplitPolicyType::AssignToFirstTree)
     {
       InsertNodeIntoTree(treeOne, child);
       child->Parent() = treeOne;
     }
-    else if (child->Bound()[cutAxis].Lo() >= cut)
+    else if (policy == SplitPolicyType::AssignToSecondTree)
     {
       InsertNodeIntoTree(treeTwo, child);
       child->Parent() = treeTwo;
@@ -209,12 +222,46 @@ void RPlusTreeSplit::SplitNonLeafNodeAlongPartition(TreeType* tree,
       child->SoftDelete();
     }
   }
+
+  assert(treeOne->NumChildren() + treeTwo->NumChildren() != 0);
+
+  if (treeOne->NumChildren() == 0)
+    AddFakeNodes(treeTwo, treeOne);
+  else if (treeTwo->NumChildren() == 0)
+    AddFakeNodes(treeOne, treeTwo);
+
   assert(treeOne->NumChildren() <= treeOne->MaxNumChildren());
   assert(treeTwo->NumChildren() <= treeTwo->MaxNumChildren());
 }
 
+template<typename SplitPolicyType>
 template<typename TreeType>
-bool RPlusTreeSplit::CheckNonLeafSweep(const TreeType* node,
+void RPlusTreeSplit<SplitPolicyType>::
+AddFakeNodes(const TreeType* tree, TreeType* emptyTree)
+{
+  size_t numDescendantNodes = 1;
+
+  TreeType* node = tree->Children()[0];
+
+  while (!node->IsLeaf())
+  {
+    numDescendantNodes++;
+    node = node->Children()[0];
+  }
+
+  node = emptyTree;
+  for (size_t i = 0; i < numDescendantNodes; i++)
+  {
+    TreeType* child = new TreeType(node);
+
+    node = child;
+  }
+}
+
+
+template<typename SplitPolicyType>
+template<typename TreeType>
+bool RPlusTreeSplit<SplitPolicyType>::CheckNonLeafSweep(const TreeType* node,
     size_t cutAxis, double cut)
 {
   size_t numTreeOneChildren = 0;
@@ -223,9 +270,10 @@ bool RPlusTreeSplit::CheckNonLeafSweep(const TreeType* node,
   for (size_t i = 0; i < node->NumChildren(); i++)
   {
     TreeType* child = node->Children()[i];
-    if (child->Bound()[cutAxis].Hi() <= cut)
+    int policy = SplitPolicyType::GetSplitPolicy(child, cutAxis, cut);
+    if (policy == SplitPolicyType::AssignToFirstTree)
       numTreeOneChildren++;
-    else if (child->Bound()[cutAxis].Lo() >= cut)
+    else if (policy == SplitPolicyType::AssignToSecondTree)
       numTreeTwoChildren++;
     else
     {
@@ -240,8 +288,9 @@ bool RPlusTreeSplit::CheckNonLeafSweep(const TreeType* node,
   return false;
 }
 
+template<typename SplitPolicyType>
 template<typename TreeType>
-bool RPlusTreeSplit::CheckLeafSweep(const TreeType* node,
+bool RPlusTreeSplit<SplitPolicyType>::CheckLeafSweep(const TreeType* node,
     size_t cutAxis, double cut)
 {
   size_t numTreeOnePoints = 0;
@@ -261,8 +310,9 @@ bool RPlusTreeSplit::CheckLeafSweep(const TreeType* node,
   return false;
 }
 
+template<typename SplitPolicyType>
 template<typename TreeType>
-bool RPlusTreeSplit::PartitionNode(const TreeType* node, size_t fillFactor,
+bool RPlusTreeSplit<SplitPolicyType>::PartitionNode(const TreeType* node, size_t fillFactor,
     size_t& minCutAxis, double& minCut)
 {
   if ((node->NumChildren() <= fillFactor && !node->IsLeaf()) ||
@@ -293,8 +343,9 @@ bool RPlusTreeSplit::PartitionNode(const TreeType* node, size_t fillFactor,
   return true;
 }
 
+template<typename SplitPolicyType>
 template<typename TreeType>
-double RPlusTreeSplit::SweepNonLeafNode(size_t axis, const TreeType* node,
+double RPlusTreeSplit<SplitPolicyType>::SweepNonLeafNode(size_t axis, const TreeType* node,
     size_t fillFactor, double& axisCut)
 {
   typedef typename TreeType::ElemType ElemType;
@@ -303,15 +354,27 @@ double RPlusTreeSplit::SweepNonLeafNode(size_t axis, const TreeType* node,
 
   for (size_t i = 0; i < node->NumChildren(); i++)
   {
-    sorted[i].d = node->Children()[i]->Bound()[axis].Hi();
+    sorted[i].d = SplitPolicyType::Bound(node->Children()[i])[axis].Hi();
     sorted[i].n = i;
   }
   std::sort(sorted.begin(), sorted.end(), StructComp<ElemType>);
 
-  axisCut = sorted[fillFactor - 1].d;
+  size_t splitPointer = fillFactor;
+
+  axisCut = sorted[splitPointer - 1].d;
 
   if (!CheckNonLeafSweep(node, axis, axisCut))
-    return std::numeric_limits<double>::max();
+  {
+    for (splitPointer = 1; splitPointer < node->NumChildren(); splitPointer++)
+    {
+      axisCut = sorted[splitPointer - 1].d;
+      if (CheckNonLeafSweep(node, axis, axisCut))
+        break;
+    }
+
+    if (splitPointer == node->NumChildren())
+      return std::numeric_limits<double>::max();
+  }
 
   std::vector<ElemType> lowerBound1(node->Bound().Dim());
   std::vector<ElemType> highBound1(node->Bound().Dim());
@@ -323,7 +386,7 @@ double RPlusTreeSplit::SweepNonLeafNode(size_t axis, const TreeType* node,
     lowerBound1[k] = node->Children()[sorted[0].n]->Bound()[k].Lo();
     highBound1[k] = node->Children()[sorted[0].n]->Bound()[k].Hi();
 
-    for (size_t i = 1; i < fillFactor; i++)
+    for (size_t i = 1; i < splitPointer; i++)
     {
       if (node->Children()[sorted[i].n]->Bound()[k].Lo() < lowerBound1[k])
         lowerBound1[k] = node->Children()[sorted[i].n]->Bound()[k].Lo();
@@ -331,10 +394,10 @@ double RPlusTreeSplit::SweepNonLeafNode(size_t axis, const TreeType* node,
         highBound1[k] = node->Children()[sorted[i].n]->Bound()[k].Hi();
     }
 
-    lowerBound2[k] = node->Children()[sorted[fillFactor].n]->Bound()[k].Lo();
-    highBound2[k] = node->Children()[sorted[fillFactor].n]->Bound()[k].Hi();
+    lowerBound2[k] = node->Children()[sorted[splitPointer].n]->Bound()[k].Lo();
+    highBound2[k] = node->Children()[sorted[splitPointer].n]->Bound()[k].Hi();
 
-    for (size_t i = fillFactor + 1; i < node->NumChildren(); i++)
+    for (size_t i = splitPointer + 1; i < node->NumChildren(); i++)
     {
       if (node->Children()[sorted[i].n]->Bound()[k].Lo() < lowerBound2[k])
         lowerBound2[k] = node->Children()[sorted[i].n]->Bound()[k].Lo();
@@ -348,21 +411,38 @@ double RPlusTreeSplit::SweepNonLeafNode(size_t axis, const TreeType* node,
 
   for (size_t k = 0; k < node->Bound().Dim(); k++)
   {
-    area1 *= highBound1[k] - lowerBound1[k];
-    area2 *= highBound2[k] - lowerBound2[k];
+    if (lowerBound1[k] >= highBound1[k])
+    {
+      overlappedArea *= 0;
+      area1 *= 0;
+    }
+    else
+      area1 *= highBound1[k] - lowerBound1[k];
 
-    if (lowerBound1[k] > highBound2[k] || lowerBound2[k] > highBound2[k])
+    if (lowerBound2[k] >= highBound2[k])
+    {
       overlappedArea *= 0;
+      area1 *= 0;
+    }
     else
-      overlappedArea *= std::min(highBound1[k], highBound2[k]) -
-          std::max(lowerBound1[k], lowerBound2[k]);
+      area2 *= highBound2[k] - lowerBound2[k];
+
+    if (lowerBound1[k] < highBound1[k] && lowerBound2[k] < highBound2[k])
+    {
+      if (lowerBound1[k] > highBound2[k] || lowerBound2[k] > highBound2[k])
+        overlappedArea *= 0;
+      else
+        overlappedArea *= std::min(highBound1[k], highBound2[k]) -
+            std::max(lowerBound1[k], lowerBound2[k]);
+    }
   }
 
   return area1 + area2 - overlappedArea;
 }
 
+template<typename SplitPolicyType>
 template<typename TreeType>
-double RPlusTreeSplit::SweepLeafNode(size_t axis, const TreeType* node,
+double RPlusTreeSplit<SplitPolicyType>::SweepLeafNode(size_t axis, const TreeType* node,
     size_t fillFactor, double& axisCut)
 {
   typedef typename TreeType::ElemType ElemType;
@@ -426,8 +506,9 @@ double RPlusTreeSplit::SweepLeafNode(size_t axis, const TreeType* node,
   return area1 + area2 - overlappedArea;
 }
 
+template<typename SplitPolicyType>
 template<typename TreeType>
-void RPlusTreeSplit::
+void RPlusTreeSplit<SplitPolicyType>::
 InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode)
 {
   destTree->Bound() |= srcNode->Bound();
diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_policy.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_policy.hpp
new file mode 100644
index 0000000..219eb1a
--- /dev/null
+++ b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_policy.hpp
@@ -0,0 +1,46 @@
+/**
+ * @file r_plus_tree_split_policy.hpp
+ * @author Mikhail Lozhnikov
+ *
+ * Defintion and implementation of the RPlusTreeSplitPolicy class, a class that
+ * helps to determine the node into which we should insert an intermediate node.
+ */
+#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_SPLIT_POLICY_HPP
+#define MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_SPLIT_POLICY_HPP
+
+namespace mlpack {
+namespace tree {
+
+class RPlusTreeSplitPolicy
+{
+ public:
+  static const int SplitRequired = 0;
+  static const int AssignToFirstTree = 1;
+  static const int AssignToSecondTree = 2;
+
+  template<typename TreeType>
+  static int GetSplitPolicy(const TreeType* child, size_t axis,
+      typename TreeType::ElemType cut)
+  {
+    if (child->Bound()[axis].Hi() <= cut)
+      return AssignToFirstTree;
+    else if (child->Bound()[axis].Lo() >= cut)
+      return AssignToSecondTree;
+
+    return SplitRequired;
+  }
+
+  template<typename TreeType>
+  static const
+      bound::HRectBound<metric::EuclideanDistance, typename TreeType::ElemType>&
+          Bound(const TreeType* node)
+  {
+    return node->Bound();
+  }
+};
+
+} //  namespace tree
+} //  namespace mlpack
+#endif //  MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_SPLIT_POLICY_HPP
+
+
diff --git a/src/mlpack/core/tree/rectangle_tree/typedef.hpp b/src/mlpack/core/tree/rectangle_tree/typedef.hpp
index 4737d5a..b22c3b1 100644
--- a/src/mlpack/core/tree/rectangle_tree/typedef.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/typedef.hpp
@@ -129,10 +129,17 @@ template<typename MetricType, typename StatisticType, typename MatType>
 using RPlusTree = RectangleTree<MetricType,
                             StatisticType,
                             MatType,
-                            RPlusTreeSplit,
+                            RPlusTreeSplit<RPlusTreeSplitPolicy>,
                             RPlusTreeDescentHeuristic,
                             NoAuxiliaryInformation>;
 
+template<typename MetricType, typename StatisticType, typename MatType>
+using RPlusPlusTree = RectangleTree<MetricType,
+                            StatisticType,
+                            MatType,
+                            RPlusTreeSplit<RPlusPlusTreeSplitPolicy>,
+                            RPlusPlusTreeDescentHeuristic,
+                            RPlusPlusTreeAuxiliaryInformation>;
 } // namespace tree
 } // namespace mlpack
 
diff --git a/src/mlpack/tests/rectangle_tree_test.cpp b/src/mlpack/tests/rectangle_tree_test.cpp
index 0845147..3568a78 100644
--- a/src/mlpack/tests/rectangle_tree_test.cpp
+++ b/src/mlpack/tests/rectangle_tree_test.cpp
@@ -789,10 +789,10 @@ void CheckOverlap(TreeType* tree)
       if (!success)
         break;
     }
-    if (success)
+    if (!success)
       break;
   }
-  assert(success == true);
+  BOOST_REQUIRE_EQUAL(success, true);
 
   for (size_t i = 0; i < tree->NumChildren(); i++)
     CheckOverlap(tree->Children()[i]);
@@ -853,6 +853,119 @@ BOOST_AUTO_TEST_CASE(RPlusTreeTraverserTest)
   }
 }
 
+template<typename TreeType>
+void CheckRPlusPlusTreeBound(const TreeType* tree)
+{
+  typedef bound::HRectBound<metric::EuclideanDistance,
+      typename TreeType::ElemType> Bound;
+
+  bool success = true;
+
+  for (size_t k = 0; k < tree->Bound().Dim(); k++)
+  {
+    BOOST_REQUIRE_LE(tree->Bound()[k].Hi(),
+        tree->AuxiliaryInfo().OuterBound()[k].Hi());
+    BOOST_REQUIRE_LE(tree->AuxiliaryInfo().OuterBound()[k].Lo(),
+        tree->Bound()[k].Lo());
+  }
+
+  if (tree->IsLeaf())
+  {
+    for (size_t i = 0; i < tree->Count(); i++)
+      BOOST_REQUIRE_EQUAL(true,
+          tree->Bound().Contains(tree->Dataset().col(tree->Points()[i])));
+
+    return;
+  }
+
+  for (size_t i = 0; i < tree->NumChildren(); i++)
+  {
+    const Bound& bound1 = tree->Children()[i]->AuxiliaryInfo().OuterBound();
+    success = true;
+
+    for (size_t j = 0; j < tree->NumChildren(); j++)
+    {
+      if (j == i)
+        continue;
+      const Bound& bound2 = tree->Children()[j]->AuxiliaryInfo().OuterBound();
+
+      success = false;
+      for (size_t k = 0; k < tree->Bound().Dim(); k++)
+      {
+        if (bound1[k].Lo() >= bound2[k].Hi() ||
+            bound2[k].Lo() >= bound1[k].Hi())
+        {
+          success = true;
+          break;
+        }
+      }
+      if (!success)
+        break;
+    }
+    if (!success)
+      break;
+  }
+  BOOST_REQUIRE_EQUAL(success, true);
+
+  for (size_t i = 0; i < tree->NumChildren(); i++)
+    CheckRPlusPlusTreeBound(tree->Children()[i]);
+}
+
+BOOST_AUTO_TEST_CASE(RPlusPlusTreeBoundTest)
+{
+  arma::mat dataset;
+  dataset.randu(8, 1000); // 1000 points in 8 dimensions.
+
+  typedef RPlusPlusTree<EuclideanDistance,
+      NeighborSearchStat<NearestNeighborSort>,arma::mat> TreeType;
+  TreeType rPlusPlusTree(dataset, 20, 6, 5, 2, 0);
+
+  CheckRPlusPlusTreeBound(&rPlusPlusTree);
+}
+
+BOOST_AUTO_TEST_CASE(RPlusPlusTreeTraverserTest)
+{
+  arma::mat dataset;
+
+  const int numP = 1000;
+
+  dataset.randu(8, numP); // 1000 points in 8 dimensions.
+  arma::Mat<size_t> neighbors1;
+  arma::mat distances1;
+  arma::Mat<size_t> neighbors2;
+  arma::mat distances2;
+
+  typedef RPlusPlusTree<EuclideanDistance, NeighborSearchStat<NearestNeighborSort>,
+      arma::mat> TreeType;
+  TreeType rPlusPlusTree(dataset, 20, 6, 5, 2, 0);
+
+  // Nearest neighbor search with the X tree.
+
+  NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
+      arma::mat, RPlusPlusTree > knn1(&rPlusPlusTree, true);
+
+  BOOST_REQUIRE_EQUAL(rPlusPlusTree.NumDescendants(), numP);
+
+  CheckContainment(rPlusPlusTree);
+  CheckExactContainment(rPlusPlusTree);
+  CheckHierarchy(rPlusPlusTree);
+  CheckRPlusPlusTreeBound(&rPlusPlusTree);
+
+  knn1.Search(5, neighbors1, distances1);
+
+  // Nearest neighbor search the naive way.
+  KNN knn2(dataset, true, true);
+
+  knn2.Search(5, neighbors2, distances2);
+
+  for (size_t i = 0; i < neighbors1.size(); i++)
+  {
+    BOOST_REQUIRE_EQUAL(neighbors1[i], neighbors2[i]);
+    BOOST_REQUIRE_EQUAL(distances1[i], distances2[i]);
+  }
+}
+
+
 // Test the tree splitting.  We set MaxLeafSize and MaxNumChildren rather low
 // to allow us to test by hand without adding hundreds of points.
 BOOST_AUTO_TEST_CASE(RTreeSplitTest)




More information about the mlpack-git mailing list