[mlpack-git] master: R++ tree implementation (147617a)
gitdub at mlpack.org
gitdub at mlpack.org
Thu Jul 7 17:30:32 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/6147ed01bab6eadcd6a5e796e259a6afacae4662...e0fd69006b17a845f066ea4de1e205fc0922739d
>---------------------------------------------------------------
commit 147617add993a5b3d037aa883412c6ec3f672bbb
Author: Mikhail Lozhnikov <lozhnikovma at gmail.com>
Date: Sun Jun 19 16:34:28 2016 +0300
R++ tree implementation
>---------------------------------------------------------------
147617add993a5b3d037aa883412c6ec3f672bbb
src/mlpack/core/tree/CMakeLists.txt | 6 +
src/mlpack/core/tree/rectangle_tree.hpp | 4 +
.../rectangle_tree/no_auxiliary_information.hpp | 7 ++
.../r_plus_plus_tree_auxiliary_information.hpp | 88 ++++++++++++++
...r_plus_plus_tree_auxiliary_information_impl.hpp | 131 ++++++++++++++++++++
....hpp => r_plus_plus_tree_descent_heuristic.hpp} | 16 +--
.../r_plus_plus_tree_descent_heuristic_impl.hpp | 47 ++++++++
.../r_plus_plus_tree_split_policy.hpp | 46 +++++++
.../core/tree/rectangle_tree/r_plus_tree_split.hpp | 4 +
.../tree/rectangle_tree/r_plus_tree_split_impl.hpp | 133 +++++++++++++++++----
.../rectangle_tree/r_plus_tree_split_policy.hpp | 46 +++++++
src/mlpack/core/tree/rectangle_tree/typedef.hpp | 9 +-
src/mlpack/tests/rectangle_tree_test.cpp | 117 +++++++++++++++++-
13 files changed, 617 insertions(+), 37 deletions(-)
diff --git a/src/mlpack/core/tree/CMakeLists.txt b/src/mlpack/core/tree/CMakeLists.txt
index 527dd59..1cefc80 100644
--- a/src/mlpack/core/tree/CMakeLists.txt
+++ b/src/mlpack/core/tree/CMakeLists.txt
@@ -67,6 +67,12 @@ set(SOURCES
rectangle_tree/r_plus_tree_descent_heuristic_impl.hpp
rectangle_tree/r_plus_tree_split.hpp
rectangle_tree/r_plus_tree_split_impl.hpp
+ rectangle_tree/r_plus_tree_split_policy.hpp
+ rectangle_tree/r_plus_plus_tree_descent_heuristic.hpp
+ rectangle_tree/r_plus_plus_tree_descent_heuristic_impl.hpp
+ rectangle_tree/r_plus_plus_tree_split_policy.hpp
+ rectangle_tree/r_plus_plus_tree_auxiliary_information.hpp
+ rectangle_tree/r_plus_plus_tree_auxiliary_information_impl.hpp
statistic.hpp
traversal_info.hpp
tree_traits.hpp
diff --git a/src/mlpack/core/tree/rectangle_tree.hpp b/src/mlpack/core/tree/rectangle_tree.hpp
index a28cd9f..d418314 100644
--- a/src/mlpack/core/tree/rectangle_tree.hpp
+++ b/src/mlpack/core/tree/rectangle_tree.hpp
@@ -31,7 +31,11 @@
#include "rectangle_tree/recursive_hilbert_value.hpp"
#include "rectangle_tree/discrete_hilbert_value.hpp"
#include "rectangle_tree/r_plus_tree_descent_heuristic.hpp"
+#include "rectangle_tree/r_plus_tree_split_policy.hpp"
#include "rectangle_tree/r_plus_tree_split.hpp"
+#include "rectangle_tree/r_plus_plus_tree_auxiliary_information.hpp"
+#include "rectangle_tree/r_plus_plus_tree_descent_heuristic.hpp"
+#include "rectangle_tree/r_plus_plus_tree_split_policy.hpp"
#include "rectangle_tree/typedef.hpp"
#endif
diff --git a/src/mlpack/core/tree/rectangle_tree/no_auxiliary_information.hpp b/src/mlpack/core/tree/rectangle_tree/no_auxiliary_information.hpp
index ac37908..046075d 100644
--- a/src/mlpack/core/tree/rectangle_tree/no_auxiliary_information.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/no_auxiliary_information.hpp
@@ -65,6 +65,13 @@ class NoAuxiliaryInformation
}
/**
+ * Nothing to split.
+ */
+ void SplitAuxiliaryInfo(TreeType* , TreeType* , size_t ,
+ typename TreeType::ElemType)
+ { }
+
+ /**
* Nothing to copy.
*/
void Copy(TreeType* , TreeType* )
diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_auxiliary_information.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_auxiliary_information.hpp
new file mode 100644
index 0000000..926c300
--- /dev/null
+++ b/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_auxiliary_information.hpp
@@ -0,0 +1,88 @@
+/**
+ * @file r_plus_plus_tree_auxiliary_information.hpp
+ * @author Mikhail Lozhnikov
+ *
+ * Definition of the RPlusPlusTreeAuxiliaryInformation class,
+ * a class that provides some r++-tree specific information
+ * about the nodes.
+ */
+#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_AUXILIARY_INFORMATION_HPP
+#define MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_AUXILIARY_INFORMATION_HPP
+
+#include <mlpack/core.hpp>
+#include "../hrectbound.hpp"
+
+namespace mlpack {
+namespace tree {
+
+template<typename TreeType>
+class RPlusPlusTreeAuxiliaryInformation
+{
+ public:
+ typedef typename TreeType::ElemType ElemType;
+
+ RPlusPlusTreeAuxiliaryInformation();
+ RPlusPlusTreeAuxiliaryInformation(const TreeType* );
+ RPlusPlusTreeAuxiliaryInformation(const RPlusPlusTreeAuxiliaryInformation& );
+
+ /**
+ * Some tree types require to save some properties at the insertion process.
+ * This method should return false if it does not handle the process.
+ */
+ bool HandlePointInsertion(TreeType* , const size_t);
+
+ /**
+ * Some tree types require to save some properties at the insertion process.
+ * This method should return false if it does not handle the process.
+ */
+ bool HandleNodeInsertion(TreeType* , TreeType* ,bool);
+
+ /**
+ * Some tree types require to save some properties at the deletion process.
+ * This method should return false if it does not handle the process.
+ */
+ bool HandlePointDeletion(TreeType* , const size_t);
+
+ /**
+ * Some tree types require to save some properties at the deletion process.
+ * This method should return false if it does not handle the process.
+ */
+ bool HandleNodeRemoval(TreeType* , const size_t);
+
+
+ /**
+ * Some tree types require to propagate the information downward.
+ * This method should return false if this is not the case.
+ */
+ bool UpdateAuxiliaryInfo(TreeType* );
+
+ void SplitAuxiliaryInfo(TreeType* treeOne, TreeType* treeTwo,
+ size_t axis, ElemType cut);
+
+ static void Copy(TreeType* ,const TreeType* );
+
+ void NullifyData();
+
+
+ bound::HRectBound<metric::EuclideanDistance, ElemType>& OuterBound()
+ { return outerBound; }
+
+ const bound::HRectBound<metric::EuclideanDistance, ElemType>& OuterBound() const
+ { return outerBound; }
+ private:
+
+ bound::HRectBound<metric::EuclideanDistance, ElemType> outerBound;
+ public:
+ /**
+ * Serialize the information.
+ */
+ template<typename Archive>
+ void Serialize(Archive &, const unsigned int /* version */);
+};
+
+} // namespace tree
+} // namespace mlpack
+
+#include "r_plus_plus_tree_auxiliary_information_impl.hpp"
+
+#endif//MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_AUXILIARY_INFORMATION_HPP
diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_auxiliary_information_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_auxiliary_information_impl.hpp
new file mode 100644
index 0000000..b6b2aea
--- /dev/null
+++ b/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_auxiliary_information_impl.hpp
@@ -0,0 +1,131 @@
+/**
+ * @file r_plus_plus_tree_auxiliary_information.hpp
+ * @author Mikhail Lozhnikov
+ *
+ * Implementation of the RPlusPlusTreeAuxiliaryInformation class,
+ * a class that provides some r++-tree specific information
+ * about the nodes.
+ */
+#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_AUXILIARY_INFORMATION_IMPL_HPP
+#define MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_AUXILIARY_INFORMATION_IMPL_HPP
+
+#include "r_plus_plus_tree_auxiliary_information.hpp"
+
+namespace mlpack {
+namespace tree {
+
+template<typename TreeType>
+RPlusPlusTreeAuxiliaryInformation<TreeType>::
+RPlusPlusTreeAuxiliaryInformation() :
+ outerBound(0)
+{
+
+}
+
+template<typename TreeType>
+RPlusPlusTreeAuxiliaryInformation<TreeType>::
+RPlusPlusTreeAuxiliaryInformation(const TreeType* tree) :
+ outerBound(tree->Parent() ?
+ tree->Parent()->AuxiliaryInfo().OuterBound() :
+ tree->Bound().Dim())
+{
+ if (!tree->Parent())
+ for (size_t k = 0; k < outerBound.Dim(); k++)
+ {
+ outerBound[k].Lo() = std::numeric_limits<ElemType>::lowest();
+ outerBound[k].Hi() = std::numeric_limits<ElemType>::max();
+ }
+}
+
+template<typename TreeType>
+RPlusPlusTreeAuxiliaryInformation<TreeType>::
+RPlusPlusTreeAuxiliaryInformation(const RPlusPlusTreeAuxiliaryInformation& other) :
+ outerBound(other.OuterBound())
+{
+
+}
+
+template<typename TreeType>
+bool RPlusPlusTreeAuxiliaryInformation<TreeType>::
+HandlePointInsertion(TreeType* , const size_t )
+{
+ return false;
+}
+
+template<typename TreeType>
+bool RPlusPlusTreeAuxiliaryInformation<TreeType>::
+HandleNodeInsertion(TreeType* , TreeType* ,bool)
+{
+ assert(false);
+ return false;
+}
+
+template<typename TreeType>
+bool RPlusPlusTreeAuxiliaryInformation<TreeType>::
+HandlePointDeletion(TreeType* , const size_t)
+{
+ return false;
+}
+
+template<typename TreeType>
+bool RPlusPlusTreeAuxiliaryInformation<TreeType>::
+HandleNodeRemoval(TreeType* , const size_t)
+{
+ return false;
+}
+
+template<typename TreeType>
+bool RPlusPlusTreeAuxiliaryInformation<TreeType>::
+UpdateAuxiliaryInfo(TreeType* )
+{
+ return false;
+}
+
+template<typename TreeType>
+void RPlusPlusTreeAuxiliaryInformation<TreeType>::
+SplitAuxiliaryInfo(TreeType* treeOne, TreeType* treeTwo, size_t axis,
+ typename TreeType::ElemType cut)
+{
+ typedef bound::HRectBound<metric::EuclideanDistance, ElemType> Bound;
+ Bound& treeOneBound = treeOne->AuxiliaryInfo().OuterBound();
+ Bound& treeTwoBound = treeTwo->AuxiliaryInfo().OuterBound();
+
+ treeOneBound = outerBound;
+ treeTwoBound = outerBound;
+
+ treeOneBound[axis].Hi() = cut;
+ treeTwoBound[axis].Lo() = cut;
+}
+
+
+template<typename TreeType>
+void RPlusPlusTreeAuxiliaryInformation<TreeType>::
+Copy(TreeType* dst, const TreeType* src)
+{
+ dst.OuterBound() = src.OuterBound();
+}
+
+template<typename TreeType>
+void RPlusPlusTreeAuxiliaryInformation<TreeType>::
+NullifyData()
+{
+
+}
+
+/**
+ * Serialize the information.
+ */
+template<typename TreeType>
+template<typename Archive>
+void RPlusPlusTreeAuxiliaryInformation<TreeType>::
+Serialize(Archive& ar, const unsigned int /* version */)
+{
+ using data::CreateNVP;
+
+ ar & CreateNVP(outerBound, "outerBound");
+}
+
+} // namespace tree
+} // namespace mlpack
+
+#endif//MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_AUXILIARY_INFORMATION_IMPL_HPP
diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_descent_heuristic.hpp
similarity index 66%
copy from src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic.hpp
copy to src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_descent_heuristic.hpp
index dfe8e0a..18166f4 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_descent_heuristic.hpp
@@ -1,19 +1,19 @@
/**
- * @file r_plus_tree_descent_heuristic.hpp
+ * @file r_plus_plus_tree_descent_heuristic.hpp
* @author Mikhail Lozhnikov
*
- * Definition of RPlusTreeDescentHeuristic, a class that chooses the best child of a
- * node in an R tree when inserting a new point.
+ * Definition of RPlusPlusTreeDescentHeuristic, a class that chooses the best child of a
+ * node in an R++ tree when inserting a new point.
*/
-#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_DESCENT_HEURISTIC_HPP
-#define MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_DESCENT_HEURISTIC_HPP
+#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_DESCENT_HEURISTIC_HPP
+#define MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_DESCENT_HEURISTIC_HPP
#include <mlpack/core.hpp>
namespace mlpack {
namespace tree {
-class RPlusTreeDescentHeuristic
+class RPlusPlusTreeDescentHeuristic
{
public:
/**
@@ -44,6 +44,6 @@ class RPlusTreeDescentHeuristic
} // namespace tree
} // namespace mlpack
-#include "r_plus_tree_descent_heuristic_impl.hpp"
+#include "r_plus_plus_tree_descent_heuristic_impl.hpp"
-#endif // MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_DESCENT_HEURISTIC_HPP
+#endif // MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_DESCENT_HEURISTIC_HPP
diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_descent_heuristic_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_descent_heuristic_impl.hpp
new file mode 100644
index 0000000..566f686
--- /dev/null
+++ b/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_descent_heuristic_impl.hpp
@@ -0,0 +1,47 @@
+/**
+ * @file r_plus_plus_tree_descent_heuristic_impl.hpp
+ * @author Mikhail Lozhnikov
+ *
+ * Implementation of RPlusPlusTreeDescentHeuristic, a class that chooses the best child
+ * of a node in an R++ tree when inserting a new point.
+ */
+#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_DESCENT_HEURISTIC_IMPL_HPP
+#define MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_DESCENT_HEURISTIC_IMPL_HPP
+
+#include "r_plus_plus_tree_descent_heuristic.hpp"
+#include "../hrectbound.hpp"
+
+namespace mlpack {
+namespace tree {
+
+template<typename TreeType>
+size_t RPlusPlusTreeDescentHeuristic::
+ChooseDescentNode(TreeType* node, const size_t point)
+{
+ for (size_t bestIndex = 0; bestIndex < node->NumChildren(); bestIndex++)
+ {
+ if (node->Children()[bestIndex]->AuxiliaryInfo().OuterBound().Contains(node->Dataset().col(point)))
+ return bestIndex;
+ }
+
+ assert(false);
+
+ return 0;
+}
+
+template<typename TreeType>
+size_t RPlusPlusTreeDescentHeuristic::
+ChooseDescentNode(const TreeType* , const TreeType* )
+{
+ size_t bestIndex = 0;
+
+ assert(false);
+
+ return bestIndex;
+}
+
+
+} // namespace tree
+} // namespace mlpack
+
+#endif //MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_DESCENT_HEURISTIC_IMPL_HPP
diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_split_policy.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_split_policy.hpp
new file mode 100644
index 0000000..e0484af
--- /dev/null
+++ b/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_split_policy.hpp
@@ -0,0 +1,46 @@
+/**
+ * @file r_plus_plus_tree_split_policy.hpp
+ * @author Mikhail Lozhnikov
+ *
+ * Defintion and implementation of the RPlusPlusTreeSplitPolicy class, a class that
+ * helps to determine the node into which we should insert an intermediate node.
+ */
+#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_SPLIT_POLICY_HPP
+#define MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_SPLIT_POLICY_HPP
+
+namespace mlpack {
+namespace tree {
+
+class RPlusPlusTreeSplitPolicy
+{
+ public:
+ static const int SplitRequired = 0;
+ static const int AssignToFirstTree = 1;
+ static const int AssignToSecondTree = 2;
+
+ template<typename TreeType>
+ static int GetSplitPolicy(const TreeType* child, size_t axis,
+ typename TreeType::ElemType cut)
+ {
+ if (child->AuxiliaryInfo().OuterBound()[axis].Hi() <= cut)
+ return AssignToFirstTree;
+ else if (child->AuxiliaryInfo().OuterBound()[axis].Lo() >= cut)
+ return AssignToSecondTree;
+
+ return SplitRequired;
+ }
+
+ template<typename TreeType>
+ static const
+ bound::HRectBound<metric::EuclideanDistance, typename TreeType::ElemType>&
+ Bound(const TreeType* node)
+ {
+ return node->AuxiliaryInfo().OuterBound();
+ }
+};
+
+} // namespace tree
+} // namespace mlpack
+#endif // MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_SPLIT_POLICY_HPP
+
+
diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split.hpp
index f06b813..8c2d056 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split.hpp
@@ -15,6 +15,7 @@ const double fillFactorFraction = 0.5;
namespace mlpack {
namespace tree /** Trees and tree-building procedures. */ {
+template<typename SplitPolicyType>
class RPlusTreeSplit
{
public:
@@ -63,6 +64,9 @@ class RPlusTreeSplit
TreeType* treeOne, TreeType* treeTwo, size_t cutAxis, double cut);
template<typename TreeType>
+ static void AddFakeNodes(const TreeType* tree, TreeType* emptyTree);
+
+ template<typename TreeType>
static bool PartitionNode(const TreeType* node, size_t fillFactor,
size_t& minCutAxis, double& minCut);
diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp
index fca51fc..e484872 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp
@@ -9,12 +9,16 @@
#include "r_plus_tree_split.hpp"
#include "rectangle_tree.hpp"
+#include "r_plus_plus_tree_auxiliary_information.hpp"
+#include "r_plus_tree_split_policy.hpp"
+#include "r_plus_plus_tree_split_policy.hpp"
namespace mlpack {
namespace tree {
+template<typename SplitPolicyType>
template<typename TreeType>
-void RPlusTreeSplit::SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
+void RPlusTreeSplit<SplitPolicyType>::SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
{
if (tree->Count() == 1)
{
@@ -88,8 +92,9 @@ void RPlusTreeSplit::SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
tree->SoftDelete();
}
+template<typename SplitPolicyType>
template<typename TreeType>
-bool RPlusTreeSplit::SplitNonLeafNode(TreeType* tree,
+bool RPlusTreeSplit<SplitPolicyType>::SplitNonLeafNode(TreeType* tree,
std::vector<bool>& relevels)
{
// If we are splitting the root node, we need will do things differently so
@@ -148,10 +153,13 @@ bool RPlusTreeSplit::SplitNonLeafNode(TreeType* tree,
return false;
}
+template<typename SplitPolicyType>
template<typename TreeType>
-void RPlusTreeSplit::SplitLeafNodeAlongPartition(TreeType* tree,
+void RPlusTreeSplit<SplitPolicyType>::SplitLeafNodeAlongPartition(TreeType* tree,
TreeType* treeOne, TreeType* treeTwo, size_t cutAxis, double cut)
{
+ tree->AuxiliaryInfo().SplitAuxiliaryInfo(treeOne, treeTwo, cutAxis, cut);
+
for (size_t i = 0; i < tree->NumPoints(); i++)
{
if (tree->Dataset().col(tree->Point(i))[cutAxis] <= cut)
@@ -172,19 +180,24 @@ void RPlusTreeSplit::SplitLeafNodeAlongPartition(TreeType* tree,
assert(treeOne->Bound()[cutAxis].Hi() < treeTwo->Bound()[cutAxis].Lo());
}
+template<typename SplitPolicyType>
template<typename TreeType>
-void RPlusTreeSplit::SplitNonLeafNodeAlongPartition(TreeType* tree,
+void RPlusTreeSplit<SplitPolicyType>::SplitNonLeafNodeAlongPartition(TreeType* tree,
TreeType* treeOne, TreeType* treeTwo, size_t cutAxis, double cut)
{
+ tree->AuxiliaryInfo().SplitAuxiliaryInfo(treeOne, treeTwo, cutAxis, cut);
+
for (size_t i = 0; i < tree->NumChildren(); i++)
{
TreeType* child = tree->Children()[i];
- if (child->Bound()[cutAxis].Hi() <= cut)
+ int policy = SplitPolicyType::GetSplitPolicy(child, cutAxis, cut);
+
+ if (policy == SplitPolicyType::AssignToFirstTree)
{
InsertNodeIntoTree(treeOne, child);
child->Parent() = treeOne;
}
- else if (child->Bound()[cutAxis].Lo() >= cut)
+ else if (policy == SplitPolicyType::AssignToSecondTree)
{
InsertNodeIntoTree(treeTwo, child);
child->Parent() = treeTwo;
@@ -209,12 +222,46 @@ void RPlusTreeSplit::SplitNonLeafNodeAlongPartition(TreeType* tree,
child->SoftDelete();
}
}
+
+ assert(treeOne->NumChildren() + treeTwo->NumChildren() != 0);
+
+ if (treeOne->NumChildren() == 0)
+ AddFakeNodes(treeTwo, treeOne);
+ else if (treeTwo->NumChildren() == 0)
+ AddFakeNodes(treeOne, treeTwo);
+
assert(treeOne->NumChildren() <= treeOne->MaxNumChildren());
assert(treeTwo->NumChildren() <= treeTwo->MaxNumChildren());
}
+template<typename SplitPolicyType>
template<typename TreeType>
-bool RPlusTreeSplit::CheckNonLeafSweep(const TreeType* node,
+void RPlusTreeSplit<SplitPolicyType>::
+AddFakeNodes(const TreeType* tree, TreeType* emptyTree)
+{
+ size_t numDescendantNodes = 1;
+
+ TreeType* node = tree->Children()[0];
+
+ while (!node->IsLeaf())
+ {
+ numDescendantNodes++;
+ node = node->Children()[0];
+ }
+
+ node = emptyTree;
+ for (size_t i = 0; i < numDescendantNodes; i++)
+ {
+ TreeType* child = new TreeType(node);
+
+ node = child;
+ }
+}
+
+
+template<typename SplitPolicyType>
+template<typename TreeType>
+bool RPlusTreeSplit<SplitPolicyType>::CheckNonLeafSweep(const TreeType* node,
size_t cutAxis, double cut)
{
size_t numTreeOneChildren = 0;
@@ -223,9 +270,10 @@ bool RPlusTreeSplit::CheckNonLeafSweep(const TreeType* node,
for (size_t i = 0; i < node->NumChildren(); i++)
{
TreeType* child = node->Children()[i];
- if (child->Bound()[cutAxis].Hi() <= cut)
+ int policy = SplitPolicyType::GetSplitPolicy(child, cutAxis, cut);
+ if (policy == SplitPolicyType::AssignToFirstTree)
numTreeOneChildren++;
- else if (child->Bound()[cutAxis].Lo() >= cut)
+ else if (policy == SplitPolicyType::AssignToSecondTree)
numTreeTwoChildren++;
else
{
@@ -240,8 +288,9 @@ bool RPlusTreeSplit::CheckNonLeafSweep(const TreeType* node,
return false;
}
+template<typename SplitPolicyType>
template<typename TreeType>
-bool RPlusTreeSplit::CheckLeafSweep(const TreeType* node,
+bool RPlusTreeSplit<SplitPolicyType>::CheckLeafSweep(const TreeType* node,
size_t cutAxis, double cut)
{
size_t numTreeOnePoints = 0;
@@ -261,8 +310,9 @@ bool RPlusTreeSplit::CheckLeafSweep(const TreeType* node,
return false;
}
+template<typename SplitPolicyType>
template<typename TreeType>
-bool RPlusTreeSplit::PartitionNode(const TreeType* node, size_t fillFactor,
+bool RPlusTreeSplit<SplitPolicyType>::PartitionNode(const TreeType* node, size_t fillFactor,
size_t& minCutAxis, double& minCut)
{
if ((node->NumChildren() <= fillFactor && !node->IsLeaf()) ||
@@ -293,8 +343,9 @@ bool RPlusTreeSplit::PartitionNode(const TreeType* node, size_t fillFactor,
return true;
}
+template<typename SplitPolicyType>
template<typename TreeType>
-double RPlusTreeSplit::SweepNonLeafNode(size_t axis, const TreeType* node,
+double RPlusTreeSplit<SplitPolicyType>::SweepNonLeafNode(size_t axis, const TreeType* node,
size_t fillFactor, double& axisCut)
{
typedef typename TreeType::ElemType ElemType;
@@ -303,15 +354,27 @@ double RPlusTreeSplit::SweepNonLeafNode(size_t axis, const TreeType* node,
for (size_t i = 0; i < node->NumChildren(); i++)
{
- sorted[i].d = node->Children()[i]->Bound()[axis].Hi();
+ sorted[i].d = SplitPolicyType::Bound(node->Children()[i])[axis].Hi();
sorted[i].n = i;
}
std::sort(sorted.begin(), sorted.end(), StructComp<ElemType>);
- axisCut = sorted[fillFactor - 1].d;
+ size_t splitPointer = fillFactor;
+
+ axisCut = sorted[splitPointer - 1].d;
if (!CheckNonLeafSweep(node, axis, axisCut))
- return std::numeric_limits<double>::max();
+ {
+ for (splitPointer = 1; splitPointer < node->NumChildren(); splitPointer++)
+ {
+ axisCut = sorted[splitPointer - 1].d;
+ if (CheckNonLeafSweep(node, axis, axisCut))
+ break;
+ }
+
+ if (splitPointer == node->NumChildren())
+ return std::numeric_limits<double>::max();
+ }
std::vector<ElemType> lowerBound1(node->Bound().Dim());
std::vector<ElemType> highBound1(node->Bound().Dim());
@@ -323,7 +386,7 @@ double RPlusTreeSplit::SweepNonLeafNode(size_t axis, const TreeType* node,
lowerBound1[k] = node->Children()[sorted[0].n]->Bound()[k].Lo();
highBound1[k] = node->Children()[sorted[0].n]->Bound()[k].Hi();
- for (size_t i = 1; i < fillFactor; i++)
+ for (size_t i = 1; i < splitPointer; i++)
{
if (node->Children()[sorted[i].n]->Bound()[k].Lo() < lowerBound1[k])
lowerBound1[k] = node->Children()[sorted[i].n]->Bound()[k].Lo();
@@ -331,10 +394,10 @@ double RPlusTreeSplit::SweepNonLeafNode(size_t axis, const TreeType* node,
highBound1[k] = node->Children()[sorted[i].n]->Bound()[k].Hi();
}
- lowerBound2[k] = node->Children()[sorted[fillFactor].n]->Bound()[k].Lo();
- highBound2[k] = node->Children()[sorted[fillFactor].n]->Bound()[k].Hi();
+ lowerBound2[k] = node->Children()[sorted[splitPointer].n]->Bound()[k].Lo();
+ highBound2[k] = node->Children()[sorted[splitPointer].n]->Bound()[k].Hi();
- for (size_t i = fillFactor + 1; i < node->NumChildren(); i++)
+ for (size_t i = splitPointer + 1; i < node->NumChildren(); i++)
{
if (node->Children()[sorted[i].n]->Bound()[k].Lo() < lowerBound2[k])
lowerBound2[k] = node->Children()[sorted[i].n]->Bound()[k].Lo();
@@ -348,21 +411,38 @@ double RPlusTreeSplit::SweepNonLeafNode(size_t axis, const TreeType* node,
for (size_t k = 0; k < node->Bound().Dim(); k++)
{
- area1 *= highBound1[k] - lowerBound1[k];
- area2 *= highBound2[k] - lowerBound2[k];
+ if (lowerBound1[k] >= highBound1[k])
+ {
+ overlappedArea *= 0;
+ area1 *= 0;
+ }
+ else
+ area1 *= highBound1[k] - lowerBound1[k];
- if (lowerBound1[k] > highBound2[k] || lowerBound2[k] > highBound2[k])
+ if (lowerBound2[k] >= highBound2[k])
+ {
overlappedArea *= 0;
+ area1 *= 0;
+ }
else
- overlappedArea *= std::min(highBound1[k], highBound2[k]) -
- std::max(lowerBound1[k], lowerBound2[k]);
+ area2 *= highBound2[k] - lowerBound2[k];
+
+ if (lowerBound1[k] < highBound1[k] && lowerBound2[k] < highBound2[k])
+ {
+ if (lowerBound1[k] > highBound2[k] || lowerBound2[k] > highBound2[k])
+ overlappedArea *= 0;
+ else
+ overlappedArea *= std::min(highBound1[k], highBound2[k]) -
+ std::max(lowerBound1[k], lowerBound2[k]);
+ }
}
return area1 + area2 - overlappedArea;
}
+template<typename SplitPolicyType>
template<typename TreeType>
-double RPlusTreeSplit::SweepLeafNode(size_t axis, const TreeType* node,
+double RPlusTreeSplit<SplitPolicyType>::SweepLeafNode(size_t axis, const TreeType* node,
size_t fillFactor, double& axisCut)
{
typedef typename TreeType::ElemType ElemType;
@@ -426,8 +506,9 @@ double RPlusTreeSplit::SweepLeafNode(size_t axis, const TreeType* node,
return area1 + area2 - overlappedArea;
}
+template<typename SplitPolicyType>
template<typename TreeType>
-void RPlusTreeSplit::
+void RPlusTreeSplit<SplitPolicyType>::
InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode)
{
destTree->Bound() |= srcNode->Bound();
diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_policy.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_policy.hpp
new file mode 100644
index 0000000..219eb1a
--- /dev/null
+++ b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_policy.hpp
@@ -0,0 +1,46 @@
+/**
+ * @file r_plus_tree_split_policy.hpp
+ * @author Mikhail Lozhnikov
+ *
+ * Defintion and implementation of the RPlusTreeSplitPolicy class, a class that
+ * helps to determine the node into which we should insert an intermediate node.
+ */
+#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_SPLIT_POLICY_HPP
+#define MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_SPLIT_POLICY_HPP
+
+namespace mlpack {
+namespace tree {
+
+class RPlusTreeSplitPolicy
+{
+ public:
+ static const int SplitRequired = 0;
+ static const int AssignToFirstTree = 1;
+ static const int AssignToSecondTree = 2;
+
+ template<typename TreeType>
+ static int GetSplitPolicy(const TreeType* child, size_t axis,
+ typename TreeType::ElemType cut)
+ {
+ if (child->Bound()[axis].Hi() <= cut)
+ return AssignToFirstTree;
+ else if (child->Bound()[axis].Lo() >= cut)
+ return AssignToSecondTree;
+
+ return SplitRequired;
+ }
+
+ template<typename TreeType>
+ static const
+ bound::HRectBound<metric::EuclideanDistance, typename TreeType::ElemType>&
+ Bound(const TreeType* node)
+ {
+ return node->Bound();
+ }
+};
+
+} // namespace tree
+} // namespace mlpack
+#endif // MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_SPLIT_POLICY_HPP
+
+
diff --git a/src/mlpack/core/tree/rectangle_tree/typedef.hpp b/src/mlpack/core/tree/rectangle_tree/typedef.hpp
index 4737d5a..b22c3b1 100644
--- a/src/mlpack/core/tree/rectangle_tree/typedef.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/typedef.hpp
@@ -129,10 +129,17 @@ template<typename MetricType, typename StatisticType, typename MatType>
using RPlusTree = RectangleTree<MetricType,
StatisticType,
MatType,
- RPlusTreeSplit,
+ RPlusTreeSplit<RPlusTreeSplitPolicy>,
RPlusTreeDescentHeuristic,
NoAuxiliaryInformation>;
+template<typename MetricType, typename StatisticType, typename MatType>
+using RPlusPlusTree = RectangleTree<MetricType,
+ StatisticType,
+ MatType,
+ RPlusTreeSplit<RPlusPlusTreeSplitPolicy>,
+ RPlusPlusTreeDescentHeuristic,
+ RPlusPlusTreeAuxiliaryInformation>;
} // namespace tree
} // namespace mlpack
diff --git a/src/mlpack/tests/rectangle_tree_test.cpp b/src/mlpack/tests/rectangle_tree_test.cpp
index 0845147..3568a78 100644
--- a/src/mlpack/tests/rectangle_tree_test.cpp
+++ b/src/mlpack/tests/rectangle_tree_test.cpp
@@ -789,10 +789,10 @@ void CheckOverlap(TreeType* tree)
if (!success)
break;
}
- if (success)
+ if (!success)
break;
}
- assert(success == true);
+ BOOST_REQUIRE_EQUAL(success, true);
for (size_t i = 0; i < tree->NumChildren(); i++)
CheckOverlap(tree->Children()[i]);
@@ -853,6 +853,119 @@ BOOST_AUTO_TEST_CASE(RPlusTreeTraverserTest)
}
}
+template<typename TreeType>
+void CheckRPlusPlusTreeBound(const TreeType* tree)
+{
+ typedef bound::HRectBound<metric::EuclideanDistance,
+ typename TreeType::ElemType> Bound;
+
+ bool success = true;
+
+ for (size_t k = 0; k < tree->Bound().Dim(); k++)
+ {
+ BOOST_REQUIRE_LE(tree->Bound()[k].Hi(),
+ tree->AuxiliaryInfo().OuterBound()[k].Hi());
+ BOOST_REQUIRE_LE(tree->AuxiliaryInfo().OuterBound()[k].Lo(),
+ tree->Bound()[k].Lo());
+ }
+
+ if (tree->IsLeaf())
+ {
+ for (size_t i = 0; i < tree->Count(); i++)
+ BOOST_REQUIRE_EQUAL(true,
+ tree->Bound().Contains(tree->Dataset().col(tree->Points()[i])));
+
+ return;
+ }
+
+ for (size_t i = 0; i < tree->NumChildren(); i++)
+ {
+ const Bound& bound1 = tree->Children()[i]->AuxiliaryInfo().OuterBound();
+ success = true;
+
+ for (size_t j = 0; j < tree->NumChildren(); j++)
+ {
+ if (j == i)
+ continue;
+ const Bound& bound2 = tree->Children()[j]->AuxiliaryInfo().OuterBound();
+
+ success = false;
+ for (size_t k = 0; k < tree->Bound().Dim(); k++)
+ {
+ if (bound1[k].Lo() >= bound2[k].Hi() ||
+ bound2[k].Lo() >= bound1[k].Hi())
+ {
+ success = true;
+ break;
+ }
+ }
+ if (!success)
+ break;
+ }
+ if (!success)
+ break;
+ }
+ BOOST_REQUIRE_EQUAL(success, true);
+
+ for (size_t i = 0; i < tree->NumChildren(); i++)
+ CheckRPlusPlusTreeBound(tree->Children()[i]);
+}
+
+BOOST_AUTO_TEST_CASE(RPlusPlusTreeBoundTest)
+{
+ arma::mat dataset;
+ dataset.randu(8, 1000); // 1000 points in 8 dimensions.
+
+ typedef RPlusPlusTree<EuclideanDistance,
+ NeighborSearchStat<NearestNeighborSort>,arma::mat> TreeType;
+ TreeType rPlusPlusTree(dataset, 20, 6, 5, 2, 0);
+
+ CheckRPlusPlusTreeBound(&rPlusPlusTree);
+}
+
+BOOST_AUTO_TEST_CASE(RPlusPlusTreeTraverserTest)
+{
+ arma::mat dataset;
+
+ const int numP = 1000;
+
+ dataset.randu(8, numP); // 1000 points in 8 dimensions.
+ arma::Mat<size_t> neighbors1;
+ arma::mat distances1;
+ arma::Mat<size_t> neighbors2;
+ arma::mat distances2;
+
+ typedef RPlusPlusTree<EuclideanDistance, NeighborSearchStat<NearestNeighborSort>,
+ arma::mat> TreeType;
+ TreeType rPlusPlusTree(dataset, 20, 6, 5, 2, 0);
+
+ // Nearest neighbor search with the X tree.
+
+ NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
+ arma::mat, RPlusPlusTree > knn1(&rPlusPlusTree, true);
+
+ BOOST_REQUIRE_EQUAL(rPlusPlusTree.NumDescendants(), numP);
+
+ CheckContainment(rPlusPlusTree);
+ CheckExactContainment(rPlusPlusTree);
+ CheckHierarchy(rPlusPlusTree);
+ CheckRPlusPlusTreeBound(&rPlusPlusTree);
+
+ knn1.Search(5, neighbors1, distances1);
+
+ // Nearest neighbor search the naive way.
+ KNN knn2(dataset, true, true);
+
+ knn2.Search(5, neighbors2, distances2);
+
+ for (size_t i = 0; i < neighbors1.size(); i++)
+ {
+ BOOST_REQUIRE_EQUAL(neighbors1[i], neighbors2[i]);
+ BOOST_REQUIRE_EQUAL(distances1[i], distances2[i]);
+ }
+}
+
+
// Test the tree splitting. We set MaxLeafSize and MaxNumChildren rather low
// to allow us to test by hand without adding hundreds of points.
BOOST_AUTO_TEST_CASE(RTreeSplitTest)
More information about the mlpack-git
mailing list