[mlpack-git] master: R+ tree implementation (b8da9a9)
gitdub at mlpack.org
gitdub at mlpack.org
Thu Jul 7 17:30:32 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/6147ed01bab6eadcd6a5e796e259a6afacae4662...e0fd69006b17a845f066ea4de1e205fc0922739d
>---------------------------------------------------------------
commit b8da9a9c01630b5455ac31347f69865b85bce370
Author: Mikhail Lozhnikov <lozhnikovma at gmail.com>
Date: Thu Jun 16 05:45:59 2016 +0300
R+ tree implementation
>---------------------------------------------------------------
b8da9a9c01630b5455ac31347f69865b85bce370
src/mlpack/core/tree/CMakeLists.txt | 4 +
src/mlpack/core/tree/rectangle_tree.hpp | 2 +
.../rectangle_tree/hilbert_r_tree_split_impl.hpp | 2 +
...istic.hpp => r_plus_tree_descent_heuristic.hpp} | 17 +-
.../r_plus_tree_descent_heuristic_impl.hpp | 96 +++++
.../core/tree/rectangle_tree/r_plus_tree_split.hpp | 95 +++++
.../tree/rectangle_tree/r_plus_tree_split_impl.hpp | 441 +++++++++++++++++++++
.../tree/rectangle_tree/r_star_tree_split_impl.hpp | 3 +
.../core/tree/rectangle_tree/r_tree_split_impl.hpp | 2 +
.../tree/rectangle_tree/rectangle_tree_impl.hpp | 4 +-
src/mlpack/core/tree/rectangle_tree/typedef.hpp | 7 +
.../core/tree/rectangle_tree/x_tree_split_impl.hpp | 3 +
src/mlpack/tests/rectangle_tree_test.cpp | 90 +++++
13 files changed, 756 insertions(+), 10 deletions(-)
diff --git a/src/mlpack/core/tree/CMakeLists.txt b/src/mlpack/core/tree/CMakeLists.txt
index 28415d5..527dd59 100644
--- a/src/mlpack/core/tree/CMakeLists.txt
+++ b/src/mlpack/core/tree/CMakeLists.txt
@@ -63,6 +63,10 @@ set(SOURCES
rectangle_tree/recursive_hilbert_value_impl.hpp
rectangle_tree/discrete_hilbert_value.hpp
rectangle_tree/discrete_hilbert_value_impl.hpp
+ rectangle_tree/r_plus_tree_descent_heuristic.hpp
+ rectangle_tree/r_plus_tree_descent_heuristic_impl.hpp
+ rectangle_tree/r_plus_tree_split.hpp
+ rectangle_tree/r_plus_tree_split_impl.hpp
statistic.hpp
traversal_info.hpp
tree_traits.hpp
diff --git a/src/mlpack/core/tree/rectangle_tree.hpp b/src/mlpack/core/tree/rectangle_tree.hpp
index de236ad..a28cd9f 100644
--- a/src/mlpack/core/tree/rectangle_tree.hpp
+++ b/src/mlpack/core/tree/rectangle_tree.hpp
@@ -30,6 +30,8 @@
#include "rectangle_tree/hilbert_r_tree_auxiliary_information.hpp"
#include "rectangle_tree/recursive_hilbert_value.hpp"
#include "rectangle_tree/discrete_hilbert_value.hpp"
+#include "rectangle_tree/r_plus_tree_descent_heuristic.hpp"
+#include "rectangle_tree/r_plus_tree_split.hpp"
#include "rectangle_tree/typedef.hpp"
#endif
diff --git a/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_split_impl.hpp
index fd39961..0d4ed5f 100644
--- a/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_split_impl.hpp
@@ -18,6 +18,8 @@ template<typename TreeType>
void HilbertRTreeSplit::
SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
{
+ if (tree->Count() <= tree->MaxLeafSize())
+ return;
// If we are splitting the root node, we need will do things differently so
// that the constructor and other methods don't confuse the end user by giving
// an address of another node.
diff --git a/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_descent_heuristic.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic.hpp
similarity index 65%
copy from src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_descent_heuristic.hpp
copy to src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic.hpp
index c83532a..dfe8e0a 100644
--- a/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_descent_heuristic.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic.hpp
@@ -1,19 +1,19 @@
/**
- * @file hilbert_r_tree_descent_heuristic.hpp
+ * @file r_plus_tree_descent_heuristic.hpp
* @author Mikhail Lozhnikov
*
- * Definition of HilbertRTreeDescentHeuristic, a class that chooses the best child of a
+ * Definition of RPlusTreeDescentHeuristic, a class that chooses the best child of a
* node in an R tree when inserting a new point.
*/
-#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_HILBERT_R_TREE_DESCENT_HEURISTIC_HPP
-#define MLPACK_CORE_TREE_RECTANGLE_TREE_HILBERT_R_TREE_DESCENT_HEURISTIC_HPP
+#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_DESCENT_HEURISTIC_HPP
+#define MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_DESCENT_HEURISTIC_HPP
#include <mlpack/core.hpp>
namespace mlpack {
namespace tree {
-class HilbertRTreeDescentHeuristic
+class RPlusTreeDescentHeuristic
{
public:
/**
@@ -25,7 +25,7 @@ class HilbertRTreeDescentHeuristic
* @param point The number of the point that is being inserted.
*/
template<typename TreeType>
- static size_t ChooseDescentNode(const TreeType* node, const size_t point);
+ static size_t ChooseDescentNode(TreeType* node, const size_t point);
/**
* Evaluate the node using a heuristic. Returns the number of the node
@@ -40,9 +40,10 @@ class HilbertRTreeDescentHeuristic
const TreeType* insertedNode);
};
+
} // namespace tree
} // namespace mlpack
-#include "hilbert_r_tree_descent_heuristic_impl.hpp"
+#include "r_plus_tree_descent_heuristic_impl.hpp"
-#endif // MLPACK_CORE_TREE_RECTANGLE_TREE_HILBERT_R_TREE_DESCENT_HEURISTIC_HPP
+#endif // MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_DESCENT_HEURISTIC_HPP
diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic_impl.hpp
new file mode 100644
index 0000000..265b739
--- /dev/null
+++ b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic_impl.hpp
@@ -0,0 +1,96 @@
+/**
+ * @file hilbert_r_tree_descent_heuristic_impl.hpp
+ * @author Mikhail Lozhnikov
+ *
+ * Implementation of HilbertRTreeDescentHeuristic, a class that chooses the best child
+ * of a node in an R tree when inserting a new point.
+ */
+#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_DESCENT_HEURISTIC_IMPL_HPP
+#define MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_DESCENT_HEURISTIC_IMPL_HPP
+
+#include "r_plus_tree_descent_heuristic.hpp"
+#include "../hrectbound.hpp"
+
+namespace mlpack {
+namespace tree {
+
+template<typename TreeType>
+size_t RPlusTreeDescentHeuristic::
+ChooseDescentNode(TreeType* node, const size_t point)
+{
+ typedef typename TreeType::ElemType ElemType;
+ size_t bestIndex = 0;
+ bool success;
+
+ for (bestIndex = 0; bestIndex < node->NumChildren(); bestIndex++)
+ {
+ if (node->Children()[bestIndex]->Bound().Contains(node->Dataset().col(point)))
+ return bestIndex;
+ }
+
+ for (bestIndex = 0; bestIndex < node->NumChildren(); bestIndex++)
+ {
+ bound::HRectBound<metric::EuclideanDistance, ElemType> bound =
+ node->Children()[bestIndex]->Bound();
+ bound |= node->Dataset().col(point);
+
+ success = true;
+
+ for (size_t j = 0; j < node->NumChildren(); j++)
+ {
+ if (j == bestIndex)
+ continue;
+ success = false;
+ for (size_t k = 0; k < node->Bound().Dim(); k++)
+ {
+ if (bound[k].Lo() >= node->Children()[j]->Bound()[k].Hi() ||
+ node->Children()[j]->Bound()[k].Lo() >= bound[k].Hi())
+ {
+ success = true;
+ break;
+ }
+ }
+ if (!success)
+ break;
+ }
+ if (success)
+ break;
+ }
+
+ if (!success)
+ {
+ size_t depth = node->TreeDepth();
+
+ TreeType* tree = node;
+ while (depth > 1)
+ {
+ TreeType* child = new TreeType(node);
+
+ tree->Children()[tree->NumChildren()++] = child;
+ tree = child;
+ depth--;
+ }
+ return node->NumChildren()-1;
+ }
+
+ assert(bestIndex < node->NumChildren());
+
+ return bestIndex;
+}
+
+template<typename TreeType>
+size_t RPlusTreeDescentHeuristic::
+ChooseDescentNode(const TreeType* node, const TreeType* insertedNode)
+{
+ size_t bestIndex = 0;
+
+ assert(false);
+
+ return bestIndex;
+}
+
+
+} // namespace tree
+} // namespace mlpack
+
+#endif //MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_DESCENT_HEURISTIC_IMPL_HPP
diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split.hpp
new file mode 100644
index 0000000..f06b813
--- /dev/null
+++ b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split.hpp
@@ -0,0 +1,95 @@
+/**
+ * @file r_plus_tree_split.hpp
+ * @author Mikhail Lozhnikov
+ *
+ * Defintion of the RPlusTreeSplit class, a class that splits the nodes of an R
+ * tree, starting at a leaf node and moving upwards if necessary.
+ */
+#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_SPLIT_HPP
+#define MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_SPLIT_HPP
+
+#include <mlpack/core.hpp>
+
+const double fillFactorFraction = 0.5;
+
+namespace mlpack {
+namespace tree /** Trees and tree-building procedures. */ {
+
+class RPlusTreeSplit
+{
+ public:
+ /**
+ * Split a leaf node using the "default" algorithm. If necessary, this split
+ * will propagate upwards through the tree.
+ * @param node. The node that is being split.
+ * @param relevels Not used.
+ */
+ template<typename TreeType>
+ static void SplitLeafNode(TreeType *tree,std::vector<bool>& relevels);
+
+ /**
+ * Split a non-leaf node using the "default" algorithm. If this is a root
+ * node, the tree increases in depth.
+ * @param node. The node that is being split.
+ * @param relevels Not used.
+ */
+ template<typename TreeType>
+ static bool SplitNonLeafNode(TreeType *tree,std::vector<bool>& relevels);
+
+
+
+ private:
+
+ template<typename ElemType>
+ struct SortStruct
+ {
+ ElemType d;
+ int n;
+ };
+
+ template<typename ElemType>
+ static bool StructComp(const SortStruct<ElemType>& s1,
+ const SortStruct<ElemType>& s2)
+ {
+ return s1.d < s2.d;
+ }
+
+ template<typename TreeType>
+ static void SplitLeafNodeAlongPartition(TreeType* tree,
+ TreeType* treeOne, TreeType* treeTwo, size_t cutAxis, double cut);
+
+ template<typename TreeType>
+ static void SplitNonLeafNodeAlongPartition(TreeType* tree,
+ TreeType* treeOne, TreeType* treeTwo, size_t cutAxis, double cut);
+
+ template<typename TreeType>
+ static bool PartitionNode(const TreeType* node, size_t fillFactor,
+ size_t& minCutAxis, double& minCut);
+
+ template<typename TreeType>
+ static double SweepLeafNode(size_t axis, const TreeType* node,
+ size_t fillFactor, double& axisCut);
+
+ template<typename TreeType>
+ static double SweepNonLeafNode(size_t axis, const TreeType* node,
+ size_t fillFactor, double& axisCut);
+
+ template<typename TreeType>
+ static void InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode);
+
+ template<typename TreeType>
+ static bool CheckNonLeafSweep(const TreeType* node,
+ size_t cutAxis, double cut);
+
+ template<typename TreeType>
+ static bool CheckLeafSweep(const TreeType* node, size_t cutAxis, double cut);
+};
+
+} // namespace tree
+} // namespace mlpack
+
+// Include implementation
+#include "r_plus_tree_split_impl.hpp"
+
+#endif // MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_SPLIT_HPP
+
diff --git a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp
new file mode 100644
index 0000000..fca51fc
--- /dev/null
+++ b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp
@@ -0,0 +1,441 @@
+/**
+ * @file r_plus_tree_split_impl.hpp
+ * @author Mikhail Lozhnikov
+ *
+ * Implementation of class (RPlusTreeSplit) to split a RectangleTree.
+ */
+#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_SPLIT_IMPL_HPP
+#define MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_SPLIT_IMPL_HPP
+
+#include "r_plus_tree_split.hpp"
+#include "rectangle_tree.hpp"
+
+namespace mlpack {
+namespace tree {
+
+template<typename TreeType>
+void RPlusTreeSplit::SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
+{
+ if (tree->Count() == 1)
+ {
+ TreeType* node = tree->Parent();
+
+ while (node != NULL)
+ {
+ if (node->NumChildren() == node->MaxNumChildren() + 1)
+ {
+ RPlusTreeSplit::SplitNonLeafNode(node,relevels);
+ return;
+ }
+ node = node->Parent();
+ }
+ return;
+ }
+ else if (tree->Count() <= tree->MaxLeafSize())
+ return;
+ // If we are splitting the root node, we need will do things differently so
+ // that the constructor and other methods don't confuse the end user by giving
+ // an address of another node.
+ if (tree->Parent() == NULL)
+ {
+ // We actually want to copy this way. Pointers and everything.
+ TreeType* copy = new TreeType(*tree, false);
+ copy->Parent() = tree;
+ tree->Count() = 0;
+ tree->NullifyData();
+ // Because this was a leaf node, numChildren must be 0.
+ tree->Children()[(tree->NumChildren())++] = copy;
+ assert(tree->NumChildren() == 1);
+
+ RPlusTreeSplit::SplitLeafNode(copy,relevels);
+ return;
+ }
+
+ const size_t fillFactor = tree->MaxLeafSize() * fillFactorFraction;
+ size_t cutAxis;
+ double cut;
+
+ if ( !PartitionNode(tree, fillFactor, cutAxis, cut))
+ return;
+
+ assert(cutAxis < tree->Bound().Dim());
+
+ TreeType* treeOne = new TreeType(tree->Parent());
+ TreeType* treeTwo = new TreeType(tree->Parent());
+ treeOne->MinLeafSize() = 0;
+ treeOne->MinNumChildren() = 0;
+ treeTwo->MinLeafSize() = 0;
+ treeTwo->MinNumChildren() = 0;
+
+ SplitLeafNodeAlongPartition(tree, treeOne, treeTwo, cutAxis, cut);
+
+ TreeType* parent = tree->Parent();
+ size_t i = 0;
+ while (parent->Children()[i] != tree)
+ i++;
+
+ assert(i < parent->NumChildren());
+
+ parent->Children()[i] = parent->Children()[--parent->NumChildren()];
+
+ InsertNodeIntoTree(parent, treeOne);
+ InsertNodeIntoTree(parent, treeTwo);
+
+ assert(parent->NumChildren() <= parent->MaxNumChildren() + 1);
+ if (parent->NumChildren() == parent->MaxNumChildren() + 1)
+ RPlusTreeSplit::SplitNonLeafNode(parent, relevels);
+
+ tree->SoftDelete();
+}
+
+template<typename TreeType>
+bool RPlusTreeSplit::SplitNonLeafNode(TreeType* tree,
+ std::vector<bool>& relevels)
+{
+ // If we are splitting the root node, we need will do things differently so
+ // that the constructor and other methods don't confuse the end user by giving
+ // an address of another node.
+ if (tree->Parent() == NULL)
+ {
+ // We actually want to copy this way. Pointers and everything.
+ TreeType* copy = new TreeType(*tree, false);
+
+ copy->Parent() = tree;
+ tree->NumChildren() = 0;
+ tree->NullifyData();
+ tree->Children()[(tree->NumChildren())++] = copy;
+
+ RPlusTreeSplit::SplitNonLeafNode(copy,relevels);
+ return true;
+ }
+ const size_t fillFactor = tree->MaxNumChildren() * fillFactorFraction;
+ size_t cutAxis;
+ double cut;
+
+ if ( !PartitionNode(tree, fillFactor, cutAxis, cut))
+ return false;
+
+ assert(cutAxis < tree->Bound().Dim());
+
+ TreeType* treeOne = new TreeType(tree->Parent());
+ TreeType* treeTwo = new TreeType(tree->Parent());
+ treeOne->MinLeafSize() = 0;
+ treeOne->MinNumChildren() = 0;
+ treeTwo->MinLeafSize() = 0;
+ treeTwo->MinNumChildren() = 0;
+
+ SplitNonLeafNodeAlongPartition(tree, treeOne, treeTwo, cutAxis, cut);
+
+ TreeType* parent = tree->Parent();
+ size_t i = 0;
+ while (parent->Children()[i] != tree)
+ i++;
+
+ assert(i < parent->NumChildren());
+
+ parent->Children()[i] = parent->Children()[--parent->NumChildren()];
+
+ InsertNodeIntoTree(parent, treeOne);
+ InsertNodeIntoTree(parent, treeTwo);
+
+ tree->SoftDelete();
+
+ assert(parent->NumChildren() <= parent->MaxNumChildren() + 1);
+
+ if (parent->NumChildren() == parent->MaxNumChildren() + 1)
+ RPlusTreeSplit::SplitNonLeafNode(parent, relevels);
+
+ return false;
+}
+
+template<typename TreeType>
+void RPlusTreeSplit::SplitLeafNodeAlongPartition(TreeType* tree,
+ TreeType* treeOne, TreeType* treeTwo, size_t cutAxis, double cut)
+{
+ for (size_t i = 0; i < tree->NumPoints(); i++)
+ {
+ if (tree->Dataset().col(tree->Point(i))[cutAxis] <= cut)
+ {
+ treeOne->Points()[treeOne->Count()++] = tree->Point(i);
+ treeOne->Bound() |= tree->Dataset().col(tree->Point(i));
+ }
+ else
+ {
+ treeTwo->Points()[treeTwo->Count()++] = tree->Point(i);
+ treeTwo->Bound() |= tree->Dataset().col(tree->Point(i));
+ }
+ }
+ assert(treeOne->Count() <= treeOne->MaxLeafSize());
+ assert(treeTwo->Count() <= treeTwo->MaxLeafSize());
+
+ assert(tree->Count() == treeOne->Count() + treeTwo->Count());
+ assert(treeOne->Bound()[cutAxis].Hi() < treeTwo->Bound()[cutAxis].Lo());
+}
+
+template<typename TreeType>
+void RPlusTreeSplit::SplitNonLeafNodeAlongPartition(TreeType* tree,
+ TreeType* treeOne, TreeType* treeTwo, size_t cutAxis, double cut)
+{
+ for (size_t i = 0; i < tree->NumChildren(); i++)
+ {
+ TreeType* child = tree->Children()[i];
+ if (child->Bound()[cutAxis].Hi() <= cut)
+ {
+ InsertNodeIntoTree(treeOne, child);
+ child->Parent() = treeOne;
+ }
+ else if (child->Bound()[cutAxis].Lo() >= cut)
+ {
+ InsertNodeIntoTree(treeTwo, child);
+ child->Parent() = treeTwo;
+ }
+ else
+ {
+ TreeType* childOne = new TreeType(treeOne);
+ TreeType* childTwo = new TreeType(treeTwo);
+ treeOne->MinLeafSize() = 0;
+ treeOne->MinNumChildren() = 0;
+ treeTwo->MinLeafSize() = 0;
+ treeTwo->MinNumChildren() = 0;
+
+ if (child->IsLeaf())
+ SplitLeafNodeAlongPartition(child, childOne, childTwo, cutAxis, cut);
+ else
+ SplitNonLeafNodeAlongPartition(child, childOne, childTwo, cutAxis, cut);
+
+ InsertNodeIntoTree(treeOne, childOne);
+ InsertNodeIntoTree(treeTwo, childTwo);
+
+ child->SoftDelete();
+ }
+ }
+ assert(treeOne->NumChildren() <= treeOne->MaxNumChildren());
+ assert(treeTwo->NumChildren() <= treeTwo->MaxNumChildren());
+}
+
+template<typename TreeType>
+bool RPlusTreeSplit::CheckNonLeafSweep(const TreeType* node,
+ size_t cutAxis, double cut)
+{
+ size_t numTreeOneChildren = 0;
+ size_t numTreeTwoChildren = 0;
+
+ for (size_t i = 0; i < node->NumChildren(); i++)
+ {
+ TreeType* child = node->Children()[i];
+ if (child->Bound()[cutAxis].Hi() <= cut)
+ numTreeOneChildren++;
+ else if (child->Bound()[cutAxis].Lo() >= cut)
+ numTreeTwoChildren++;
+ else
+ {
+ numTreeOneChildren++;
+ numTreeTwoChildren++;
+ }
+ }
+
+ if (numTreeOneChildren <= node->MaxNumChildren() && numTreeOneChildren > 0 &&
+ numTreeTwoChildren <= node->MaxNumChildren() && numTreeTwoChildren > 0)
+ return true;
+ return false;
+}
+
+template<typename TreeType>
+bool RPlusTreeSplit::CheckLeafSweep(const TreeType* node,
+ size_t cutAxis, double cut)
+{
+ size_t numTreeOnePoints = 0;
+ size_t numTreeTwoPoints = 0;
+
+ for (size_t i = 0; i < node->NumPoints(); i++)
+ {
+ if (node->Dataset().col(node->Point(i))[cutAxis] <= cut)
+ numTreeOnePoints++;
+ else
+ numTreeTwoPoints++;
+ }
+
+ if (numTreeOnePoints <= node->MaxLeafSize() && numTreeOnePoints > 0 &&
+ numTreeTwoPoints <= node->MaxLeafSize() && numTreeTwoPoints > 0)
+ return true;
+ return false;
+}
+
+template<typename TreeType>
+bool RPlusTreeSplit::PartitionNode(const TreeType* node, size_t fillFactor,
+ size_t& minCutAxis, double& minCut)
+{
+ if ((node->NumChildren() <= fillFactor && !node->IsLeaf()) ||
+ (node->Count() <= fillFactor && node->IsLeaf()))
+ return false;
+
+ double minCost = std::numeric_limits<double>::max();
+ minCutAxis = node->Bound().Dim();
+
+ for (size_t k = 0; k < node->Bound().Dim(); k++)
+ {
+ double cut;
+ double cost;
+
+ if (node->IsLeaf())
+ cost = SweepLeafNode(k, node, fillFactor, cut);
+ else
+ cost = SweepNonLeafNode(k, node, fillFactor, cut);
+
+
+ if (cost < minCost)
+ {
+ minCost = cost;
+ minCutAxis = k;
+ minCut = cut;
+ }
+ }
+ return true;
+}
+
+template<typename TreeType>
+double RPlusTreeSplit::SweepNonLeafNode(size_t axis, const TreeType* node,
+ size_t fillFactor, double& axisCut)
+{
+ typedef typename TreeType::ElemType ElemType;
+
+ std::vector<SortStruct<ElemType>> sorted(node->NumChildren());
+
+ for (size_t i = 0; i < node->NumChildren(); i++)
+ {
+ sorted[i].d = node->Children()[i]->Bound()[axis].Hi();
+ sorted[i].n = i;
+ }
+ std::sort(sorted.begin(), sorted.end(), StructComp<ElemType>);
+
+ axisCut = sorted[fillFactor - 1].d;
+
+ if (!CheckNonLeafSweep(node, axis, axisCut))
+ return std::numeric_limits<double>::max();
+
+ std::vector<ElemType> lowerBound1(node->Bound().Dim());
+ std::vector<ElemType> highBound1(node->Bound().Dim());
+ std::vector<ElemType> lowerBound2(node->Bound().Dim());
+ std::vector<ElemType> highBound2(node->Bound().Dim());
+
+ for (size_t k = 0; k < node->Bound().Dim(); k++)
+ {
+ lowerBound1[k] = node->Children()[sorted[0].n]->Bound()[k].Lo();
+ highBound1[k] = node->Children()[sorted[0].n]->Bound()[k].Hi();
+
+ for (size_t i = 1; i < fillFactor; i++)
+ {
+ if (node->Children()[sorted[i].n]->Bound()[k].Lo() < lowerBound1[k])
+ lowerBound1[k] = node->Children()[sorted[i].n]->Bound()[k].Lo();
+ if (node->Children()[sorted[i].n]->Bound()[k].Hi() > highBound1[k])
+ highBound1[k] = node->Children()[sorted[i].n]->Bound()[k].Hi();
+ }
+
+ lowerBound2[k] = node->Children()[sorted[fillFactor].n]->Bound()[k].Lo();
+ highBound2[k] = node->Children()[sorted[fillFactor].n]->Bound()[k].Hi();
+
+ for (size_t i = fillFactor + 1; i < node->NumChildren(); i++)
+ {
+ if (node->Children()[sorted[i].n]->Bound()[k].Lo() < lowerBound2[k])
+ lowerBound2[k] = node->Children()[sorted[i].n]->Bound()[k].Lo();
+ if (node->Children()[sorted[i].n]->Bound()[k].Hi() > highBound2[k])
+ highBound2[k] = node->Children()[sorted[i].n]->Bound()[k].Hi();
+ }
+ }
+
+ ElemType area1 = 1.0, area2 = 1.0;
+ ElemType overlappedArea = 1.0;
+
+ for (size_t k = 0; k < node->Bound().Dim(); k++)
+ {
+ area1 *= highBound1[k] - lowerBound1[k];
+ area2 *= highBound2[k] - lowerBound2[k];
+
+ if (lowerBound1[k] > highBound2[k] || lowerBound2[k] > highBound2[k])
+ overlappedArea *= 0;
+ else
+ overlappedArea *= std::min(highBound1[k], highBound2[k]) -
+ std::max(lowerBound1[k], lowerBound2[k]);
+ }
+
+ return area1 + area2 - overlappedArea;
+}
+
+template<typename TreeType>
+double RPlusTreeSplit::SweepLeafNode(size_t axis, const TreeType* node,
+ size_t fillFactor, double& axisCut)
+{
+ typedef typename TreeType::ElemType ElemType;
+
+ std::vector<SortStruct<ElemType>> sorted(node->Count());
+
+ sorted.resize(node->Count());
+
+ for (size_t i = 0; i < node->NumPoints(); i++)
+ {
+ sorted[i].d = node->Dataset().col(node->Point(i))[axis];
+ sorted[i].n = i;
+ }
+
+ std::sort(sorted.begin(), sorted.end(), StructComp<ElemType>);
+
+ axisCut = sorted[fillFactor - 1].d;
+
+ if (!CheckLeafSweep(node, axis, axisCut))
+ return std::numeric_limits<double>::max();
+
+ std::vector<ElemType> lowerBound1(node->Bound().Dim());
+ std::vector<ElemType> highBound1(node->Bound().Dim());
+ std::vector<ElemType> lowerBound2(node->Bound().Dim());
+ std::vector<ElemType> highBound2(node->Bound().Dim());
+
+ for (size_t k = 0; k < node->Bound().Dim(); k++)
+ {
+ lowerBound1[k] = node->Dataset().col(node->Point(sorted[0].n))[k];
+ highBound1[k] = node->Dataset().col(node->Point(sorted[0].n))[k];
+
+ for (size_t i = 1; i < fillFactor; i++)
+ {
+ if (node->Dataset().col(node->Point(sorted[i].n))[k] < lowerBound1[k])
+ lowerBound1[k] = node->Dataset().col(node->Point(sorted[i].n))[k];
+ if (node->Dataset().col(node->Point(sorted[i].n))[k] > highBound1[k])
+ highBound1[k] = node->Dataset().col(node->Point(sorted[i].n))[k];
+ }
+
+ lowerBound2[k] = node->Dataset().col(node->Point(sorted[fillFactor].n))[k];
+ highBound2[k] = node->Dataset().col(node->Point(sorted[fillFactor].n))[k];
+
+ for (size_t i = fillFactor + 1; i < node->NumChildren(); i++)
+ {
+ if (node->Dataset().col(node->Point(sorted[i].n))[k] < lowerBound2[k])
+ lowerBound2[k] = node->Dataset().col(node->Point(sorted[i].n))[k];
+ if (node->Dataset().col(node->Point(sorted[i].n))[k] > highBound2[k])
+ highBound2[k] = node->Dataset().col(node->Point(sorted[i].n))[k];
+ }
+ }
+
+ ElemType area1 = 1.0, area2 = 1.0;
+ ElemType overlappedArea = 1.0;
+
+ for (size_t k = 0; k < node->Bound().Dim(); k++)
+ {
+ area1 *= highBound1[k] - lowerBound1[k];
+ area2 *= highBound2[k] - lowerBound2[k];
+ }
+
+ return area1 + area2 - overlappedArea;
+}
+
+template<typename TreeType>
+void RPlusTreeSplit::
+InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode)
+{
+ destTree->Bound() |= srcNode->Bound();
+ destTree->Children()[destTree->NumChildren()++] = srcNode;
+}
+
+
+} // namespace tree
+} // namespace mlpack
+
+#endif // MLPACK_CORE_TREE_RECTANGLE_TREE_HILBERT_R_TREE_SPLIT_IMPL_HPP
diff --git a/src/mlpack/core/tree/rectangle_tree/r_star_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_star_tree_split_impl.hpp
index 0ec5c51..6b8e73d 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_star_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_star_tree_split_impl.hpp
@@ -27,6 +27,9 @@ void RStarTreeSplit::SplitLeafNode(TreeType *tree,std::vector<bool>& relevels)
// Convenience typedef.
typedef typename TreeType::ElemType ElemType;
+ if (tree->Count() <= tree->MaxLeafSize())
+ return;
+
// If we are splitting the root node, we need will do things differently so
// that the constructor and other methods don't confuse the end user by giving
// an address of another node.
diff --git a/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
index db67fbe..3ad3629 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
@@ -23,6 +23,8 @@ namespace tree {
template<typename TreeType>
void RTreeSplit::SplitLeafNode(TreeType *tree,std::vector<bool>& relevels)
{
+ if (tree->Count() <= tree->MaxLeafSize())
+ return;
// If we are splitting the root node, we need will do things differently so
// that the constructor and other methods don't confuse the end user by giving
// an address of another node.
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
index b087f79..c6d8ac2 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
@@ -694,8 +694,8 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
if (numChildren == 0)
{
// Check to see if we are full.
- if (count <= maxLeafSize)
- return; // We don't need to split.
+// if (count <= maxLeafSize)
+// return; // We don't need to split.
// If we are full, then we need to split (or at least try). The SplitType
// takes care of this and of moving up the tree if necessary.
diff --git a/src/mlpack/core/tree/rectangle_tree/typedef.hpp b/src/mlpack/core/tree/rectangle_tree/typedef.hpp
index 6622099..4737d5a 100644
--- a/src/mlpack/core/tree/rectangle_tree/typedef.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/typedef.hpp
@@ -125,6 +125,13 @@ using DiscreteHilbertRTree = RectangleTree<MetricType,
HilbertRTreeDescentHeuristic,
DiscreteHilbertRTreeAuxiliaryInformation>;
+template<typename MetricType, typename StatisticType, typename MatType>
+using RPlusTree = RectangleTree<MetricType,
+ StatisticType,
+ MatType,
+ RPlusTreeSplit,
+ RPlusTreeDescentHeuristic,
+ NoAuxiliaryInformation>;
} // namespace tree
} // namespace mlpack
diff --git a/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp
index 89372a4..58591d7 100644
--- a/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp
@@ -26,6 +26,9 @@ void XTreeSplit::SplitLeafNode(TreeType *tree,std::vector<bool>& relevels)
// Convenience typedef.
typedef typename TreeType::ElemType ElemType;
+ if (tree->Count() <= tree->MaxLeafSize())
+ return;
+
// If we are splitting the root node, we need will do things differently so
// that the constructor and other methods don't confuse the end user by giving
// an address of another node.
diff --git a/src/mlpack/tests/rectangle_tree_test.cpp b/src/mlpack/tests/rectangle_tree_test.cpp
index 2af8ec3..0845147 100644
--- a/src/mlpack/tests/rectangle_tree_test.cpp
+++ b/src/mlpack/tests/rectangle_tree_test.cpp
@@ -763,6 +763,96 @@ BOOST_AUTO_TEST_CASE(DiscreteHilbertValueTest)
BOOST_REQUIRE_EQUAL(DiscreteHilbertValue<double>::ComparePoints(point1,point2), 1);
}
+template<typename TreeType>
+void CheckOverlap(TreeType* tree)
+{
+ bool success = true;
+
+ for (size_t i = 0; i < tree->NumChildren(); i++)
+ {
+ success = true;
+
+ for (size_t j = 0; j < tree->NumChildren(); j++)
+ {
+ if (j == i)
+ continue;
+ success = false;
+ for (size_t k = 0; k < tree->Bound().Dim(); k++)
+ {
+ if (tree->Children()[i]->Bound()[k].Lo() >= tree->Children()[j]->Bound()[k].Hi() ||
+ tree->Children()[j]->Bound()[k].Lo() >= tree->Children()[i]->Bound()[k].Hi())
+ {
+ success = true;
+ break;
+ }
+ }
+ if (!success)
+ break;
+ }
+ if (success)
+ break;
+ }
+ assert(success == true);
+
+ for (size_t i = 0; i < tree->NumChildren(); i++)
+ CheckOverlap(tree->Children()[i]);
+}
+
+BOOST_AUTO_TEST_CASE(RPlusTreeOverlapTest)
+{
+ arma::mat dataset;
+ dataset.randu(8, 1000); // 1000 points in 8 dimensions.
+
+ typedef RPlusTree<EuclideanDistance,
+ NeighborSearchStat<NearestNeighborSort>,arma::mat> TreeType;
+ TreeType rPlusTree(dataset, 20, 6, 5, 2, 0);
+
+ CheckOverlap(&rPlusTree);
+}
+
+
+BOOST_AUTO_TEST_CASE(RPlusTreeTraverserTest)
+{
+ arma::mat dataset;
+
+ const int numP = 1000;
+
+ dataset.randu(8, numP); // 1000 points in 8 dimensions.
+ arma::Mat<size_t> neighbors1;
+ arma::mat distances1;
+ arma::Mat<size_t> neighbors2;
+ arma::mat distances2;
+
+ typedef RPlusTree<EuclideanDistance, NeighborSearchStat<NearestNeighborSort>,
+ arma::mat> TreeType;
+ TreeType rPlusTree(dataset, 20, 6, 5, 2, 0);
+
+ // Nearest neighbor search with the X tree.
+
+ NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>, arma::mat, RPlusTree >
+ knn1(&rPlusTree, true);
+
+ BOOST_REQUIRE_EQUAL(rPlusTree.NumDescendants(), numP);
+
+ CheckContainment(rPlusTree);
+ CheckExactContainment(rPlusTree);
+ CheckHierarchy(rPlusTree);
+ CheckOverlap(&rPlusTree);
+
+ knn1.Search(5, neighbors1, distances1);
+
+ // Nearest neighbor search the naive way.
+ KNN knn2(dataset, true, true);
+
+ knn2.Search(5, neighbors2, distances2);
+
+ for (size_t i = 0; i < neighbors1.size(); i++)
+ {
+ BOOST_REQUIRE_EQUAL(neighbors1[i], neighbors2[i]);
+ BOOST_REQUIRE_EQUAL(distances1[i], distances2[i]);
+ }
+}
+
// Test the tree splitting. We set MaxLeafSize and MaxNumChildren rather low
// to allow us to test by hand without adding hundreds of points.
BOOST_AUTO_TEST_CASE(RTreeSplitTest)
More information about the mlpack-git
mailing list