[mlpack-git] master: R+ and R++ trees refactoring. Added a template parameter SweepType. Implemented MinimalSplitsNumberSweep. (da4b598)

gitdub at mlpack.org gitdub at mlpack.org
Thu Jun 23 15:59:29 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/6147ed01bab6eadcd6a5e796e259a6afacae4662...e0fd69006b17a845f066ea4de1e205fc0922739d

>---------------------------------------------------------------

commit da4b598cc0eacec8807611824887feea1fd397ad
Author: Mikhail Lozhnikov <lozhnikovma at gmail.com>
Date:   Thu Jun 23 22:59:29 2016 +0300

    R+ and R++ trees refactoring.
    Added a template parameter SweepType.
    Implemented MinimalSplitsNumberSweep.


>---------------------------------------------------------------

da4b598cc0eacec8807611824887feea1fd397ad
 src/mlpack/core/tree/CMakeLists.txt                |   4 +
 src/mlpack/core/tree/rectangle_tree.hpp            |   2 +
 .../tree/rectangle_tree/minimal_coverage_sweep.hpp |  64 +++++
 .../rectangle_tree/minimal_coverage_sweep_impl.hpp | 236 +++++++++++++++++
 .../rectangle_tree/minimal_splits_number_sweep.hpp |  53 ++++
 .../minimal_splits_number_sweep_impl.hpp           |  89 +++++++
 .../r_plus_tree_descent_heuristic_impl.hpp         |   2 +-
 .../core/tree/rectangle_tree/r_plus_tree_split.hpp |  39 +--
 .../tree/rectangle_tree/r_plus_tree_split_impl.hpp | 285 +++------------------
 src/mlpack/core/tree/rectangle_tree/typedef.hpp    |   6 +-
 src/mlpack/tests/rectangle_tree_test.cpp           |  23 +-
 11 files changed, 525 insertions(+), 278 deletions(-)

diff --git a/src/mlpack/core/tree/CMakeLists.txt b/src/mlpack/core/tree/CMakeLists.txt
index 1cefc80..0a46a1e 100644
--- a/src/mlpack/core/tree/CMakeLists.txt
+++ b/src/mlpack/core/tree/CMakeLists.txt
@@ -65,6 +65,10 @@ set(SOURCES
   rectangle_tree/discrete_hilbert_value_impl.hpp
   rectangle_tree/r_plus_tree_descent_heuristic.hpp
   rectangle_tree/r_plus_tree_descent_heuristic_impl.hpp
+  rectangle_tree/minimal_coverage_sweep.hpp
+  rectangle_tree/minimal_coverage_sweep_impl.hpp
+  rectangle_tree/minimal_splits_number_sweep.hpp
+  rectangle_tree/minimal_splits_number_sweep_impl.hpp
   rectangle_tree/r_plus_tree_split.hpp
   rectangle_tree/r_plus_tree_split_impl.hpp
   rectangle_tree/r_plus_tree_split_policy.hpp
diff --git a/src/mlpack/core/tree/rectangle_tree.hpp b/src/mlpack/core/tree/rectangle_tree.hpp
index d418314..a31b75b 100644
--- a/src/mlpack/core/tree/rectangle_tree.hpp
+++ b/src/mlpack/core/tree/rectangle_tree.hpp
@@ -32,6 +32,8 @@
 #include "rectangle_tree/discrete_hilbert_value.hpp"
 #include "rectangle_tree/r_plus_tree_descent_heuristic.hpp"
 #include "rectangle_tree/r_plus_tree_split_policy.hpp"
+#include "rectangle_tree/minimal_coverage_sweep.hpp"
+#include "rectangle_tree/minimal_splits_number_sweep.hpp"
 #include "rectangle_tree/r_plus_tree_split.hpp"
 #include "rectangle_tree/r_plus_plus_tree_auxiliary_information.hpp"
 #include "rectangle_tree/r_plus_plus_tree_descent_heuristic.hpp"
diff --git a/src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep.hpp b/src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep.hpp
new file mode 100644
index 0000000..4e9ce5a
--- /dev/null
+++ b/src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep.hpp
@@ -0,0 +1,64 @@
+/**
+ * @file minimal_coverage_sweep.hpp
+ * @author Mikhail Lozhnikov
+ *
+ */
+#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_COVERAGE_SWEEP_HPP
+#define MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_COVERAGE_SWEEP_HPP
+
+namespace mlpack {
+namespace tree {
+
+constexpr double fillFactor = 0.5;
+
+template<typename SplitPolicy>
+class MinimalCoverageSweep
+{
+ private:
+  template<typename ElemType>
+  struct SortStruct
+  {
+    ElemType d;
+    int n;
+  };
+
+  template<typename ElemType>
+  static bool StructComp(const SortStruct<ElemType>& s1,
+                         const SortStruct<ElemType>& s2)
+  {
+    return s1.d < s2.d;
+  }
+
+ public:
+
+  template<typename TreeType>
+  struct SweepCost
+  {
+    typedef typename TreeType::ElemType type;
+  };
+
+  template<typename TreeType>
+  static typename TreeType::ElemType SweepNonLeafNode(size_t axis,
+      const TreeType* node, typename TreeType::ElemType& axisCut);
+
+  template<typename TreeType>
+  static typename TreeType::ElemType SweepLeafNode(size_t axis,
+      const TreeType* node, typename TreeType::ElemType& axisCut);
+
+  template<typename TreeType, typename ElemType>
+  static bool CheckNonLeafSweep(const TreeType* node, size_t cutAxis,
+      ElemType cut);
+
+  template<typename TreeType, typename ElemType>
+  static bool CheckLeafSweep(const TreeType* node, size_t cutAxis,
+      ElemType cut);
+};
+
+} // namespace tree
+} // namespace mlpack
+
+// Include implementation
+#include "minimal_coverage_sweep_impl.hpp"
+
+#endif  //  MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_COVERAGE_SWEEP_HPP
+
diff --git a/src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep_impl.hpp b/src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep_impl.hpp
new file mode 100644
index 0000000..b300cfa
--- /dev/null
+++ b/src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep_impl.hpp
@@ -0,0 +1,236 @@
+/**
+ * @file minimal_coverage_sweep_impl.hpp
+ * @author Mikhail Lozhnikov
+ *
+ */
+#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_COVERAGE_SWEEP_IMPL_HPP
+#define MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_COVERAGE_SWEEP_IMPL_HPP
+
+#include "minimal_coverage_sweep.hpp"
+
+namespace mlpack {
+namespace tree {
+
+template<typename SplitPolicy>
+template<typename TreeType>
+typename TreeType::ElemType MinimalCoverageSweep<SplitPolicy>::
+SweepNonLeafNode(size_t axis, const TreeType* node,
+    typename TreeType::ElemType& axisCut)
+{
+  typedef typename TreeType::ElemType ElemType;
+
+  std::vector<SortStruct<ElemType>> sorted(node->NumChildren());
+
+  for (size_t i = 0; i < node->NumChildren(); i++)
+  {
+    sorted[i].d = SplitPolicy::Bound(node->Children()[i])[axis].Hi();
+    sorted[i].n = i;
+  }
+  std::sort(sorted.begin(), sorted.end(), StructComp<ElemType>);
+
+  size_t splitPointer = fillFactor * node->NumChildren();
+
+  axisCut = sorted[splitPointer - 1].d;
+
+  if (!CheckNonLeafSweep(node, axis, axisCut))
+  {
+    for (splitPointer = 1; splitPointer < sorted.size(); splitPointer++)
+    {
+      axisCut = sorted[splitPointer - 1].d;
+      if (CheckNonLeafSweep(node, axis, axisCut))
+        break;
+    }
+
+    if (splitPointer == node->NumChildren())
+      return std::numeric_limits<ElemType>::max();
+  }
+
+  std::vector<ElemType> lowerBound1(node->Bound().Dim());
+  std::vector<ElemType> highBound1(node->Bound().Dim());
+  std::vector<ElemType> lowerBound2(node->Bound().Dim());
+  std::vector<ElemType> highBound2(node->Bound().Dim());
+
+  for (size_t k = 0; k < node->Bound().Dim(); k++)
+  {
+    lowerBound1[k] = node->Children()[sorted[0].n]->Bound()[k].Lo();
+    highBound1[k] = node->Children()[sorted[0].n]->Bound()[k].Hi();
+
+    for (size_t i = 1; i < splitPointer; i++)
+    {
+      if (node->Children()[sorted[i].n]->Bound()[k].Lo() < lowerBound1[k])
+        lowerBound1[k] = node->Children()[sorted[i].n]->Bound()[k].Lo();
+      if (node->Children()[sorted[i].n]->Bound()[k].Hi() > highBound1[k])
+        highBound1[k] = node->Children()[sorted[i].n]->Bound()[k].Hi();
+    }
+
+    lowerBound2[k] = node->Children()[sorted[splitPointer].n]->Bound()[k].Lo();
+    highBound2[k] = node->Children()[sorted[splitPointer].n]->Bound()[k].Hi();
+
+    for (size_t i = splitPointer + 1; i < node->NumChildren(); i++)
+    {
+      if (node->Children()[sorted[i].n]->Bound()[k].Lo() < lowerBound2[k])
+        lowerBound2[k] = node->Children()[sorted[i].n]->Bound()[k].Lo();
+      if (node->Children()[sorted[i].n]->Bound()[k].Hi() > highBound2[k])
+        highBound2[k] = node->Children()[sorted[i].n]->Bound()[k].Hi();
+    }
+  }
+
+  ElemType area1 = 1.0, area2 = 1.0;
+  ElemType overlappedArea = 1.0;
+
+  for (size_t k = 0; k < node->Bound().Dim(); k++)
+  {
+    if (lowerBound1[k] >= highBound1[k])
+    {
+      overlappedArea *= 0;
+      area1 *= 0;
+    }
+    else
+      area1 *= highBound1[k] - lowerBound1[k];
+
+    if (lowerBound2[k] >= highBound2[k])
+    {
+      overlappedArea *= 0;
+      area1 *= 0;
+    }
+    else
+      area2 *= highBound2[k] - lowerBound2[k];
+
+    if (lowerBound1[k] < highBound1[k] && lowerBound2[k] < highBound2[k])
+    {
+      if (lowerBound1[k] > highBound2[k] || lowerBound2[k] > highBound2[k])
+        overlappedArea *= 0;
+      else
+        overlappedArea *= std::min(highBound1[k], highBound2[k]) -
+            std::max(lowerBound1[k], lowerBound2[k]);
+    }
+  }
+
+  return area1 + area2 - overlappedArea;
+}
+
+template<typename SplitPolicy>
+template<typename TreeType>
+typename TreeType::ElemType MinimalCoverageSweep<SplitPolicy>::
+SweepLeafNode(size_t axis, const TreeType* node,
+    typename TreeType::ElemType& axisCut)
+{
+  typedef typename TreeType::ElemType ElemType;
+
+  std::vector<SortStruct<ElemType>> sorted(node->Count());
+
+  sorted.resize(node->Count());
+
+  for (size_t i = 0; i < node->NumPoints(); i++)
+  {
+    sorted[i].d = node->Dataset().col(node->Point(i))[axis];
+    sorted[i].n = i;
+  }
+
+  std::sort(sorted.begin(), sorted.end(), StructComp<ElemType>);
+
+  size_t splitPointer = fillFactor * node->Count();
+
+  axisCut = sorted[splitPointer - 1].d;
+
+  if (!CheckLeafSweep(node, axis, axisCut))
+    return std::numeric_limits<ElemType>::max();
+
+  std::vector<ElemType> lowerBound1(node->Bound().Dim());
+  std::vector<ElemType> highBound1(node->Bound().Dim());
+  std::vector<ElemType> lowerBound2(node->Bound().Dim());
+  std::vector<ElemType> highBound2(node->Bound().Dim());
+
+  for (size_t k = 0; k < node->Bound().Dim(); k++)
+  {
+    lowerBound1[k] = node->Dataset().col(node->Point(sorted[0].n))[k];
+    highBound1[k] = node->Dataset().col(node->Point(sorted[0].n))[k];
+
+    for (size_t i = 1; i < splitPointer; i++)
+    {
+      if (node->Dataset().col(node->Point(sorted[i].n))[k] < lowerBound1[k])
+        lowerBound1[k] = node->Dataset().col(node->Point(sorted[i].n))[k];
+      if (node->Dataset().col(node->Point(sorted[i].n))[k] > highBound1[k])
+        highBound1[k] = node->Dataset().col(node->Point(sorted[i].n))[k];
+    }
+
+    lowerBound2[k] = node->Dataset().col(node->Point(sorted[splitPointer].n))[k];
+    highBound2[k] = node->Dataset().col(node->Point(sorted[splitPointer].n))[k];
+
+    for (size_t i = splitPointer + 1; i < node->NumChildren(); i++)
+    {
+      if (node->Dataset().col(node->Point(sorted[i].n))[k] < lowerBound2[k])
+        lowerBound2[k] = node->Dataset().col(node->Point(sorted[i].n))[k];
+      if (node->Dataset().col(node->Point(sorted[i].n))[k] > highBound2[k])
+        highBound2[k] = node->Dataset().col(node->Point(sorted[i].n))[k];
+    }
+  }
+
+  ElemType area1 = 1.0, area2 = 1.0;
+  ElemType overlappedArea = 1.0;
+
+  for (size_t k = 0; k < node->Bound().Dim(); k++)
+  {
+    area1 *= highBound1[k] - lowerBound1[k];
+    area2 *= highBound2[k] - lowerBound2[k];
+  }
+
+  return area1 + area2 - overlappedArea;
+}
+
+template<typename SplitPolicy>
+template<typename TreeType, typename ElemType>
+bool MinimalCoverageSweep<SplitPolicy>::
+CheckNonLeafSweep(const TreeType* node, size_t cutAxis, ElemType cut)
+{
+  size_t numTreeOneChildren = 0;
+  size_t numTreeTwoChildren = 0;
+
+  for (size_t i = 0; i < node->NumChildren(); i++)
+  {
+    TreeType* child = node->Children()[i];
+    int policy = SplitPolicy::GetSplitPolicy(child, cutAxis, cut);
+    if (policy == SplitPolicy::AssignToFirstTree)
+      numTreeOneChildren++;
+    else if (policy == SplitPolicy::AssignToSecondTree)
+      numTreeTwoChildren++;
+    else
+    {
+      numTreeOneChildren++;
+      numTreeTwoChildren++;
+    }
+  }
+
+  if (numTreeOneChildren <= node->MaxNumChildren() && numTreeOneChildren > 0 &&
+      numTreeTwoChildren <= node->MaxNumChildren() && numTreeTwoChildren > 0)
+    return true;
+  return false;
+}
+
+template<typename SplitPolicy>
+template<typename TreeType, typename ElemType>
+bool MinimalCoverageSweep<SplitPolicy>::
+CheckLeafSweep(const TreeType* node, size_t cutAxis, ElemType cut)
+{
+  size_t numTreeOnePoints = 0;
+  size_t numTreeTwoPoints = 0;
+
+  for (size_t i = 0; i < node->NumPoints(); i++)
+  {
+    if (node->Dataset().col(node->Point(i))[cutAxis] <= cut)
+      numTreeOnePoints++;
+    else
+      numTreeTwoPoints++;
+  }
+
+  if (numTreeOnePoints <= node->MaxLeafSize() && numTreeOnePoints > 0 &&
+      numTreeTwoPoints <= node->MaxLeafSize() && numTreeTwoPoints > 0)
+    return true;
+  return false;
+}
+
+} // namespace tree
+} // namespace mlpack
+
+#endif  //  MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_COVERAGE_SWEEP_IMPL_HPP
+
diff --git a/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep.hpp b/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep.hpp
new file mode 100644
index 0000000..60cbe2d
--- /dev/null
+++ b/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep.hpp
@@ -0,0 +1,53 @@
+/**
+ * @file minimal_splits_number_sweep.hpp
+ * @author Mikhail Lozhnikov
+ *
+ */
+#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_SPLITS_NUMBER_SWEEP_HPP
+#define MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_SPLITS_NUMBER_SWEEP_HPP
+
+namespace mlpack {
+namespace tree {
+
+template<typename SplitPolicy>
+class MinimalSplitsNumberSweep
+{
+ private:
+  template<typename ElemType>
+  struct SortStruct
+  {
+    ElemType d;
+    int n;
+  };
+
+  template<typename ElemType>
+  static bool StructComp(const SortStruct<ElemType>& s1,
+                         const SortStruct<ElemType>& s2)
+  {
+    return s1.d < s2.d;
+  }
+ public:
+  template<typename>
+  struct SweepCost
+  {
+    typedef size_t type;
+  };
+
+  template<typename TreeType>
+  static size_t SweepNonLeafNode(size_t axis, const TreeType* node,
+      typename TreeType::ElemType& axisCut);
+
+  template<typename TreeType>
+  static size_t SweepLeafNode(size_t axis, const TreeType* node,
+      typename TreeType::ElemType& axisCut);
+};
+
+} // namespace tree
+} // namespace mlpack
+
+// Include implementation
+#include "minimal_splits_number_sweep_impl.hpp"
+
+#endif  //  MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_SPLITS_NUMBER_SWEEP_HPP
+
+
diff --git a/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep_impl.hpp b/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep_impl.hpp
new file mode 100644
index 0000000..21c0dcb
--- /dev/null
+++ b/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep_impl.hpp
@@ -0,0 +1,89 @@
+/**
+ * @file minimal_splits_number_sweep_impl.hpp
+ * @author Mikhail Lozhnikov
+ *
+ */
+#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_SPLITS_NUMBER_SWEEP_IMPL_HPP
+#define MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_SPLITS_NUMBER_SWEEP_IMPL_HPP
+
+#include "minimal_splits_number_sweep.hpp"
+
+namespace mlpack {
+namespace tree {
+
+template<typename SplitPolicy>
+template<typename TreeType>
+size_t MinimalSplitsNumberSweep<SplitPolicy>::
+SweepNonLeafNode(size_t axis, const TreeType* node,
+    typename TreeType::ElemType& axisCut)
+{
+  typedef typename TreeType::ElemType ElemType;
+
+  std::vector<SortStruct<ElemType>> sorted(node->NumChildren());
+
+  for (size_t i = 0; i < node->NumChildren(); i++)
+  {
+    sorted[i].d = SplitPolicy::Bound(node->Children()[i])[axis].Hi();
+    sorted[i].n = i;
+  }
+  std::sort(sorted.begin(), sorted.end(), StructComp<ElemType>);
+
+  size_t minCost = SIZE_MAX;
+
+  for (size_t i = 0; i < sorted.size(); i++)
+  {
+    size_t numTreeOneChildren = 0;
+    size_t numTreeTwoChildren = 0;
+    size_t numSplits = 0;
+
+    for (size_t j = 0; j < node->NumChildren(); j++)
+    {
+      TreeType* child = node->Children()[j];
+      int policy = SplitPolicy::GetSplitPolicy(child, axis, sorted[i].d);
+      if (policy == SplitPolicy::AssignToFirstTree)
+        numTreeOneChildren++;
+      else if (policy == SplitPolicy::AssignToSecondTree)
+        numTreeTwoChildren++;
+      else
+      {
+        numTreeOneChildren++;
+        numTreeTwoChildren++;
+        numSplits++;
+      }
+    }
+
+    if (numTreeOneChildren <= node->MaxNumChildren() && numTreeOneChildren > 0 &&
+        numTreeTwoChildren <= node->MaxNumChildren() && numTreeTwoChildren > 0)
+    {
+      size_t cost = numSplits * (std::abs(sorted.size() / 2 - i));
+      if (cost < minCost)
+      {
+        minCost = cost;
+        axisCut = sorted[i].d;
+      }
+    }
+  }
+  return minCost;
+}
+
+template<typename SplitPolicy>
+template<typename TreeType>
+size_t MinimalSplitsNumberSweep<SplitPolicy>::
+SweepLeafNode(size_t axis, const TreeType* node,
+    typename TreeType::ElemType& axisCut)
+{
+  axisCut = (node->Bound()[axis].Lo() + node->Bound()[axis].Hi()) * 0.5;
+
+  if (node->Bound()[axis].Lo() == axisCut)
+    return SIZE_MAX;
+
+  return 0;
+}
+
+
+} // namespace tree
+} // namespace mlpack
+
+#endif  //  MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_SPLITS_NUMBER_SWEEP_IMPL_HPP
+
+
diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic_impl.hpp
index 265b739..3d495f9 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic_impl.hpp
@@ -64,7 +64,7 @@ ChooseDescentNode(TreeType* node, const size_t point)
     TreeType* tree = node;
     while (depth > 1)
     {
-      TreeType* child = new TreeType(node);
+      TreeType* child = new TreeType(tree);
 
       tree->Children()[tree->NumChildren()++] = child;
       tree = child;
diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split.hpp
index 8c2d056..5d5e273 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split.hpp
@@ -10,15 +10,15 @@
 
 #include <mlpack/core.hpp>
 
-const double fillFactorFraction = 0.5;
-
 namespace mlpack {
 namespace tree /** Trees and tree-building procedures. */ {
 
-template<typename SplitPolicyType>
+template<typename SplitPolicyType,
+         template<typename> class SweepType>
 class RPlusTreeSplit
 {
  public:
+  typedef SplitPolicyType SplitPolicy;
   /**
    * Split a leaf node using the "default" algorithm.  If necessary, this split
    * will propagate upwards through the tree.
@@ -48,45 +48,24 @@ class RPlusTreeSplit
     int n;
   };
 
-  template<typename ElemType>
-  static bool StructComp(const SortStruct<ElemType>& s1,
-                         const SortStruct<ElemType>& s2)
-  {
-    return s1.d < s2.d;
-  }
-
   template<typename TreeType>
-  static void SplitLeafNodeAlongPartition(TreeType* tree,
-      TreeType* treeOne, TreeType* treeTwo, size_t cutAxis, double cut);
+  static void SplitLeafNodeAlongPartition(TreeType* tree, TreeType* treeOne,
+      TreeType* treeTwo, size_t cutAxis, typename TreeType::ElemType cut);
 
   template<typename TreeType>
-  static void SplitNonLeafNodeAlongPartition(TreeType* tree,
-      TreeType* treeOne, TreeType* treeTwo, size_t cutAxis, double cut);
+  static void SplitNonLeafNodeAlongPartition(TreeType* tree, TreeType* treeOne,
+      TreeType* treeTwo, size_t cutAxis, typename TreeType::ElemType cut);
 
   template<typename TreeType>
   static void AddFakeNodes(const TreeType* tree, TreeType* emptyTree);
 
   template<typename TreeType>
-  static bool PartitionNode(const TreeType* node, size_t fillFactor,
-      size_t& minCutAxis, double& minCut);
-
-  template<typename TreeType>
-  static double SweepLeafNode(size_t axis, const TreeType* node,
-      size_t fillFactor, double& axisCut);
-
-  template<typename TreeType>
-  static double SweepNonLeafNode(size_t axis, const TreeType* node,
-      size_t fillFactor, double& axisCut);
+  static bool PartitionNode(const TreeType* node, size_t& minCutAxis,
+      typename TreeType::ElemType& minCut);
 
   template<typename TreeType>
   static void InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode);
 
-  template<typename TreeType>
-  static bool CheckNonLeafSweep(const TreeType* node,
-      size_t cutAxis, double cut);
-
-  template<typename TreeType>
-  static bool CheckLeafSweep(const TreeType* node, size_t cutAxis, double cut);
 };
 
 } // namespace tree
diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp
index e484872..9010595 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp
@@ -16,9 +16,11 @@
 namespace mlpack {
 namespace tree {
 
-template<typename SplitPolicyType>
+template<typename SplitPolicyType,
+         template<typename> class SweepType>
 template<typename TreeType>
-void RPlusTreeSplit<SplitPolicyType>::SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
+void RPlusTreeSplit<SplitPolicyType, SweepType>::
+SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
 {
   if (tree->Count() == 1)
   {
@@ -55,11 +57,10 @@ void RPlusTreeSplit<SplitPolicyType>::SplitLeafNode(TreeType* tree, std::vector<
     return;
   }
 
-  const size_t fillFactor = tree->MaxLeafSize() * fillFactorFraction;
   size_t cutAxis;
-  double cut;
+  typename TreeType::ElemType cut;
 
-  if ( !PartitionNode(tree, fillFactor, cutAxis, cut))
+  if ( !PartitionNode(tree, cutAxis, cut))
     return;
 
   assert(cutAxis < tree->Bound().Dim());
@@ -92,10 +93,11 @@ void RPlusTreeSplit<SplitPolicyType>::SplitLeafNode(TreeType* tree, std::vector<
   tree->SoftDelete();
 }
 
-template<typename SplitPolicyType>
+template<typename SplitPolicyType,
+         template<typename> class SweepType>
 template<typename TreeType>
-bool RPlusTreeSplit<SplitPolicyType>::SplitNonLeafNode(TreeType* tree,
-    std::vector<bool>& relevels)
+bool RPlusTreeSplit<SplitPolicyType, SweepType>::
+SplitNonLeafNode(TreeType* tree, std::vector<bool>& relevels)
 {
   // If we are splitting the root node, we need will do things differently so
   // that the constructor and other methods don't confuse the end user by giving
@@ -113,11 +115,10 @@ bool RPlusTreeSplit<SplitPolicyType>::SplitNonLeafNode(TreeType* tree,
     RPlusTreeSplit::SplitNonLeafNode(copy,relevels);
     return true;
   }
-  const size_t fillFactor = tree->MaxNumChildren() * fillFactorFraction;
   size_t cutAxis;
-  double cut;
+  typename TreeType::ElemType cut;
 
-  if ( !PartitionNode(tree, fillFactor, cutAxis, cut))
+  if ( !PartitionNode(tree, cutAxis, cut))
     return false;
 
   assert(cutAxis < tree->Bound().Dim());
@@ -153,10 +154,12 @@ bool RPlusTreeSplit<SplitPolicyType>::SplitNonLeafNode(TreeType* tree,
   return false;
 }
 
-template<typename SplitPolicyType>
+template<typename SplitPolicyType,
+         template<typename> class SweepType>
 template<typename TreeType>
-void RPlusTreeSplit<SplitPolicyType>::SplitLeafNodeAlongPartition(TreeType* tree,
-  TreeType* treeOne, TreeType* treeTwo, size_t cutAxis, double cut)
+void RPlusTreeSplit<SplitPolicyType, SweepType>::
+SplitLeafNodeAlongPartition(TreeType* tree, TreeType* treeOne,
+    TreeType* treeTwo, size_t cutAxis, typename TreeType::ElemType cut)
 {
   tree->AuxiliaryInfo().SplitAuxiliaryInfo(treeOne, treeTwo, cutAxis, cut);
 
@@ -180,10 +183,12 @@ void RPlusTreeSplit<SplitPolicyType>::SplitLeafNodeAlongPartition(TreeType* tree
   assert(treeOne->Bound()[cutAxis].Hi() < treeTwo->Bound()[cutAxis].Lo());
 }
 
-template<typename SplitPolicyType>
+template<typename SplitPolicyType,
+         template<typename> class SweepType>
 template<typename TreeType>
-void RPlusTreeSplit<SplitPolicyType>::SplitNonLeafNodeAlongPartition(TreeType* tree,
-  TreeType* treeOne, TreeType* treeTwo, size_t cutAxis, double cut)
+void RPlusTreeSplit<SplitPolicyType, SweepType>::
+SplitNonLeafNodeAlongPartition(TreeType* tree, TreeType* treeOne,
+    TreeType* treeTwo, size_t cutAxis, typename TreeType::ElemType cut)
 {
   tree->AuxiliaryInfo().SplitAuxiliaryInfo(treeOne, treeTwo, cutAxis, cut);
 
@@ -234,9 +239,10 @@ void RPlusTreeSplit<SplitPolicyType>::SplitNonLeafNodeAlongPartition(TreeType* t
   assert(treeTwo->NumChildren() <= treeTwo->MaxNumChildren());
 }
 
-template<typename SplitPolicyType>
+template<typename SplitPolicyType,
+         template<typename> class SweepType>
 template<typename TreeType>
-void RPlusTreeSplit<SplitPolicyType>::
+void RPlusTreeSplit<SplitPolicyType, SweepType>::
 AddFakeNodes(const TreeType* tree, TreeType* emptyTree)
 {
   size_t numDescendantNodes = 1;
@@ -258,79 +264,32 @@ AddFakeNodes(const TreeType* tree, TreeType* emptyTree)
   }
 }
 
-
-template<typename SplitPolicyType>
-template<typename TreeType>
-bool RPlusTreeSplit<SplitPolicyType>::CheckNonLeafSweep(const TreeType* node,
-    size_t cutAxis, double cut)
-{
-  size_t numTreeOneChildren = 0;
-  size_t numTreeTwoChildren = 0;
-
-  for (size_t i = 0; i < node->NumChildren(); i++)
-  {
-    TreeType* child = node->Children()[i];
-    int policy = SplitPolicyType::GetSplitPolicy(child, cutAxis, cut);
-    if (policy == SplitPolicyType::AssignToFirstTree)
-      numTreeOneChildren++;
-    else if (policy == SplitPolicyType::AssignToSecondTree)
-      numTreeTwoChildren++;
-    else
-    {
-      numTreeOneChildren++;
-      numTreeTwoChildren++;
-    }
-  }
-
-  if (numTreeOneChildren <= node->MaxNumChildren() && numTreeOneChildren > 0 &&
-      numTreeTwoChildren <= node->MaxNumChildren() && numTreeTwoChildren > 0)
-    return true;
-  return false;
-}
-
-template<typename SplitPolicyType>
-template<typename TreeType>
-bool RPlusTreeSplit<SplitPolicyType>::CheckLeafSweep(const TreeType* node,
-    size_t cutAxis, double cut)
-{
-  size_t numTreeOnePoints = 0;
-  size_t numTreeTwoPoints = 0;
-
-  for (size_t i = 0; i < node->NumPoints(); i++)
-  {
-    if (node->Dataset().col(node->Point(i))[cutAxis] <= cut)
-      numTreeOnePoints++;
-    else
-      numTreeTwoPoints++;
-  }
-
-  if (numTreeOnePoints <= node->MaxLeafSize() && numTreeOnePoints > 0 &&
-      numTreeTwoPoints <= node->MaxLeafSize() && numTreeTwoPoints > 0)
-    return true;
-  return false;
-}
-
-template<typename SplitPolicyType>
+template<typename SplitPolicyType,
+         template<typename> class SweepType>
 template<typename TreeType>
-bool RPlusTreeSplit<SplitPolicyType>::PartitionNode(const TreeType* node, size_t fillFactor,
-    size_t& minCutAxis, double& minCut)
+bool RPlusTreeSplit<SplitPolicyType, SweepType>::
+PartitionNode(const TreeType* node, size_t& minCutAxis,
+    typename TreeType::ElemType& minCut)
 {
   if ((node->NumChildren() <= fillFactor && !node->IsLeaf()) ||
       (node->Count() <= fillFactor && node->IsLeaf()))
     return false;
 
-  double minCost = std::numeric_limits<double>::max();
+  typedef typename SweepType<SplitPolicyType>::template SweepCost<TreeType>::type
+      SweepCostType;
+
+  SweepCostType minCost = std::numeric_limits<SweepCostType>::max();
   minCutAxis = node->Bound().Dim();
 
   for (size_t k = 0; k < node->Bound().Dim(); k++)
   {
-    double cut;
-    double cost;
+    typename TreeType::ElemType cut;
+    SweepCostType cost;
 
     if (node->IsLeaf())
-      cost = SweepLeafNode(k, node, fillFactor, cut);
+      cost = SweepType<SplitPolicyType>::SweepLeafNode(k, node, cut);
     else
-      cost = SweepNonLeafNode(k, node, fillFactor, cut);
+      cost = SweepType<SplitPolicyType>::SweepNonLeafNode(k, node, cut);
     
 
     if (cost < minCost)
@@ -343,172 +302,10 @@ bool RPlusTreeSplit<SplitPolicyType>::PartitionNode(const TreeType* node, size_t
   return true;
 }
 
-template<typename SplitPolicyType>
-template<typename TreeType>
-double RPlusTreeSplit<SplitPolicyType>::SweepNonLeafNode(size_t axis, const TreeType* node,
-    size_t fillFactor, double& axisCut)
-{
-  typedef typename TreeType::ElemType ElemType;
-
-  std::vector<SortStruct<ElemType>> sorted(node->NumChildren());
-
-  for (size_t i = 0; i < node->NumChildren(); i++)
-  {
-    sorted[i].d = SplitPolicyType::Bound(node->Children()[i])[axis].Hi();
-    sorted[i].n = i;
-  }
-  std::sort(sorted.begin(), sorted.end(), StructComp<ElemType>);
-
-  size_t splitPointer = fillFactor;
-
-  axisCut = sorted[splitPointer - 1].d;
-
-  if (!CheckNonLeafSweep(node, axis, axisCut))
-  {
-    for (splitPointer = 1; splitPointer < node->NumChildren(); splitPointer++)
-    {
-      axisCut = sorted[splitPointer - 1].d;
-      if (CheckNonLeafSweep(node, axis, axisCut))
-        break;
-    }
-
-    if (splitPointer == node->NumChildren())
-      return std::numeric_limits<double>::max();
-  }
-
-  std::vector<ElemType> lowerBound1(node->Bound().Dim());
-  std::vector<ElemType> highBound1(node->Bound().Dim());
-  std::vector<ElemType> lowerBound2(node->Bound().Dim());
-  std::vector<ElemType> highBound2(node->Bound().Dim());
-
-  for (size_t k = 0; k < node->Bound().Dim(); k++)
-  {
-    lowerBound1[k] = node->Children()[sorted[0].n]->Bound()[k].Lo();
-    highBound1[k] = node->Children()[sorted[0].n]->Bound()[k].Hi();
-
-    for (size_t i = 1; i < splitPointer; i++)
-    {
-      if (node->Children()[sorted[i].n]->Bound()[k].Lo() < lowerBound1[k])
-        lowerBound1[k] = node->Children()[sorted[i].n]->Bound()[k].Lo();
-      if (node->Children()[sorted[i].n]->Bound()[k].Hi() > highBound1[k])
-        highBound1[k] = node->Children()[sorted[i].n]->Bound()[k].Hi();
-    }
-
-    lowerBound2[k] = node->Children()[sorted[splitPointer].n]->Bound()[k].Lo();
-    highBound2[k] = node->Children()[sorted[splitPointer].n]->Bound()[k].Hi();
-
-    for (size_t i = splitPointer + 1; i < node->NumChildren(); i++)
-    {
-      if (node->Children()[sorted[i].n]->Bound()[k].Lo() < lowerBound2[k])
-        lowerBound2[k] = node->Children()[sorted[i].n]->Bound()[k].Lo();
-      if (node->Children()[sorted[i].n]->Bound()[k].Hi() > highBound2[k])
-        highBound2[k] = node->Children()[sorted[i].n]->Bound()[k].Hi();
-    }
-  }
-
-  ElemType area1 = 1.0, area2 = 1.0;
-  ElemType overlappedArea = 1.0;
-
-  for (size_t k = 0; k < node->Bound().Dim(); k++)
-  {
-    if (lowerBound1[k] >= highBound1[k])
-    {
-      overlappedArea *= 0;
-      area1 *= 0;
-    }
-    else
-      area1 *= highBound1[k] - lowerBound1[k];
-
-    if (lowerBound2[k] >= highBound2[k])
-    {
-      overlappedArea *= 0;
-      area1 *= 0;
-    }
-    else
-      area2 *= highBound2[k] - lowerBound2[k];
-
-    if (lowerBound1[k] < highBound1[k] && lowerBound2[k] < highBound2[k])
-    {
-      if (lowerBound1[k] > highBound2[k] || lowerBound2[k] > highBound2[k])
-        overlappedArea *= 0;
-      else
-        overlappedArea *= std::min(highBound1[k], highBound2[k]) -
-            std::max(lowerBound1[k], lowerBound2[k]);
-    }
-  }
-
-  return area1 + area2 - overlappedArea;
-}
-
-template<typename SplitPolicyType>
-template<typename TreeType>
-double RPlusTreeSplit<SplitPolicyType>::SweepLeafNode(size_t axis, const TreeType* node,
-    size_t fillFactor, double& axisCut)
-{
-  typedef typename TreeType::ElemType ElemType;
-
-  std::vector<SortStruct<ElemType>> sorted(node->Count());
-
-  sorted.resize(node->Count());
-
-  for (size_t i = 0; i < node->NumPoints(); i++)
-  {
-    sorted[i].d = node->Dataset().col(node->Point(i))[axis];
-    sorted[i].n = i;
-  }
-
-  std::sort(sorted.begin(), sorted.end(), StructComp<ElemType>);
-
-  axisCut = sorted[fillFactor - 1].d;
-
-  if (!CheckLeafSweep(node, axis, axisCut))
-    return std::numeric_limits<double>::max();
-
-  std::vector<ElemType> lowerBound1(node->Bound().Dim());
-  std::vector<ElemType> highBound1(node->Bound().Dim());
-  std::vector<ElemType> lowerBound2(node->Bound().Dim());
-  std::vector<ElemType> highBound2(node->Bound().Dim());
-
-  for (size_t k = 0; k < node->Bound().Dim(); k++)
-  {
-    lowerBound1[k] = node->Dataset().col(node->Point(sorted[0].n))[k];
-    highBound1[k] = node->Dataset().col(node->Point(sorted[0].n))[k];
-
-    for (size_t i = 1; i < fillFactor; i++)
-    {
-      if (node->Dataset().col(node->Point(sorted[i].n))[k] < lowerBound1[k])
-        lowerBound1[k] = node->Dataset().col(node->Point(sorted[i].n))[k];
-      if (node->Dataset().col(node->Point(sorted[i].n))[k] > highBound1[k])
-        highBound1[k] = node->Dataset().col(node->Point(sorted[i].n))[k];
-    }
-
-    lowerBound2[k] = node->Dataset().col(node->Point(sorted[fillFactor].n))[k];
-    highBound2[k] = node->Dataset().col(node->Point(sorted[fillFactor].n))[k];
-
-    for (size_t i = fillFactor + 1; i < node->NumChildren(); i++)
-    {
-      if (node->Dataset().col(node->Point(sorted[i].n))[k] < lowerBound2[k])
-        lowerBound2[k] = node->Dataset().col(node->Point(sorted[i].n))[k];
-      if (node->Dataset().col(node->Point(sorted[i].n))[k] > highBound2[k])
-        highBound2[k] = node->Dataset().col(node->Point(sorted[i].n))[k];
-    }
-  }
-
-  ElemType area1 = 1.0, area2 = 1.0;
-  ElemType overlappedArea = 1.0;
-
-  for (size_t k = 0; k < node->Bound().Dim(); k++)
-  {
-    area1 *= highBound1[k] - lowerBound1[k];
-    area2 *= highBound2[k] - lowerBound2[k];
-  }
-
-  return area1 + area2 - overlappedArea;
-}
-
-template<typename SplitPolicyType>
+template<typename SplitPolicyType,
+         template<typename> class SweepType>
 template<typename TreeType>
-void RPlusTreeSplit<SplitPolicyType>::
+void RPlusTreeSplit<SplitPolicyType, SweepType>::
 InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode)
 {
   destTree->Bound() |= srcNode->Bound();
diff --git a/src/mlpack/core/tree/rectangle_tree/typedef.hpp b/src/mlpack/core/tree/rectangle_tree/typedef.hpp
index b22c3b1..6371d4d 100644
--- a/src/mlpack/core/tree/rectangle_tree/typedef.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/typedef.hpp
@@ -129,7 +129,8 @@ template<typename MetricType, typename StatisticType, typename MatType>
 using RPlusTree = RectangleTree<MetricType,
                             StatisticType,
                             MatType,
-                            RPlusTreeSplit<RPlusTreeSplitPolicy>,
+                            RPlusTreeSplit<RPlusTreeSplitPolicy,
+                                           MinimalCoverageSweep>,
                             RPlusTreeDescentHeuristic,
                             NoAuxiliaryInformation>;
 
@@ -137,7 +138,8 @@ template<typename MetricType, typename StatisticType, typename MatType>
 using RPlusPlusTree = RectangleTree<MetricType,
                             StatisticType,
                             MatType,
-                            RPlusTreeSplit<RPlusPlusTreeSplitPolicy>,
+                            RPlusTreeSplit<RPlusPlusTreeSplitPolicy,
+                                           MinimalCoverageSweep>,
                             RPlusPlusTreeDescentHeuristic,
                             RPlusPlusTreeAuxiliaryInformation>;
 } // namespace tree
diff --git a/src/mlpack/tests/rectangle_tree_test.cpp b/src/mlpack/tests/rectangle_tree_test.cpp
index 3568a78..0aedb5c 100644
--- a/src/mlpack/tests/rectangle_tree_test.cpp
+++ b/src/mlpack/tests/rectangle_tree_test.cpp
@@ -137,7 +137,17 @@ void CheckContainment(const TreeType& tree)
     for (size_t i = 0; i < tree.NumChildren(); i++)
     {
       for (size_t j = 0; j < tree.Bound().Dim(); j++)
-        BOOST_REQUIRE(tree.Bound()[j].Contains(tree.Children()[i]->Bound()[j]));
+      {
+        //  All children should be covered by the parent node.
+        //  Some children can be empty (only in case of the R++ tree)
+        bool success = (tree.Children()[i]->Bound()[j].Hi() ==
+                std::numeric_limits<typename TreeType::ElemType>::lowest() &&
+                tree.Children()[i]->Bound()[j].Lo() ==
+                std::numeric_limits<typename TreeType::ElemType>::max()) ||
+            tree.Bound()[j].Contains(tree.Children()[i]->Bound()[j]);
+
+        BOOST_REQUIRE(success);
+      }
 
       CheckContainment(*(tree.Children()[i]));
     }
@@ -921,6 +931,17 @@ BOOST_AUTO_TEST_CASE(RPlusPlusTreeBoundTest)
   TreeType rPlusPlusTree(dataset, 20, 6, 5, 2, 0);
 
   CheckRPlusPlusTreeBound(&rPlusPlusTree);
+
+  typedef RectangleTree<EuclideanDistance,
+      NeighborSearchStat<NearestNeighborSort>, arma::mat,
+      RPlusTreeSplit<RPlusPlusTreeSplitPolicy, MinimalSplitsNumberSweep>,
+      RPlusPlusTreeDescentHeuristic, RPlusPlusTreeAuxiliaryInformation>
+          RPlusPlusTreeMinimalSplits;
+
+  RPlusPlusTreeMinimalSplits rPlusPlusTree2(dataset, 20, 6, 5, 2, 0);
+
+  CheckRPlusPlusTreeBound(&rPlusPlusTree2);
+
 }
 
 BOOST_AUTO_TEST_CASE(RPlusPlusTreeTraverserTest)




More information about the mlpack-git mailing list