[mlpack-git] master: R+ and R++ trees refactoring. Added a template parameter SweepType. Implemented MinimalSplitsNumberSweep. (da4b598)
gitdub at mlpack.org
gitdub at mlpack.org
Thu Jun 23 15:59:29 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/6147ed01bab6eadcd6a5e796e259a6afacae4662...e0fd69006b17a845f066ea4de1e205fc0922739d
>---------------------------------------------------------------
commit da4b598cc0eacec8807611824887feea1fd397ad
Author: Mikhail Lozhnikov <lozhnikovma at gmail.com>
Date: Thu Jun 23 22:59:29 2016 +0300
R+ and R++ trees refactoring.
Added a template parameter SweepType.
Implemented MinimalSplitsNumberSweep.
>---------------------------------------------------------------
da4b598cc0eacec8807611824887feea1fd397ad
src/mlpack/core/tree/CMakeLists.txt | 4 +
src/mlpack/core/tree/rectangle_tree.hpp | 2 +
.../tree/rectangle_tree/minimal_coverage_sweep.hpp | 64 +++++
.../rectangle_tree/minimal_coverage_sweep_impl.hpp | 236 +++++++++++++++++
.../rectangle_tree/minimal_splits_number_sweep.hpp | 53 ++++
.../minimal_splits_number_sweep_impl.hpp | 89 +++++++
.../r_plus_tree_descent_heuristic_impl.hpp | 2 +-
.../core/tree/rectangle_tree/r_plus_tree_split.hpp | 39 +--
.../tree/rectangle_tree/r_plus_tree_split_impl.hpp | 285 +++------------------
src/mlpack/core/tree/rectangle_tree/typedef.hpp | 6 +-
src/mlpack/tests/rectangle_tree_test.cpp | 23 +-
11 files changed, 525 insertions(+), 278 deletions(-)
diff --git a/src/mlpack/core/tree/CMakeLists.txt b/src/mlpack/core/tree/CMakeLists.txt
index 1cefc80..0a46a1e 100644
--- a/src/mlpack/core/tree/CMakeLists.txt
+++ b/src/mlpack/core/tree/CMakeLists.txt
@@ -65,6 +65,10 @@ set(SOURCES
rectangle_tree/discrete_hilbert_value_impl.hpp
rectangle_tree/r_plus_tree_descent_heuristic.hpp
rectangle_tree/r_plus_tree_descent_heuristic_impl.hpp
+ rectangle_tree/minimal_coverage_sweep.hpp
+ rectangle_tree/minimal_coverage_sweep_impl.hpp
+ rectangle_tree/minimal_splits_number_sweep.hpp
+ rectangle_tree/minimal_splits_number_sweep_impl.hpp
rectangle_tree/r_plus_tree_split.hpp
rectangle_tree/r_plus_tree_split_impl.hpp
rectangle_tree/r_plus_tree_split_policy.hpp
diff --git a/src/mlpack/core/tree/rectangle_tree.hpp b/src/mlpack/core/tree/rectangle_tree.hpp
index d418314..a31b75b 100644
--- a/src/mlpack/core/tree/rectangle_tree.hpp
+++ b/src/mlpack/core/tree/rectangle_tree.hpp
@@ -32,6 +32,8 @@
#include "rectangle_tree/discrete_hilbert_value.hpp"
#include "rectangle_tree/r_plus_tree_descent_heuristic.hpp"
#include "rectangle_tree/r_plus_tree_split_policy.hpp"
+#include "rectangle_tree/minimal_coverage_sweep.hpp"
+#include "rectangle_tree/minimal_splits_number_sweep.hpp"
#include "rectangle_tree/r_plus_tree_split.hpp"
#include "rectangle_tree/r_plus_plus_tree_auxiliary_information.hpp"
#include "rectangle_tree/r_plus_plus_tree_descent_heuristic.hpp"
diff --git a/src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep.hpp b/src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep.hpp
new file mode 100644
index 0000000..4e9ce5a
--- /dev/null
+++ b/src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep.hpp
@@ -0,0 +1,64 @@
+/**
+ * @file minimal_coverage_sweep.hpp
+ * @author Mikhail Lozhnikov
+ *
+ */
+#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_COVERAGE_SWEEP_HPP
+#define MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_COVERAGE_SWEEP_HPP
+
+namespace mlpack {
+namespace tree {
+
+constexpr double fillFactor = 0.5;
+
+template<typename SplitPolicy>
+class MinimalCoverageSweep
+{
+ private:
+ template<typename ElemType>
+ struct SortStruct
+ {
+ ElemType d;
+ int n;
+ };
+
+ template<typename ElemType>
+ static bool StructComp(const SortStruct<ElemType>& s1,
+ const SortStruct<ElemType>& s2)
+ {
+ return s1.d < s2.d;
+ }
+
+ public:
+
+ template<typename TreeType>
+ struct SweepCost
+ {
+ typedef typename TreeType::ElemType type;
+ };
+
+ template<typename TreeType>
+ static typename TreeType::ElemType SweepNonLeafNode(size_t axis,
+ const TreeType* node, typename TreeType::ElemType& axisCut);
+
+ template<typename TreeType>
+ static typename TreeType::ElemType SweepLeafNode(size_t axis,
+ const TreeType* node, typename TreeType::ElemType& axisCut);
+
+ template<typename TreeType, typename ElemType>
+ static bool CheckNonLeafSweep(const TreeType* node, size_t cutAxis,
+ ElemType cut);
+
+ template<typename TreeType, typename ElemType>
+ static bool CheckLeafSweep(const TreeType* node, size_t cutAxis,
+ ElemType cut);
+};
+
+} // namespace tree
+} // namespace mlpack
+
+// Include implementation
+#include "minimal_coverage_sweep_impl.hpp"
+
+#endif // MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_COVERAGE_SWEEP_HPP
+
diff --git a/src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep_impl.hpp b/src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep_impl.hpp
new file mode 100644
index 0000000..b300cfa
--- /dev/null
+++ b/src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep_impl.hpp
@@ -0,0 +1,236 @@
+/**
+ * @file minimal_coverage_sweep_impl.hpp
+ * @author Mikhail Lozhnikov
+ *
+ */
+#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_COVERAGE_SWEEP_IMPL_HPP
+#define MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_COVERAGE_SWEEP_IMPL_HPP
+
+#include "minimal_coverage_sweep.hpp"
+
+namespace mlpack {
+namespace tree {
+
+template<typename SplitPolicy>
+template<typename TreeType>
+typename TreeType::ElemType MinimalCoverageSweep<SplitPolicy>::
+SweepNonLeafNode(size_t axis, const TreeType* node,
+ typename TreeType::ElemType& axisCut)
+{
+ typedef typename TreeType::ElemType ElemType;
+
+ std::vector<SortStruct<ElemType>> sorted(node->NumChildren());
+
+ for (size_t i = 0; i < node->NumChildren(); i++)
+ {
+ sorted[i].d = SplitPolicy::Bound(node->Children()[i])[axis].Hi();
+ sorted[i].n = i;
+ }
+ std::sort(sorted.begin(), sorted.end(), StructComp<ElemType>);
+
+ size_t splitPointer = fillFactor * node->NumChildren();
+
+ axisCut = sorted[splitPointer - 1].d;
+
+ if (!CheckNonLeafSweep(node, axis, axisCut))
+ {
+ for (splitPointer = 1; splitPointer < sorted.size(); splitPointer++)
+ {
+ axisCut = sorted[splitPointer - 1].d;
+ if (CheckNonLeafSweep(node, axis, axisCut))
+ break;
+ }
+
+ if (splitPointer == node->NumChildren())
+ return std::numeric_limits<ElemType>::max();
+ }
+
+ std::vector<ElemType> lowerBound1(node->Bound().Dim());
+ std::vector<ElemType> highBound1(node->Bound().Dim());
+ std::vector<ElemType> lowerBound2(node->Bound().Dim());
+ std::vector<ElemType> highBound2(node->Bound().Dim());
+
+ for (size_t k = 0; k < node->Bound().Dim(); k++)
+ {
+ lowerBound1[k] = node->Children()[sorted[0].n]->Bound()[k].Lo();
+ highBound1[k] = node->Children()[sorted[0].n]->Bound()[k].Hi();
+
+ for (size_t i = 1; i < splitPointer; i++)
+ {
+ if (node->Children()[sorted[i].n]->Bound()[k].Lo() < lowerBound1[k])
+ lowerBound1[k] = node->Children()[sorted[i].n]->Bound()[k].Lo();
+ if (node->Children()[sorted[i].n]->Bound()[k].Hi() > highBound1[k])
+ highBound1[k] = node->Children()[sorted[i].n]->Bound()[k].Hi();
+ }
+
+ lowerBound2[k] = node->Children()[sorted[splitPointer].n]->Bound()[k].Lo();
+ highBound2[k] = node->Children()[sorted[splitPointer].n]->Bound()[k].Hi();
+
+ for (size_t i = splitPointer + 1; i < node->NumChildren(); i++)
+ {
+ if (node->Children()[sorted[i].n]->Bound()[k].Lo() < lowerBound2[k])
+ lowerBound2[k] = node->Children()[sorted[i].n]->Bound()[k].Lo();
+ if (node->Children()[sorted[i].n]->Bound()[k].Hi() > highBound2[k])
+ highBound2[k] = node->Children()[sorted[i].n]->Bound()[k].Hi();
+ }
+ }
+
+ ElemType area1 = 1.0, area2 = 1.0;
+ ElemType overlappedArea = 1.0;
+
+ for (size_t k = 0; k < node->Bound().Dim(); k++)
+ {
+ if (lowerBound1[k] >= highBound1[k])
+ {
+ overlappedArea *= 0;
+ area1 *= 0;
+ }
+ else
+ area1 *= highBound1[k] - lowerBound1[k];
+
+ if (lowerBound2[k] >= highBound2[k])
+ {
+ overlappedArea *= 0;
+ area1 *= 0;
+ }
+ else
+ area2 *= highBound2[k] - lowerBound2[k];
+
+ if (lowerBound1[k] < highBound1[k] && lowerBound2[k] < highBound2[k])
+ {
+ if (lowerBound1[k] > highBound2[k] || lowerBound2[k] > highBound2[k])
+ overlappedArea *= 0;
+ else
+ overlappedArea *= std::min(highBound1[k], highBound2[k]) -
+ std::max(lowerBound1[k], lowerBound2[k]);
+ }
+ }
+
+ return area1 + area2 - overlappedArea;
+}
+
+template<typename SplitPolicy>
+template<typename TreeType>
+typename TreeType::ElemType MinimalCoverageSweep<SplitPolicy>::
+SweepLeafNode(size_t axis, const TreeType* node,
+ typename TreeType::ElemType& axisCut)
+{
+ typedef typename TreeType::ElemType ElemType;
+
+ std::vector<SortStruct<ElemType>> sorted(node->Count());
+
+ sorted.resize(node->Count());
+
+ for (size_t i = 0; i < node->NumPoints(); i++)
+ {
+ sorted[i].d = node->Dataset().col(node->Point(i))[axis];
+ sorted[i].n = i;
+ }
+
+ std::sort(sorted.begin(), sorted.end(), StructComp<ElemType>);
+
+ size_t splitPointer = fillFactor * node->Count();
+
+ axisCut = sorted[splitPointer - 1].d;
+
+ if (!CheckLeafSweep(node, axis, axisCut))
+ return std::numeric_limits<ElemType>::max();
+
+ std::vector<ElemType> lowerBound1(node->Bound().Dim());
+ std::vector<ElemType> highBound1(node->Bound().Dim());
+ std::vector<ElemType> lowerBound2(node->Bound().Dim());
+ std::vector<ElemType> highBound2(node->Bound().Dim());
+
+ for (size_t k = 0; k < node->Bound().Dim(); k++)
+ {
+ lowerBound1[k] = node->Dataset().col(node->Point(sorted[0].n))[k];
+ highBound1[k] = node->Dataset().col(node->Point(sorted[0].n))[k];
+
+ for (size_t i = 1; i < splitPointer; i++)
+ {
+ if (node->Dataset().col(node->Point(sorted[i].n))[k] < lowerBound1[k])
+ lowerBound1[k] = node->Dataset().col(node->Point(sorted[i].n))[k];
+ if (node->Dataset().col(node->Point(sorted[i].n))[k] > highBound1[k])
+ highBound1[k] = node->Dataset().col(node->Point(sorted[i].n))[k];
+ }
+
+ lowerBound2[k] = node->Dataset().col(node->Point(sorted[splitPointer].n))[k];
+ highBound2[k] = node->Dataset().col(node->Point(sorted[splitPointer].n))[k];
+
+ for (size_t i = splitPointer + 1; i < node->NumChildren(); i++)
+ {
+ if (node->Dataset().col(node->Point(sorted[i].n))[k] < lowerBound2[k])
+ lowerBound2[k] = node->Dataset().col(node->Point(sorted[i].n))[k];
+ if (node->Dataset().col(node->Point(sorted[i].n))[k] > highBound2[k])
+ highBound2[k] = node->Dataset().col(node->Point(sorted[i].n))[k];
+ }
+ }
+
+ ElemType area1 = 1.0, area2 = 1.0;
+ ElemType overlappedArea = 1.0;
+
+ for (size_t k = 0; k < node->Bound().Dim(); k++)
+ {
+ area1 *= highBound1[k] - lowerBound1[k];
+ area2 *= highBound2[k] - lowerBound2[k];
+ }
+
+ return area1 + area2 - overlappedArea;
+}
+
+template<typename SplitPolicy>
+template<typename TreeType, typename ElemType>
+bool MinimalCoverageSweep<SplitPolicy>::
+CheckNonLeafSweep(const TreeType* node, size_t cutAxis, ElemType cut)
+{
+ size_t numTreeOneChildren = 0;
+ size_t numTreeTwoChildren = 0;
+
+ for (size_t i = 0; i < node->NumChildren(); i++)
+ {
+ TreeType* child = node->Children()[i];
+ int policy = SplitPolicy::GetSplitPolicy(child, cutAxis, cut);
+ if (policy == SplitPolicy::AssignToFirstTree)
+ numTreeOneChildren++;
+ else if (policy == SplitPolicy::AssignToSecondTree)
+ numTreeTwoChildren++;
+ else
+ {
+ numTreeOneChildren++;
+ numTreeTwoChildren++;
+ }
+ }
+
+ if (numTreeOneChildren <= node->MaxNumChildren() && numTreeOneChildren > 0 &&
+ numTreeTwoChildren <= node->MaxNumChildren() && numTreeTwoChildren > 0)
+ return true;
+ return false;
+}
+
+template<typename SplitPolicy>
+template<typename TreeType, typename ElemType>
+bool MinimalCoverageSweep<SplitPolicy>::
+CheckLeafSweep(const TreeType* node, size_t cutAxis, ElemType cut)
+{
+ size_t numTreeOnePoints = 0;
+ size_t numTreeTwoPoints = 0;
+
+ for (size_t i = 0; i < node->NumPoints(); i++)
+ {
+ if (node->Dataset().col(node->Point(i))[cutAxis] <= cut)
+ numTreeOnePoints++;
+ else
+ numTreeTwoPoints++;
+ }
+
+ if (numTreeOnePoints <= node->MaxLeafSize() && numTreeOnePoints > 0 &&
+ numTreeTwoPoints <= node->MaxLeafSize() && numTreeTwoPoints > 0)
+ return true;
+ return false;
+}
+
+} // namespace tree
+} // namespace mlpack
+
+#endif // MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_COVERAGE_SWEEP_IMPL_HPP
+
diff --git a/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep.hpp b/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep.hpp
new file mode 100644
index 0000000..60cbe2d
--- /dev/null
+++ b/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep.hpp
@@ -0,0 +1,53 @@
+/**
+ * @file minimal_splits_number_sweep.hpp
+ * @author Mikhail Lozhnikov
+ *
+ */
+#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_SPLITS_NUMBER_SWEEP_HPP
+#define MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_SPLITS_NUMBER_SWEEP_HPP
+
+namespace mlpack {
+namespace tree {
+
+template<typename SplitPolicy>
+class MinimalSplitsNumberSweep
+{
+ private:
+ template<typename ElemType>
+ struct SortStruct
+ {
+ ElemType d;
+ int n;
+ };
+
+ template<typename ElemType>
+ static bool StructComp(const SortStruct<ElemType>& s1,
+ const SortStruct<ElemType>& s2)
+ {
+ return s1.d < s2.d;
+ }
+ public:
+ template<typename>
+ struct SweepCost
+ {
+ typedef size_t type;
+ };
+
+ template<typename TreeType>
+ static size_t SweepNonLeafNode(size_t axis, const TreeType* node,
+ typename TreeType::ElemType& axisCut);
+
+ template<typename TreeType>
+ static size_t SweepLeafNode(size_t axis, const TreeType* node,
+ typename TreeType::ElemType& axisCut);
+};
+
+} // namespace tree
+} // namespace mlpack
+
+// Include implementation
+#include "minimal_splits_number_sweep_impl.hpp"
+
+#endif // MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_SPLITS_NUMBER_SWEEP_HPP
+
+
diff --git a/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep_impl.hpp b/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep_impl.hpp
new file mode 100644
index 0000000..21c0dcb
--- /dev/null
+++ b/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep_impl.hpp
@@ -0,0 +1,89 @@
+/**
+ * @file minimal_splits_number_sweep_impl.hpp
+ * @author Mikhail Lozhnikov
+ *
+ */
+#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_SPLITS_NUMBER_SWEEP_IMPL_HPP
+#define MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_SPLITS_NUMBER_SWEEP_IMPL_HPP
+
+#include "minimal_splits_number_sweep.hpp"
+
+namespace mlpack {
+namespace tree {
+
+template<typename SplitPolicy>
+template<typename TreeType>
+size_t MinimalSplitsNumberSweep<SplitPolicy>::
+SweepNonLeafNode(size_t axis, const TreeType* node,
+ typename TreeType::ElemType& axisCut)
+{
+ typedef typename TreeType::ElemType ElemType;
+
+ std::vector<SortStruct<ElemType>> sorted(node->NumChildren());
+
+ for (size_t i = 0; i < node->NumChildren(); i++)
+ {
+ sorted[i].d = SplitPolicy::Bound(node->Children()[i])[axis].Hi();
+ sorted[i].n = i;
+ }
+ std::sort(sorted.begin(), sorted.end(), StructComp<ElemType>);
+
+ size_t minCost = SIZE_MAX;
+
+ for (size_t i = 0; i < sorted.size(); i++)
+ {
+ size_t numTreeOneChildren = 0;
+ size_t numTreeTwoChildren = 0;
+ size_t numSplits = 0;
+
+ for (size_t j = 0; j < node->NumChildren(); j++)
+ {
+ TreeType* child = node->Children()[j];
+ int policy = SplitPolicy::GetSplitPolicy(child, axis, sorted[i].d);
+ if (policy == SplitPolicy::AssignToFirstTree)
+ numTreeOneChildren++;
+ else if (policy == SplitPolicy::AssignToSecondTree)
+ numTreeTwoChildren++;
+ else
+ {
+ numTreeOneChildren++;
+ numTreeTwoChildren++;
+ numSplits++;
+ }
+ }
+
+ if (numTreeOneChildren <= node->MaxNumChildren() && numTreeOneChildren > 0 &&
+ numTreeTwoChildren <= node->MaxNumChildren() && numTreeTwoChildren > 0)
+ {
+ size_t cost = numSplits * (std::abs(sorted.size() / 2 - i));
+ if (cost < minCost)
+ {
+ minCost = cost;
+ axisCut = sorted[i].d;
+ }
+ }
+ }
+ return minCost;
+}
+
+template<typename SplitPolicy>
+template<typename TreeType>
+size_t MinimalSplitsNumberSweep<SplitPolicy>::
+SweepLeafNode(size_t axis, const TreeType* node,
+ typename TreeType::ElemType& axisCut)
+{
+ axisCut = (node->Bound()[axis].Lo() + node->Bound()[axis].Hi()) * 0.5;
+
+ if (node->Bound()[axis].Lo() == axisCut)
+ return SIZE_MAX;
+
+ return 0;
+}
+
+
+} // namespace tree
+} // namespace mlpack
+
+#endif // MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_SPLITS_NUMBER_SWEEP_IMPL_HPP
+
+
diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic_impl.hpp
index 265b739..3d495f9 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic_impl.hpp
@@ -64,7 +64,7 @@ ChooseDescentNode(TreeType* node, const size_t point)
TreeType* tree = node;
while (depth > 1)
{
- TreeType* child = new TreeType(node);
+ TreeType* child = new TreeType(tree);
tree->Children()[tree->NumChildren()++] = child;
tree = child;
diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split.hpp
index 8c2d056..5d5e273 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split.hpp
@@ -10,15 +10,15 @@
#include <mlpack/core.hpp>
-const double fillFactorFraction = 0.5;
-
namespace mlpack {
namespace tree /** Trees and tree-building procedures. */ {
-template<typename SplitPolicyType>
+template<typename SplitPolicyType,
+ template<typename> class SweepType>
class RPlusTreeSplit
{
public:
+ typedef SplitPolicyType SplitPolicy;
/**
* Split a leaf node using the "default" algorithm. If necessary, this split
* will propagate upwards through the tree.
@@ -48,45 +48,24 @@ class RPlusTreeSplit
int n;
};
- template<typename ElemType>
- static bool StructComp(const SortStruct<ElemType>& s1,
- const SortStruct<ElemType>& s2)
- {
- return s1.d < s2.d;
- }
-
template<typename TreeType>
- static void SplitLeafNodeAlongPartition(TreeType* tree,
- TreeType* treeOne, TreeType* treeTwo, size_t cutAxis, double cut);
+ static void SplitLeafNodeAlongPartition(TreeType* tree, TreeType* treeOne,
+ TreeType* treeTwo, size_t cutAxis, typename TreeType::ElemType cut);
template<typename TreeType>
- static void SplitNonLeafNodeAlongPartition(TreeType* tree,
- TreeType* treeOne, TreeType* treeTwo, size_t cutAxis, double cut);
+ static void SplitNonLeafNodeAlongPartition(TreeType* tree, TreeType* treeOne,
+ TreeType* treeTwo, size_t cutAxis, typename TreeType::ElemType cut);
template<typename TreeType>
static void AddFakeNodes(const TreeType* tree, TreeType* emptyTree);
template<typename TreeType>
- static bool PartitionNode(const TreeType* node, size_t fillFactor,
- size_t& minCutAxis, double& minCut);
-
- template<typename TreeType>
- static double SweepLeafNode(size_t axis, const TreeType* node,
- size_t fillFactor, double& axisCut);
-
- template<typename TreeType>
- static double SweepNonLeafNode(size_t axis, const TreeType* node,
- size_t fillFactor, double& axisCut);
+ static bool PartitionNode(const TreeType* node, size_t& minCutAxis,
+ typename TreeType::ElemType& minCut);
template<typename TreeType>
static void InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode);
- template<typename TreeType>
- static bool CheckNonLeafSweep(const TreeType* node,
- size_t cutAxis, double cut);
-
- template<typename TreeType>
- static bool CheckLeafSweep(const TreeType* node, size_t cutAxis, double cut);
};
} // namespace tree
diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp
index e484872..9010595 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp
@@ -16,9 +16,11 @@
namespace mlpack {
namespace tree {
-template<typename SplitPolicyType>
+template<typename SplitPolicyType,
+ template<typename> class SweepType>
template<typename TreeType>
-void RPlusTreeSplit<SplitPolicyType>::SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
+void RPlusTreeSplit<SplitPolicyType, SweepType>::
+SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
{
if (tree->Count() == 1)
{
@@ -55,11 +57,10 @@ void RPlusTreeSplit<SplitPolicyType>::SplitLeafNode(TreeType* tree, std::vector<
return;
}
- const size_t fillFactor = tree->MaxLeafSize() * fillFactorFraction;
size_t cutAxis;
- double cut;
+ typename TreeType::ElemType cut;
- if ( !PartitionNode(tree, fillFactor, cutAxis, cut))
+ if ( !PartitionNode(tree, cutAxis, cut))
return;
assert(cutAxis < tree->Bound().Dim());
@@ -92,10 +93,11 @@ void RPlusTreeSplit<SplitPolicyType>::SplitLeafNode(TreeType* tree, std::vector<
tree->SoftDelete();
}
-template<typename SplitPolicyType>
+template<typename SplitPolicyType,
+ template<typename> class SweepType>
template<typename TreeType>
-bool RPlusTreeSplit<SplitPolicyType>::SplitNonLeafNode(TreeType* tree,
- std::vector<bool>& relevels)
+bool RPlusTreeSplit<SplitPolicyType, SweepType>::
+SplitNonLeafNode(TreeType* tree, std::vector<bool>& relevels)
{
// If we are splitting the root node, we need will do things differently so
// that the constructor and other methods don't confuse the end user by giving
@@ -113,11 +115,10 @@ bool RPlusTreeSplit<SplitPolicyType>::SplitNonLeafNode(TreeType* tree,
RPlusTreeSplit::SplitNonLeafNode(copy,relevels);
return true;
}
- const size_t fillFactor = tree->MaxNumChildren() * fillFactorFraction;
size_t cutAxis;
- double cut;
+ typename TreeType::ElemType cut;
- if ( !PartitionNode(tree, fillFactor, cutAxis, cut))
+ if ( !PartitionNode(tree, cutAxis, cut))
return false;
assert(cutAxis < tree->Bound().Dim());
@@ -153,10 +154,12 @@ bool RPlusTreeSplit<SplitPolicyType>::SplitNonLeafNode(TreeType* tree,
return false;
}
-template<typename SplitPolicyType>
+template<typename SplitPolicyType,
+ template<typename> class SweepType>
template<typename TreeType>
-void RPlusTreeSplit<SplitPolicyType>::SplitLeafNodeAlongPartition(TreeType* tree,
- TreeType* treeOne, TreeType* treeTwo, size_t cutAxis, double cut)
+void RPlusTreeSplit<SplitPolicyType, SweepType>::
+SplitLeafNodeAlongPartition(TreeType* tree, TreeType* treeOne,
+ TreeType* treeTwo, size_t cutAxis, typename TreeType::ElemType cut)
{
tree->AuxiliaryInfo().SplitAuxiliaryInfo(treeOne, treeTwo, cutAxis, cut);
@@ -180,10 +183,12 @@ void RPlusTreeSplit<SplitPolicyType>::SplitLeafNodeAlongPartition(TreeType* tree
assert(treeOne->Bound()[cutAxis].Hi() < treeTwo->Bound()[cutAxis].Lo());
}
-template<typename SplitPolicyType>
+template<typename SplitPolicyType,
+ template<typename> class SweepType>
template<typename TreeType>
-void RPlusTreeSplit<SplitPolicyType>::SplitNonLeafNodeAlongPartition(TreeType* tree,
- TreeType* treeOne, TreeType* treeTwo, size_t cutAxis, double cut)
+void RPlusTreeSplit<SplitPolicyType, SweepType>::
+SplitNonLeafNodeAlongPartition(TreeType* tree, TreeType* treeOne,
+ TreeType* treeTwo, size_t cutAxis, typename TreeType::ElemType cut)
{
tree->AuxiliaryInfo().SplitAuxiliaryInfo(treeOne, treeTwo, cutAxis, cut);
@@ -234,9 +239,10 @@ void RPlusTreeSplit<SplitPolicyType>::SplitNonLeafNodeAlongPartition(TreeType* t
assert(treeTwo->NumChildren() <= treeTwo->MaxNumChildren());
}
-template<typename SplitPolicyType>
+template<typename SplitPolicyType,
+ template<typename> class SweepType>
template<typename TreeType>
-void RPlusTreeSplit<SplitPolicyType>::
+void RPlusTreeSplit<SplitPolicyType, SweepType>::
AddFakeNodes(const TreeType* tree, TreeType* emptyTree)
{
size_t numDescendantNodes = 1;
@@ -258,79 +264,32 @@ AddFakeNodes(const TreeType* tree, TreeType* emptyTree)
}
}
-
-template<typename SplitPolicyType>
-template<typename TreeType>
-bool RPlusTreeSplit<SplitPolicyType>::CheckNonLeafSweep(const TreeType* node,
- size_t cutAxis, double cut)
-{
- size_t numTreeOneChildren = 0;
- size_t numTreeTwoChildren = 0;
-
- for (size_t i = 0; i < node->NumChildren(); i++)
- {
- TreeType* child = node->Children()[i];
- int policy = SplitPolicyType::GetSplitPolicy(child, cutAxis, cut);
- if (policy == SplitPolicyType::AssignToFirstTree)
- numTreeOneChildren++;
- else if (policy == SplitPolicyType::AssignToSecondTree)
- numTreeTwoChildren++;
- else
- {
- numTreeOneChildren++;
- numTreeTwoChildren++;
- }
- }
-
- if (numTreeOneChildren <= node->MaxNumChildren() && numTreeOneChildren > 0 &&
- numTreeTwoChildren <= node->MaxNumChildren() && numTreeTwoChildren > 0)
- return true;
- return false;
-}
-
-template<typename SplitPolicyType>
-template<typename TreeType>
-bool RPlusTreeSplit<SplitPolicyType>::CheckLeafSweep(const TreeType* node,
- size_t cutAxis, double cut)
-{
- size_t numTreeOnePoints = 0;
- size_t numTreeTwoPoints = 0;
-
- for (size_t i = 0; i < node->NumPoints(); i++)
- {
- if (node->Dataset().col(node->Point(i))[cutAxis] <= cut)
- numTreeOnePoints++;
- else
- numTreeTwoPoints++;
- }
-
- if (numTreeOnePoints <= node->MaxLeafSize() && numTreeOnePoints > 0 &&
- numTreeTwoPoints <= node->MaxLeafSize() && numTreeTwoPoints > 0)
- return true;
- return false;
-}
-
-template<typename SplitPolicyType>
+template<typename SplitPolicyType,
+ template<typename> class SweepType>
template<typename TreeType>
-bool RPlusTreeSplit<SplitPolicyType>::PartitionNode(const TreeType* node, size_t fillFactor,
- size_t& minCutAxis, double& minCut)
+bool RPlusTreeSplit<SplitPolicyType, SweepType>::
+PartitionNode(const TreeType* node, size_t& minCutAxis,
+ typename TreeType::ElemType& minCut)
{
if ((node->NumChildren() <= fillFactor && !node->IsLeaf()) ||
(node->Count() <= fillFactor && node->IsLeaf()))
return false;
- double minCost = std::numeric_limits<double>::max();
+ typedef typename SweepType<SplitPolicyType>::template SweepCost<TreeType>::type
+ SweepCostType;
+
+ SweepCostType minCost = std::numeric_limits<SweepCostType>::max();
minCutAxis = node->Bound().Dim();
for (size_t k = 0; k < node->Bound().Dim(); k++)
{
- double cut;
- double cost;
+ typename TreeType::ElemType cut;
+ SweepCostType cost;
if (node->IsLeaf())
- cost = SweepLeafNode(k, node, fillFactor, cut);
+ cost = SweepType<SplitPolicyType>::SweepLeafNode(k, node, cut);
else
- cost = SweepNonLeafNode(k, node, fillFactor, cut);
+ cost = SweepType<SplitPolicyType>::SweepNonLeafNode(k, node, cut);
if (cost < minCost)
@@ -343,172 +302,10 @@ bool RPlusTreeSplit<SplitPolicyType>::PartitionNode(const TreeType* node, size_t
return true;
}
-template<typename SplitPolicyType>
-template<typename TreeType>
-double RPlusTreeSplit<SplitPolicyType>::SweepNonLeafNode(size_t axis, const TreeType* node,
- size_t fillFactor, double& axisCut)
-{
- typedef typename TreeType::ElemType ElemType;
-
- std::vector<SortStruct<ElemType>> sorted(node->NumChildren());
-
- for (size_t i = 0; i < node->NumChildren(); i++)
- {
- sorted[i].d = SplitPolicyType::Bound(node->Children()[i])[axis].Hi();
- sorted[i].n = i;
- }
- std::sort(sorted.begin(), sorted.end(), StructComp<ElemType>);
-
- size_t splitPointer = fillFactor;
-
- axisCut = sorted[splitPointer - 1].d;
-
- if (!CheckNonLeafSweep(node, axis, axisCut))
- {
- for (splitPointer = 1; splitPointer < node->NumChildren(); splitPointer++)
- {
- axisCut = sorted[splitPointer - 1].d;
- if (CheckNonLeafSweep(node, axis, axisCut))
- break;
- }
-
- if (splitPointer == node->NumChildren())
- return std::numeric_limits<double>::max();
- }
-
- std::vector<ElemType> lowerBound1(node->Bound().Dim());
- std::vector<ElemType> highBound1(node->Bound().Dim());
- std::vector<ElemType> lowerBound2(node->Bound().Dim());
- std::vector<ElemType> highBound2(node->Bound().Dim());
-
- for (size_t k = 0; k < node->Bound().Dim(); k++)
- {
- lowerBound1[k] = node->Children()[sorted[0].n]->Bound()[k].Lo();
- highBound1[k] = node->Children()[sorted[0].n]->Bound()[k].Hi();
-
- for (size_t i = 1; i < splitPointer; i++)
- {
- if (node->Children()[sorted[i].n]->Bound()[k].Lo() < lowerBound1[k])
- lowerBound1[k] = node->Children()[sorted[i].n]->Bound()[k].Lo();
- if (node->Children()[sorted[i].n]->Bound()[k].Hi() > highBound1[k])
- highBound1[k] = node->Children()[sorted[i].n]->Bound()[k].Hi();
- }
-
- lowerBound2[k] = node->Children()[sorted[splitPointer].n]->Bound()[k].Lo();
- highBound2[k] = node->Children()[sorted[splitPointer].n]->Bound()[k].Hi();
-
- for (size_t i = splitPointer + 1; i < node->NumChildren(); i++)
- {
- if (node->Children()[sorted[i].n]->Bound()[k].Lo() < lowerBound2[k])
- lowerBound2[k] = node->Children()[sorted[i].n]->Bound()[k].Lo();
- if (node->Children()[sorted[i].n]->Bound()[k].Hi() > highBound2[k])
- highBound2[k] = node->Children()[sorted[i].n]->Bound()[k].Hi();
- }
- }
-
- ElemType area1 = 1.0, area2 = 1.0;
- ElemType overlappedArea = 1.0;
-
- for (size_t k = 0; k < node->Bound().Dim(); k++)
- {
- if (lowerBound1[k] >= highBound1[k])
- {
- overlappedArea *= 0;
- area1 *= 0;
- }
- else
- area1 *= highBound1[k] - lowerBound1[k];
-
- if (lowerBound2[k] >= highBound2[k])
- {
- overlappedArea *= 0;
- area1 *= 0;
- }
- else
- area2 *= highBound2[k] - lowerBound2[k];
-
- if (lowerBound1[k] < highBound1[k] && lowerBound2[k] < highBound2[k])
- {
- if (lowerBound1[k] > highBound2[k] || lowerBound2[k] > highBound2[k])
- overlappedArea *= 0;
- else
- overlappedArea *= std::min(highBound1[k], highBound2[k]) -
- std::max(lowerBound1[k], lowerBound2[k]);
- }
- }
-
- return area1 + area2 - overlappedArea;
-}
-
-template<typename SplitPolicyType>
-template<typename TreeType>
-double RPlusTreeSplit<SplitPolicyType>::SweepLeafNode(size_t axis, const TreeType* node,
- size_t fillFactor, double& axisCut)
-{
- typedef typename TreeType::ElemType ElemType;
-
- std::vector<SortStruct<ElemType>> sorted(node->Count());
-
- sorted.resize(node->Count());
-
- for (size_t i = 0; i < node->NumPoints(); i++)
- {
- sorted[i].d = node->Dataset().col(node->Point(i))[axis];
- sorted[i].n = i;
- }
-
- std::sort(sorted.begin(), sorted.end(), StructComp<ElemType>);
-
- axisCut = sorted[fillFactor - 1].d;
-
- if (!CheckLeafSweep(node, axis, axisCut))
- return std::numeric_limits<double>::max();
-
- std::vector<ElemType> lowerBound1(node->Bound().Dim());
- std::vector<ElemType> highBound1(node->Bound().Dim());
- std::vector<ElemType> lowerBound2(node->Bound().Dim());
- std::vector<ElemType> highBound2(node->Bound().Dim());
-
- for (size_t k = 0; k < node->Bound().Dim(); k++)
- {
- lowerBound1[k] = node->Dataset().col(node->Point(sorted[0].n))[k];
- highBound1[k] = node->Dataset().col(node->Point(sorted[0].n))[k];
-
- for (size_t i = 1; i < fillFactor; i++)
- {
- if (node->Dataset().col(node->Point(sorted[i].n))[k] < lowerBound1[k])
- lowerBound1[k] = node->Dataset().col(node->Point(sorted[i].n))[k];
- if (node->Dataset().col(node->Point(sorted[i].n))[k] > highBound1[k])
- highBound1[k] = node->Dataset().col(node->Point(sorted[i].n))[k];
- }
-
- lowerBound2[k] = node->Dataset().col(node->Point(sorted[fillFactor].n))[k];
- highBound2[k] = node->Dataset().col(node->Point(sorted[fillFactor].n))[k];
-
- for (size_t i = fillFactor + 1; i < node->NumChildren(); i++)
- {
- if (node->Dataset().col(node->Point(sorted[i].n))[k] < lowerBound2[k])
- lowerBound2[k] = node->Dataset().col(node->Point(sorted[i].n))[k];
- if (node->Dataset().col(node->Point(sorted[i].n))[k] > highBound2[k])
- highBound2[k] = node->Dataset().col(node->Point(sorted[i].n))[k];
- }
- }
-
- ElemType area1 = 1.0, area2 = 1.0;
- ElemType overlappedArea = 1.0;
-
- for (size_t k = 0; k < node->Bound().Dim(); k++)
- {
- area1 *= highBound1[k] - lowerBound1[k];
- area2 *= highBound2[k] - lowerBound2[k];
- }
-
- return area1 + area2 - overlappedArea;
-}
-
-template<typename SplitPolicyType>
+template<typename SplitPolicyType,
+ template<typename> class SweepType>
template<typename TreeType>
-void RPlusTreeSplit<SplitPolicyType>::
+void RPlusTreeSplit<SplitPolicyType, SweepType>::
InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode)
{
destTree->Bound() |= srcNode->Bound();
diff --git a/src/mlpack/core/tree/rectangle_tree/typedef.hpp b/src/mlpack/core/tree/rectangle_tree/typedef.hpp
index b22c3b1..6371d4d 100644
--- a/src/mlpack/core/tree/rectangle_tree/typedef.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/typedef.hpp
@@ -129,7 +129,8 @@ template<typename MetricType, typename StatisticType, typename MatType>
using RPlusTree = RectangleTree<MetricType,
StatisticType,
MatType,
- RPlusTreeSplit<RPlusTreeSplitPolicy>,
+ RPlusTreeSplit<RPlusTreeSplitPolicy,
+ MinimalCoverageSweep>,
RPlusTreeDescentHeuristic,
NoAuxiliaryInformation>;
@@ -137,7 +138,8 @@ template<typename MetricType, typename StatisticType, typename MatType>
using RPlusPlusTree = RectangleTree<MetricType,
StatisticType,
MatType,
- RPlusTreeSplit<RPlusPlusTreeSplitPolicy>,
+ RPlusTreeSplit<RPlusPlusTreeSplitPolicy,
+ MinimalCoverageSweep>,
RPlusPlusTreeDescentHeuristic,
RPlusPlusTreeAuxiliaryInformation>;
} // namespace tree
diff --git a/src/mlpack/tests/rectangle_tree_test.cpp b/src/mlpack/tests/rectangle_tree_test.cpp
index 3568a78..0aedb5c 100644
--- a/src/mlpack/tests/rectangle_tree_test.cpp
+++ b/src/mlpack/tests/rectangle_tree_test.cpp
@@ -137,7 +137,17 @@ void CheckContainment(const TreeType& tree)
for (size_t i = 0; i < tree.NumChildren(); i++)
{
for (size_t j = 0; j < tree.Bound().Dim(); j++)
- BOOST_REQUIRE(tree.Bound()[j].Contains(tree.Children()[i]->Bound()[j]));
+ {
+ // All children should be covered by the parent node.
+ // Some children can be empty (only in case of the R++ tree)
+ bool success = (tree.Children()[i]->Bound()[j].Hi() ==
+ std::numeric_limits<typename TreeType::ElemType>::lowest() &&
+ tree.Children()[i]->Bound()[j].Lo() ==
+ std::numeric_limits<typename TreeType::ElemType>::max()) ||
+ tree.Bound()[j].Contains(tree.Children()[i]->Bound()[j]);
+
+ BOOST_REQUIRE(success);
+ }
CheckContainment(*(tree.Children()[i]));
}
@@ -921,6 +931,17 @@ BOOST_AUTO_TEST_CASE(RPlusPlusTreeBoundTest)
TreeType rPlusPlusTree(dataset, 20, 6, 5, 2, 0);
CheckRPlusPlusTreeBound(&rPlusPlusTree);
+
+ typedef RectangleTree<EuclideanDistance,
+ NeighborSearchStat<NearestNeighborSort>, arma::mat,
+ RPlusTreeSplit<RPlusPlusTreeSplitPolicy, MinimalSplitsNumberSweep>,
+ RPlusPlusTreeDescentHeuristic, RPlusPlusTreeAuxiliaryInformation>
+ RPlusPlusTreeMinimalSplits;
+
+ RPlusPlusTreeMinimalSplits rPlusPlusTree2(dataset, 20, 6, 5, 2, 0);
+
+ CheckRPlusPlusTreeBound(&rPlusPlusTree2);
+
}
BOOST_AUTO_TEST_CASE(RPlusPlusTreeTraverserTest)
More information about the mlpack-git
mailing list