[mlpack-git] master: R+ tree implementation (b8da9a9)

gitdub at mlpack.org gitdub at mlpack.org
Thu Jul 7 17:30:32 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/6147ed01bab6eadcd6a5e796e259a6afacae4662...e0fd69006b17a845f066ea4de1e205fc0922739d

>---------------------------------------------------------------

commit b8da9a9c01630b5455ac31347f69865b85bce370
Author: Mikhail Lozhnikov <lozhnikovma at gmail.com>
Date:   Thu Jun 16 05:45:59 2016 +0300

    R+ tree implementation


>---------------------------------------------------------------

b8da9a9c01630b5455ac31347f69865b85bce370
 src/mlpack/core/tree/CMakeLists.txt                |   4 +
 src/mlpack/core/tree/rectangle_tree.hpp            |   2 +
 .../rectangle_tree/hilbert_r_tree_split_impl.hpp   |   2 +
 ...istic.hpp => r_plus_tree_descent_heuristic.hpp} |  17 +-
 .../r_plus_tree_descent_heuristic_impl.hpp         |  96 +++++
 .../core/tree/rectangle_tree/r_plus_tree_split.hpp |  95 +++++
 .../tree/rectangle_tree/r_plus_tree_split_impl.hpp | 441 +++++++++++++++++++++
 .../tree/rectangle_tree/r_star_tree_split_impl.hpp |   3 +
 .../core/tree/rectangle_tree/r_tree_split_impl.hpp |   2 +
 .../tree/rectangle_tree/rectangle_tree_impl.hpp    |   4 +-
 src/mlpack/core/tree/rectangle_tree/typedef.hpp    |   7 +
 .../core/tree/rectangle_tree/x_tree_split_impl.hpp |   3 +
 src/mlpack/tests/rectangle_tree_test.cpp           |  90 +++++
 13 files changed, 756 insertions(+), 10 deletions(-)

diff --git a/src/mlpack/core/tree/CMakeLists.txt b/src/mlpack/core/tree/CMakeLists.txt
index 28415d5..527dd59 100644
--- a/src/mlpack/core/tree/CMakeLists.txt
+++ b/src/mlpack/core/tree/CMakeLists.txt
@@ -63,6 +63,10 @@ set(SOURCES
   rectangle_tree/recursive_hilbert_value_impl.hpp
   rectangle_tree/discrete_hilbert_value.hpp
   rectangle_tree/discrete_hilbert_value_impl.hpp
+  rectangle_tree/r_plus_tree_descent_heuristic.hpp
+  rectangle_tree/r_plus_tree_descent_heuristic_impl.hpp
+  rectangle_tree/r_plus_tree_split.hpp
+  rectangle_tree/r_plus_tree_split_impl.hpp
   statistic.hpp
   traversal_info.hpp
   tree_traits.hpp
diff --git a/src/mlpack/core/tree/rectangle_tree.hpp b/src/mlpack/core/tree/rectangle_tree.hpp
index de236ad..a28cd9f 100644
--- a/src/mlpack/core/tree/rectangle_tree.hpp
+++ b/src/mlpack/core/tree/rectangle_tree.hpp
@@ -30,6 +30,8 @@
 #include "rectangle_tree/hilbert_r_tree_auxiliary_information.hpp"
 #include "rectangle_tree/recursive_hilbert_value.hpp"
 #include "rectangle_tree/discrete_hilbert_value.hpp"
+#include "rectangle_tree/r_plus_tree_descent_heuristic.hpp"
+#include "rectangle_tree/r_plus_tree_split.hpp"
 #include "rectangle_tree/typedef.hpp"
 
 #endif
diff --git a/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_split_impl.hpp
index fd39961..0d4ed5f 100644
--- a/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_split_impl.hpp
@@ -18,6 +18,8 @@ template<typename TreeType>
 void HilbertRTreeSplit::
 SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
 {
+  if (tree->Count() <= tree->MaxLeafSize())
+    return;
   // If we are splitting the root node, we need will do things differently so
   // that the constructor and other methods don't confuse the end user by giving
   // an address of another node.
diff --git a/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_descent_heuristic.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic.hpp
similarity index 65%
copy from src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_descent_heuristic.hpp
copy to src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic.hpp
index c83532a..dfe8e0a 100644
--- a/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_descent_heuristic.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic.hpp
@@ -1,19 +1,19 @@
 /**
- * @file hilbert_r_tree_descent_heuristic.hpp
+ * @file r_plus_tree_descent_heuristic.hpp
  * @author Mikhail Lozhnikov
  *
- * Definition of HilbertRTreeDescentHeuristic, a class that chooses the best child of a
+ * Definition of RPlusTreeDescentHeuristic, a class that chooses the best child of a
  * node in an R tree when inserting a new point.
  */
-#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_HILBERT_R_TREE_DESCENT_HEURISTIC_HPP
-#define MLPACK_CORE_TREE_RECTANGLE_TREE_HILBERT_R_TREE_DESCENT_HEURISTIC_HPP
+#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_DESCENT_HEURISTIC_HPP
+#define MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_DESCENT_HEURISTIC_HPP
 
 #include <mlpack/core.hpp>
 
 namespace mlpack {
 namespace tree {
 
-class HilbertRTreeDescentHeuristic
+class RPlusTreeDescentHeuristic
 {
  public:
   /**
@@ -25,7 +25,7 @@ class HilbertRTreeDescentHeuristic
    * @param point The number of the point that is being inserted.
    */
   template<typename TreeType>
-  static size_t ChooseDescentNode(const TreeType* node, const size_t point);
+  static size_t ChooseDescentNode(TreeType* node, const size_t point);
 
   /**
    * Evaluate the node using a heuristic. Returns the number of the node
@@ -40,9 +40,10 @@ class HilbertRTreeDescentHeuristic
                                   const TreeType* insertedNode);
 
 };
+
 } //  namespace tree
 } //  namespace mlpack
 
-#include "hilbert_r_tree_descent_heuristic_impl.hpp"
+#include "r_plus_tree_descent_heuristic_impl.hpp"
 
-#endif  //  MLPACK_CORE_TREE_RECTANGLE_TREE_HILBERT_R_TREE_DESCENT_HEURISTIC_HPP
+#endif  //  MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_DESCENT_HEURISTIC_HPP
diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic_impl.hpp
new file mode 100644
index 0000000..265b739
--- /dev/null
+++ b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic_impl.hpp
@@ -0,0 +1,96 @@
+/**
+ * @file hilbert_r_tree_descent_heuristic_impl.hpp
+ * @author Mikhail Lozhnikov
+ *
+ * Implementation of HilbertRTreeDescentHeuristic, a class that chooses the best child
+ * of a node in an R tree when inserting a new point.
+ */
+#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_DESCENT_HEURISTIC_IMPL_HPP
+#define MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_DESCENT_HEURISTIC_IMPL_HPP
+
+#include "r_plus_tree_descent_heuristic.hpp"
+#include "../hrectbound.hpp"
+
+namespace mlpack {
+namespace tree {
+
+template<typename TreeType>
+size_t RPlusTreeDescentHeuristic::
+ChooseDescentNode(TreeType* node, const size_t point)
+{
+  typedef typename TreeType::ElemType ElemType;
+  size_t bestIndex = 0;
+  bool success;
+
+  for (bestIndex = 0; bestIndex < node->NumChildren(); bestIndex++)
+  {
+    if (node->Children()[bestIndex]->Bound().Contains(node->Dataset().col(point)))
+      return bestIndex;
+  }
+
+  for (bestIndex = 0; bestIndex < node->NumChildren(); bestIndex++)
+  {
+    bound::HRectBound<metric::EuclideanDistance, ElemType> bound =
+        node->Children()[bestIndex]->Bound();
+    bound |=  node->Dataset().col(point);
+
+    success = true;
+
+    for (size_t j = 0; j < node->NumChildren(); j++)
+    {
+      if (j == bestIndex)
+        continue;
+      success = false;
+      for (size_t k = 0; k < node->Bound().Dim(); k++)
+      {
+        if (bound[k].Lo() >= node->Children()[j]->Bound()[k].Hi() ||
+            node->Children()[j]->Bound()[k].Lo() >= bound[k].Hi())
+        {
+          success = true;
+          break;
+        }
+      }
+      if (!success)
+        break;
+    }
+    if (success)
+      break;
+  }
+
+  if (!success)
+  {
+    size_t depth = node->TreeDepth();
+
+    TreeType* tree = node;
+    while (depth > 1)
+    {
+      TreeType* child = new TreeType(node);
+
+      tree->Children()[tree->NumChildren()++] = child;
+      tree = child;
+      depth--;
+    }
+    return node->NumChildren()-1;
+  }
+
+  assert(bestIndex < node->NumChildren());
+
+  return bestIndex;
+}
+
+template<typename TreeType>
+size_t RPlusTreeDescentHeuristic::
+ChooseDescentNode(const TreeType* node, const TreeType* insertedNode)
+{
+  size_t bestIndex = 0;
+
+  assert(false);
+
+  return bestIndex;
+}
+
+
+} //  namespace tree
+} //  namespace mlpack
+
+#endif  //MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_DESCENT_HEURISTIC_IMPL_HPP
diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split.hpp
new file mode 100644
index 0000000..f06b813
--- /dev/null
+++ b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split.hpp
@@ -0,0 +1,95 @@
+/**
+ * @file r_plus_tree_split.hpp
+ * @author Mikhail Lozhnikov
+ *
+ * Defintion of the RPlusTreeSplit class, a class that splits the nodes of an R
+ * tree, starting at a leaf node and moving upwards if necessary.
+ */
+#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_SPLIT_HPP
+#define MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_SPLIT_HPP
+
+#include <mlpack/core.hpp>
+
+const double fillFactorFraction = 0.5;
+
+namespace mlpack {
+namespace tree /** Trees and tree-building procedures. */ {
+
+class RPlusTreeSplit
+{
+ public:
+  /**
+   * Split a leaf node using the "default" algorithm.  If necessary, this split
+   * will propagate upwards through the tree.
+   * @param node. The node that is being split.
+   * @param relevels Not used.
+   */
+  template<typename TreeType>
+  static void SplitLeafNode(TreeType *tree,std::vector<bool>& relevels);
+
+  /**
+   * Split a non-leaf node using the "default" algorithm.  If this is a root
+   * node, the tree increases in depth.
+   * @param node. The node that is being split.
+   * @param relevels Not used.
+   */
+  template<typename TreeType>
+  static bool SplitNonLeafNode(TreeType *tree,std::vector<bool>& relevels);
+
+
+
+ private:
+
+  template<typename ElemType>
+  struct SortStruct
+  {
+    ElemType d;
+    int n;
+  };
+
+  template<typename ElemType>
+  static bool StructComp(const SortStruct<ElemType>& s1,
+                         const SortStruct<ElemType>& s2)
+  {
+    return s1.d < s2.d;
+  }
+
+  template<typename TreeType>
+  static void SplitLeafNodeAlongPartition(TreeType* tree,
+      TreeType* treeOne, TreeType* treeTwo, size_t cutAxis, double cut);
+
+  template<typename TreeType>
+  static void SplitNonLeafNodeAlongPartition(TreeType* tree,
+      TreeType* treeOne, TreeType* treeTwo, size_t cutAxis, double cut);
+
+  template<typename TreeType>
+  static bool PartitionNode(const TreeType* node, size_t fillFactor,
+      size_t& minCutAxis, double& minCut);
+
+  template<typename TreeType>
+  static double SweepLeafNode(size_t axis, const TreeType* node,
+      size_t fillFactor, double& axisCut);
+
+  template<typename TreeType>
+  static double SweepNonLeafNode(size_t axis, const TreeType* node,
+      size_t fillFactor, double& axisCut);
+
+  template<typename TreeType>
+  static void InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode);
+
+  template<typename TreeType>
+  static bool CheckNonLeafSweep(const TreeType* node,
+      size_t cutAxis, double cut);
+
+  template<typename TreeType>
+  static bool CheckLeafSweep(const TreeType* node, size_t cutAxis, double cut);
+};
+
+} // namespace tree
+} // namespace mlpack
+
+// Include implementation
+#include "r_plus_tree_split_impl.hpp"
+
+#endif  //  MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_SPLIT_HPP
+
diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp
new file mode 100644
index 0000000..fca51fc
--- /dev/null
+++ b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp
@@ -0,0 +1,441 @@
+/**
+ * @file r_plus_tree_split_impl.hpp
+ * @author Mikhail Lozhnikov
+ *
+ * Implementation of class (RPlusTreeSplit) to split a RectangleTree.
+ */
+#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_SPLIT_IMPL_HPP
+#define MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_SPLIT_IMPL_HPP
+
+#include "r_plus_tree_split.hpp"
+#include "rectangle_tree.hpp"
+
+namespace mlpack {
+namespace tree {
+
+template<typename TreeType>
+void RPlusTreeSplit::SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
+{
+  if (tree->Count() == 1)
+  {
+    TreeType* node = tree->Parent();
+
+    while (node != NULL)
+    {
+      if (node->NumChildren() == node->MaxNumChildren() + 1)
+      {
+        RPlusTreeSplit::SplitNonLeafNode(node,relevels);
+        return;
+      }
+      node = node->Parent();
+    }
+    return;
+  }
+  else if (tree->Count() <= tree->MaxLeafSize())
+    return;
+  // If we are splitting the root node, we need will do things differently so
+  // that the constructor and other methods don't confuse the end user by giving
+  // an address of another node.
+  if (tree->Parent() == NULL)
+  {
+    // We actually want to copy this way.  Pointers and everything.
+    TreeType* copy = new TreeType(*tree, false);
+    copy->Parent() = tree;
+    tree->Count() = 0;
+    tree->NullifyData();
+    // Because this was a leaf node, numChildren must be 0.
+    tree->Children()[(tree->NumChildren())++] = copy;
+    assert(tree->NumChildren() == 1);
+
+    RPlusTreeSplit::SplitLeafNode(copy,relevels);
+    return;
+  }
+
+  const size_t fillFactor = tree->MaxLeafSize() * fillFactorFraction;
+  size_t cutAxis;
+  double cut;
+
+  if ( !PartitionNode(tree, fillFactor, cutAxis, cut))
+    return;
+
+  assert(cutAxis < tree->Bound().Dim());
+
+  TreeType* treeOne = new TreeType(tree->Parent());
+  TreeType* treeTwo = new TreeType(tree->Parent());
+  treeOne->MinLeafSize() = 0;
+  treeOne->MinNumChildren() = 0;
+  treeTwo->MinLeafSize() = 0;
+  treeTwo->MinNumChildren() = 0;
+
+  SplitLeafNodeAlongPartition(tree, treeOne, treeTwo, cutAxis, cut);
+
+  TreeType* parent = tree->Parent();
+  size_t i = 0;
+  while (parent->Children()[i] != tree)
+    i++;
+
+  assert(i < parent->NumChildren());
+
+  parent->Children()[i] = parent->Children()[--parent->NumChildren()];
+
+  InsertNodeIntoTree(parent, treeOne);
+  InsertNodeIntoTree(parent, treeTwo);
+
+  assert(parent->NumChildren() <= parent->MaxNumChildren() + 1);
+  if (parent->NumChildren() == parent->MaxNumChildren() + 1)
+    RPlusTreeSplit::SplitNonLeafNode(parent, relevels);
+
+  tree->SoftDelete();
+}
+
+template<typename TreeType>
+bool RPlusTreeSplit::SplitNonLeafNode(TreeType* tree,
+    std::vector<bool>& relevels)
+{
+  // If we are splitting the root node, we need will do things differently so
+  // that the constructor and other methods don't confuse the end user by giving
+  // an address of another node.
+  if (tree->Parent() == NULL)
+  {
+    // We actually want to copy this way.  Pointers and everything.
+    TreeType* copy = new TreeType(*tree, false);
+
+    copy->Parent() = tree;
+    tree->NumChildren() = 0;
+    tree->NullifyData();
+    tree->Children()[(tree->NumChildren())++] = copy;
+
+    RPlusTreeSplit::SplitNonLeafNode(copy,relevels);
+    return true;
+  }
+  const size_t fillFactor = tree->MaxNumChildren() * fillFactorFraction;
+  size_t cutAxis;
+  double cut;
+
+  if ( !PartitionNode(tree, fillFactor, cutAxis, cut))
+    return false;
+
+  assert(cutAxis < tree->Bound().Dim());
+
+  TreeType* treeOne = new TreeType(tree->Parent());
+  TreeType* treeTwo = new TreeType(tree->Parent());
+  treeOne->MinLeafSize() = 0;
+  treeOne->MinNumChildren() = 0;
+  treeTwo->MinLeafSize() = 0;
+  treeTwo->MinNumChildren() = 0;
+
+  SplitNonLeafNodeAlongPartition(tree, treeOne, treeTwo, cutAxis, cut);
+
+  TreeType* parent = tree->Parent();
+  size_t i = 0;
+  while (parent->Children()[i] != tree)
+    i++;
+
+  assert(i < parent->NumChildren());
+
+  parent->Children()[i] = parent->Children()[--parent->NumChildren()];
+
+  InsertNodeIntoTree(parent, treeOne);
+  InsertNodeIntoTree(parent, treeTwo);
+
+  tree->SoftDelete();
+
+  assert(parent->NumChildren() <= parent->MaxNumChildren() + 1);
+  
+  if (parent->NumChildren() == parent->MaxNumChildren() + 1)
+    RPlusTreeSplit::SplitNonLeafNode(parent, relevels);
+
+  return false;
+}
+
+template<typename TreeType>
+void RPlusTreeSplit::SplitLeafNodeAlongPartition(TreeType* tree,
+  TreeType* treeOne, TreeType* treeTwo, size_t cutAxis, double cut)
+{
+  for (size_t i = 0; i < tree->NumPoints(); i++)
+  {
+    if (tree->Dataset().col(tree->Point(i))[cutAxis] <= cut)
+    {
+      treeOne->Points()[treeOne->Count()++] = tree->Point(i);
+      treeOne->Bound() |= tree->Dataset().col(tree->Point(i));
+    }
+    else
+    {
+      treeTwo->Points()[treeTwo->Count()++] = tree->Point(i);
+      treeTwo->Bound() |= tree->Dataset().col(tree->Point(i));
+    }
+  }
+  assert(treeOne->Count() <= treeOne->MaxLeafSize());
+  assert(treeTwo->Count() <= treeTwo->MaxLeafSize());
+  
+  assert(tree->Count() == treeOne->Count() + treeTwo->Count());
+  assert(treeOne->Bound()[cutAxis].Hi() < treeTwo->Bound()[cutAxis].Lo());
+}
+
+template<typename TreeType>
+void RPlusTreeSplit::SplitNonLeafNodeAlongPartition(TreeType* tree,
+  TreeType* treeOne, TreeType* treeTwo, size_t cutAxis, double cut)
+{
+  for (size_t i = 0; i < tree->NumChildren(); i++)
+  {
+    TreeType* child = tree->Children()[i];
+    if (child->Bound()[cutAxis].Hi() <= cut)
+    {
+      InsertNodeIntoTree(treeOne, child);
+      child->Parent() = treeOne;
+    }
+    else if (child->Bound()[cutAxis].Lo() >= cut)
+    {
+      InsertNodeIntoTree(treeTwo, child);
+      child->Parent() = treeTwo;
+    }
+    else
+    {
+      TreeType* childOne = new TreeType(treeOne);
+      TreeType* childTwo = new TreeType(treeTwo);
+      treeOne->MinLeafSize() = 0;
+      treeOne->MinNumChildren() = 0;
+      treeTwo->MinLeafSize() = 0;
+      treeTwo->MinNumChildren() = 0;
+
+      if (child->IsLeaf())
+        SplitLeafNodeAlongPartition(child, childOne, childTwo, cutAxis, cut);
+      else
+        SplitNonLeafNodeAlongPartition(child, childOne, childTwo, cutAxis, cut);
+
+      InsertNodeIntoTree(treeOne, childOne);
+      InsertNodeIntoTree(treeTwo, childTwo);
+
+      child->SoftDelete();
+    }
+  }
+  assert(treeOne->NumChildren() <= treeOne->MaxNumChildren());
+  assert(treeTwo->NumChildren() <= treeTwo->MaxNumChildren());
+}
+
+template<typename TreeType>
+bool RPlusTreeSplit::CheckNonLeafSweep(const TreeType* node,
+    size_t cutAxis, double cut)
+{
+  size_t numTreeOneChildren = 0;
+  size_t numTreeTwoChildren = 0;
+
+  for (size_t i = 0; i < node->NumChildren(); i++)
+  {
+    TreeType* child = node->Children()[i];
+    if (child->Bound()[cutAxis].Hi() <= cut)
+      numTreeOneChildren++;
+    else if (child->Bound()[cutAxis].Lo() >= cut)
+      numTreeTwoChildren++;
+    else
+    {
+      numTreeOneChildren++;
+      numTreeTwoChildren++;
+    }
+  }
+
+  if (numTreeOneChildren <= node->MaxNumChildren() && numTreeOneChildren > 0 &&
+      numTreeTwoChildren <= node->MaxNumChildren() && numTreeTwoChildren > 0)
+    return true;
+  return false;
+}
+
+template<typename TreeType>
+bool RPlusTreeSplit::CheckLeafSweep(const TreeType* node,
+    size_t cutAxis, double cut)
+{
+  size_t numTreeOnePoints = 0;
+  size_t numTreeTwoPoints = 0;
+
+  for (size_t i = 0; i < node->NumPoints(); i++)
+  {
+    if (node->Dataset().col(node->Point(i))[cutAxis] <= cut)
+      numTreeOnePoints++;
+    else
+      numTreeTwoPoints++;
+  }
+
+  if (numTreeOnePoints <= node->MaxLeafSize() && numTreeOnePoints > 0 &&
+      numTreeTwoPoints <= node->MaxLeafSize() && numTreeTwoPoints > 0)
+    return true;
+  return false;
+}
+
+template<typename TreeType>
+bool RPlusTreeSplit::PartitionNode(const TreeType* node, size_t fillFactor,
+    size_t& minCutAxis, double& minCut)
+{
+  if ((node->NumChildren() <= fillFactor && !node->IsLeaf()) ||
+      (node->Count() <= fillFactor && node->IsLeaf()))
+    return false;
+
+  double minCost = std::numeric_limits<double>::max();
+  minCutAxis = node->Bound().Dim();
+
+  for (size_t k = 0; k < node->Bound().Dim(); k++)
+  {
+    double cut;
+    double cost;
+
+    if (node->IsLeaf())
+      cost = SweepLeafNode(k, node, fillFactor, cut);
+    else
+      cost = SweepNonLeafNode(k, node, fillFactor, cut);
+    
+
+    if (cost < minCost)
+    {
+      minCost = cost;
+      minCutAxis = k;
+      minCut = cut;      
+    }
+  }
+  return true;
+}
+
+template<typename TreeType>
+double RPlusTreeSplit::SweepNonLeafNode(size_t axis, const TreeType* node,
+    size_t fillFactor, double& axisCut)
+{
+  typedef typename TreeType::ElemType ElemType;
+
+  std::vector<SortStruct<ElemType>> sorted(node->NumChildren());
+
+  for (size_t i = 0; i < node->NumChildren(); i++)
+  {
+    sorted[i].d = node->Children()[i]->Bound()[axis].Hi();
+    sorted[i].n = i;
+  }
+  std::sort(sorted.begin(), sorted.end(), StructComp<ElemType>);
+
+  axisCut = sorted[fillFactor - 1].d;
+
+  if (!CheckNonLeafSweep(node, axis, axisCut))
+    return std::numeric_limits<double>::max();
+
+  std::vector<ElemType> lowerBound1(node->Bound().Dim());
+  std::vector<ElemType> highBound1(node->Bound().Dim());
+  std::vector<ElemType> lowerBound2(node->Bound().Dim());
+  std::vector<ElemType> highBound2(node->Bound().Dim());
+
+  for (size_t k = 0; k < node->Bound().Dim(); k++)
+  {
+    lowerBound1[k] = node->Children()[sorted[0].n]->Bound()[k].Lo();
+    highBound1[k] = node->Children()[sorted[0].n]->Bound()[k].Hi();
+
+    for (size_t i = 1; i < fillFactor; i++)
+    {
+      if (node->Children()[sorted[i].n]->Bound()[k].Lo() < lowerBound1[k])
+        lowerBound1[k] = node->Children()[sorted[i].n]->Bound()[k].Lo();
+      if (node->Children()[sorted[i].n]->Bound()[k].Hi() > highBound1[k])
+        highBound1[k] = node->Children()[sorted[i].n]->Bound()[k].Hi();
+    }
+
+    lowerBound2[k] = node->Children()[sorted[fillFactor].n]->Bound()[k].Lo();
+    highBound2[k] = node->Children()[sorted[fillFactor].n]->Bound()[k].Hi();
+
+    for (size_t i = fillFactor + 1; i < node->NumChildren(); i++)
+    {
+      if (node->Children()[sorted[i].n]->Bound()[k].Lo() < lowerBound2[k])
+        lowerBound2[k] = node->Children()[sorted[i].n]->Bound()[k].Lo();
+      if (node->Children()[sorted[i].n]->Bound()[k].Hi() > highBound2[k])
+        highBound2[k] = node->Children()[sorted[i].n]->Bound()[k].Hi();
+    }
+  }
+
+  ElemType area1 = 1.0, area2 = 1.0;
+  ElemType overlappedArea = 1.0;
+
+  for (size_t k = 0; k < node->Bound().Dim(); k++)
+  {
+    area1 *= highBound1[k] - lowerBound1[k];
+    area2 *= highBound2[k] - lowerBound2[k];
+
+    if (lowerBound1[k] > highBound2[k] || lowerBound2[k] > highBound2[k])
+      overlappedArea *= 0;
+    else
+      overlappedArea *= std::min(highBound1[k], highBound2[k]) -
+          std::max(lowerBound1[k], lowerBound2[k]);
+  }
+
+  return area1 + area2 - overlappedArea;
+}
+
+template<typename TreeType>
+double RPlusTreeSplit::SweepLeafNode(size_t axis, const TreeType* node,
+    size_t fillFactor, double& axisCut)
+{
+  typedef typename TreeType::ElemType ElemType;
+
+  std::vector<SortStruct<ElemType>> sorted(node->Count());
+
+  sorted.resize(node->Count());
+
+  for (size_t i = 0; i < node->NumPoints(); i++)
+  {
+    sorted[i].d = node->Dataset().col(node->Point(i))[axis];
+    sorted[i].n = i;
+  }
+
+  std::sort(sorted.begin(), sorted.end(), StructComp<ElemType>);
+
+  axisCut = sorted[fillFactor - 1].d;
+
+  if (!CheckLeafSweep(node, axis, axisCut))
+    return std::numeric_limits<double>::max();
+
+  std::vector<ElemType> lowerBound1(node->Bound().Dim());
+  std::vector<ElemType> highBound1(node->Bound().Dim());
+  std::vector<ElemType> lowerBound2(node->Bound().Dim());
+  std::vector<ElemType> highBound2(node->Bound().Dim());
+
+  for (size_t k = 0; k < node->Bound().Dim(); k++)
+  {
+    lowerBound1[k] = node->Dataset().col(node->Point(sorted[0].n))[k];
+    highBound1[k] = node->Dataset().col(node->Point(sorted[0].n))[k];
+
+    for (size_t i = 1; i < fillFactor; i++)
+    {
+      if (node->Dataset().col(node->Point(sorted[i].n))[k] < lowerBound1[k])
+        lowerBound1[k] = node->Dataset().col(node->Point(sorted[i].n))[k];
+      if (node->Dataset().col(node->Point(sorted[i].n))[k] > highBound1[k])
+        highBound1[k] = node->Dataset().col(node->Point(sorted[i].n))[k];
+    }
+
+    lowerBound2[k] = node->Dataset().col(node->Point(sorted[fillFactor].n))[k];
+    highBound2[k] = node->Dataset().col(node->Point(sorted[fillFactor].n))[k];
+
+    for (size_t i = fillFactor + 1; i < node->NumChildren(); i++)
+    {
+      if (node->Dataset().col(node->Point(sorted[i].n))[k] < lowerBound2[k])
+        lowerBound2[k] = node->Dataset().col(node->Point(sorted[i].n))[k];
+      if (node->Dataset().col(node->Point(sorted[i].n))[k] > highBound2[k])
+        highBound2[k] = node->Dataset().col(node->Point(sorted[i].n))[k];
+    }
+  }
+
+  ElemType area1 = 1.0, area2 = 1.0;
+  ElemType overlappedArea = 1.0;
+
+  for (size_t k = 0; k < node->Bound().Dim(); k++)
+  {
+    area1 *= highBound1[k] - lowerBound1[k];
+    area2 *= highBound2[k] - lowerBound2[k];
+  }
+
+  return area1 + area2 - overlappedArea;
+}
+
+template<typename TreeType>
+void RPlusTreeSplit::
+InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode)
+{
+  destTree->Bound() |= srcNode->Bound();
+  destTree->Children()[destTree->NumChildren()++] = srcNode;
+}
+
+
+} // namespace tree
+} // namespace mlpack
+
+#endif  //  MLPACK_CORE_TREE_RECTANGLE_TREE_HILBERT_R_TREE_SPLIT_IMPL_HPP
diff --git a/src/mlpack/core/tree/rectangle_tree/r_star_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_star_tree_split_impl.hpp
index 0ec5c51..6b8e73d 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_star_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_star_tree_split_impl.hpp
@@ -27,6 +27,9 @@ void RStarTreeSplit::SplitLeafNode(TreeType *tree,std::vector<bool>& relevels)
   // Convenience typedef.
   typedef typename TreeType::ElemType ElemType;
 
+  if (tree->Count() <= tree->MaxLeafSize())
+    return;
+
   // If we are splitting the root node, we need will do things differently so
   // that the constructor and other methods don't confuse the end user by giving
   // an address of another node.
diff --git a/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
index db67fbe..3ad3629 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
@@ -23,6 +23,8 @@ namespace tree {
 template<typename TreeType>
 void RTreeSplit::SplitLeafNode(TreeType *tree,std::vector<bool>& relevels)
 {
+  if (tree->Count() <= tree->MaxLeafSize())
+    return;
   // If we are splitting the root node, we need will do things differently so
   // that the constructor and other methods don't confuse the end user by giving
   // an address of another node.
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
index b087f79..c6d8ac2 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
@@ -694,8 +694,8 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
   if (numChildren == 0)
   {
     // Check to see if we are full.
-    if (count <= maxLeafSize)
-      return; // We don't need to split.
+//    if (count <= maxLeafSize)
+//      return; // We don't need to split.
 
     // If we are full, then we need to split (or at least try).  The SplitType
     // takes care of this and of moving up the tree if necessary.
diff --git a/src/mlpack/core/tree/rectangle_tree/typedef.hpp b/src/mlpack/core/tree/rectangle_tree/typedef.hpp
index 6622099..4737d5a 100644
--- a/src/mlpack/core/tree/rectangle_tree/typedef.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/typedef.hpp
@@ -125,6 +125,13 @@ using DiscreteHilbertRTree = RectangleTree<MetricType,
                             HilbertRTreeDescentHeuristic,
                             DiscreteHilbertRTreeAuxiliaryInformation>;
 
+template<typename MetricType, typename StatisticType, typename MatType>
+using RPlusTree = RectangleTree<MetricType,
+                            StatisticType,
+                            MatType,
+                            RPlusTreeSplit,
+                            RPlusTreeDescentHeuristic,
+                            NoAuxiliaryInformation>;
 
 } // namespace tree
 } // namespace mlpack
diff --git a/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp
index 89372a4..58591d7 100644
--- a/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp
@@ -26,6 +26,9 @@ void XTreeSplit::SplitLeafNode(TreeType *tree,std::vector<bool>& relevels)
   // Convenience typedef.
   typedef typename TreeType::ElemType ElemType;
 
+  if (tree->Count() <= tree->MaxLeafSize())
+    return;
+
   // If we are splitting the root node, we need will do things differently so
   // that the constructor and other methods don't confuse the end user by giving
   // an address of another node.
diff --git a/src/mlpack/tests/rectangle_tree_test.cpp b/src/mlpack/tests/rectangle_tree_test.cpp
index 2af8ec3..0845147 100644
--- a/src/mlpack/tests/rectangle_tree_test.cpp
+++ b/src/mlpack/tests/rectangle_tree_test.cpp
@@ -763,6 +763,96 @@ BOOST_AUTO_TEST_CASE(DiscreteHilbertValueTest)
   BOOST_REQUIRE_EQUAL(DiscreteHilbertValue<double>::ComparePoints(point1,point2), 1);
 }
 
+template<typename TreeType>
+void CheckOverlap(TreeType* tree)
+{
+  bool success = true;
+
+  for (size_t i = 0; i < tree->NumChildren(); i++)
+  {
+    success = true;
+
+    for (size_t j = 0; j < tree->NumChildren(); j++)
+    {
+      if (j == i)
+        continue;
+      success = false;
+      for (size_t k = 0; k < tree->Bound().Dim(); k++)
+      {
+        if (tree->Children()[i]->Bound()[k].Lo() >= tree->Children()[j]->Bound()[k].Hi() ||
+            tree->Children()[j]->Bound()[k].Lo() >= tree->Children()[i]->Bound()[k].Hi())
+        {
+          success = true;
+          break;
+        }
+      }
+      if (!success)
+        break;
+    }
+    if (success)
+      break;
+  }
+  assert(success == true);
+
+  for (size_t i = 0; i < tree->NumChildren(); i++)
+    CheckOverlap(tree->Children()[i]);
+}
+
+BOOST_AUTO_TEST_CASE(RPlusTreeOverlapTest)
+{
+  arma::mat dataset;
+  dataset.randu(8, 1000); // 1000 points in 8 dimensions.
+
+  typedef RPlusTree<EuclideanDistance,
+      NeighborSearchStat<NearestNeighborSort>,arma::mat> TreeType;
+  TreeType rPlusTree(dataset, 20, 6, 5, 2, 0);
+
+  CheckOverlap(&rPlusTree);
+}
+
+
+BOOST_AUTO_TEST_CASE(RPlusTreeTraverserTest)
+{
+  arma::mat dataset;
+
+  const int numP = 1000;
+
+  dataset.randu(8, numP); // 1000 points in 8 dimensions.
+  arma::Mat<size_t> neighbors1;
+  arma::mat distances1;
+  arma::Mat<size_t> neighbors2;
+  arma::mat distances2;
+
+  typedef RPlusTree<EuclideanDistance, NeighborSearchStat<NearestNeighborSort>,
+      arma::mat> TreeType;
+  TreeType rPlusTree(dataset, 20, 6, 5, 2, 0);
+
+  // Nearest neighbor search with the X tree.
+
+  NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>, arma::mat, RPlusTree >
+      knn1(&rPlusTree, true);
+
+  BOOST_REQUIRE_EQUAL(rPlusTree.NumDescendants(), numP);
+
+  CheckContainment(rPlusTree);
+  CheckExactContainment(rPlusTree);
+  CheckHierarchy(rPlusTree);
+  CheckOverlap(&rPlusTree);
+
+  knn1.Search(5, neighbors1, distances1);
+
+  // Nearest neighbor search the naive way.
+  KNN knn2(dataset, true, true);
+
+  knn2.Search(5, neighbors2, distances2);
+
+  for (size_t i = 0; i < neighbors1.size(); i++)
+  {
+    BOOST_REQUIRE_EQUAL(neighbors1[i], neighbors2[i]);
+    BOOST_REQUIRE_EQUAL(distances1[i], distances2[i]);
+  }
+}
+
 // Test the tree splitting.  We set MaxLeafSize and MaxNumChildren rather low
 // to allow us to test by hand without adding hundreds of points.
 BOOST_AUTO_TEST_CASE(RTreeSplitTest)




More information about the mlpack-git mailing list