[mlpack-git] master: Merge remote-tracking branch 'upstream/master' into r_plus_tree-cherry_pick Remove RectangleTree::Children() from the R+/R++ tree. Fix errors. (e165d75)
gitdub at mlpack.org
gitdub at mlpack.org
Wed Jun 29 11:59:27 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/6147ed01bab6eadcd6a5e796e259a6afacae4662...e0fd69006b17a845f066ea4de1e205fc0922739d
>---------------------------------------------------------------
commit e165d759f9ae612b9965f70fbbf8abdb19dc8d07
Merge: c2d7a55 809ed4b
Author: Mikhail Lozhnikov <lozhnikovma at gmail.com>
Date: Wed Jun 29 18:59:27 2016 +0300
Merge remote-tracking branch 'upstream/master' into r_plus_tree-cherry_pick
Remove RectangleTree::Children() from the R+/R++ tree.
Fix errors.
>---------------------------------------------------------------
e165d759f9ae612b9965f70fbbf8abdb19dc8d07
.../rectangle_tree/discrete_hilbert_value_impl.hpp | 28 ++--
.../rectangle_tree/dual_tree_traverser_impl.hpp | 4 +-
.../hilbert_r_tree_auxiliary_information_impl.hpp | 14 +-
.../hilbert_r_tree_descent_heuristic_impl.hpp | 4 +-
.../rectangle_tree/hilbert_r_tree_split_impl.hpp | 87 ++++++-----
.../rectangle_tree/minimal_coverage_sweep_impl.hpp | 28 ++--
.../minimal_splits_number_sweep_impl.hpp | 4 +-
.../r_plus_plus_tree_descent_heuristic_impl.hpp | 2 +-
.../r_plus_plus_tree_split_policy.hpp | 10 +-
.../r_plus_tree_descent_heuristic_impl.hpp | 10 +-
.../tree/rectangle_tree/r_plus_tree_split_impl.hpp | 35 +++--
.../rectangle_tree/r_plus_tree_split_policy.hpp | 10 +-
.../r_star_tree_descent_heuristic_impl.hpp | 44 ++++--
.../tree/rectangle_tree/r_star_tree_split_impl.hpp | 87 +++++------
.../r_tree_descent_heuristic_impl.hpp | 30 ++--
.../core/tree/rectangle_tree/r_tree_split_impl.hpp | 59 ++++----
.../core/tree/rectangle_tree/rectangle_tree.hpp | 19 ++-
.../tree/rectangle_tree/rectangle_tree_impl.hpp | 72 ++++++---
.../rectangle_tree/single_tree_traverser_impl.hpp | 2 +-
.../core/tree/rectangle_tree/x_tree_split_impl.hpp | 103 +++++++------
src/mlpack/methods/lsh/lsh_search_impl.hpp | 27 +++-
src/mlpack/tests/rectangle_tree_test.cpp | 165 ++++++++++++---------
22 files changed, 474 insertions(+), 370 deletions(-)
diff --cc src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep_impl.hpp
index 50b04b9,0000000..e0723d5
mode 100644,000000..100644
--- a/src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/minimal_coverage_sweep_impl.hpp
@@@ -1,260 -1,0 +1,260 @@@
+/**
+ * @file minimal_coverage_sweep_impl.hpp
+ * @author Mikhail Lozhnikov
+ *
+ * Implementation of the MinimalCoverageSweep class, a class that finds a
+ * partition of a node along an axis.
+ */
+#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_COVERAGE_SWEEP_IMPL_HPP
+#define MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_COVERAGE_SWEEP_IMPL_HPP
+
+#include "minimal_coverage_sweep.hpp"
+
+namespace mlpack {
+namespace tree {
+
+template<typename SplitPolicy>
+template<typename TreeType>
+typename TreeType::ElemType MinimalCoverageSweep<SplitPolicy>::
+SweepNonLeafNode(const size_t axis,
+ const TreeType* node,
+ typename TreeType::ElemType& axisCut)
+{
+ typedef typename TreeType::ElemType ElemType;
+
+ std::vector<SortStruct<ElemType>> sorted(node->NumChildren());
+
+ for (size_t i = 0; i < node->NumChildren(); i++)
+ {
- sorted[i].d = SplitPolicy::Bound(node->Children()[i])[axis].Hi();
++ sorted[i].d = SplitPolicy::Bound(node->Child(i))[axis].Hi();
+ sorted[i].n = i;
+ }
+ // Sort high bounds of children.
+ std::sort(sorted.begin(), sorted.end(), StructComp<ElemType>);
+
+ size_t splitPointer = fillFactor * node->NumChildren();
+
+ axisCut = sorted[splitPointer - 1].d;
+
+ // Check if the partition is suitable.
+ if (!CheckNonLeafSweep(node, axis, axisCut))
+ {
+ for (splitPointer = 1; splitPointer < sorted.size(); splitPointer++)
+ {
+ axisCut = sorted[splitPointer - 1].d;
+ if (CheckNonLeafSweep(node, axis, axisCut))
+ break;
+ }
+
+ if (splitPointer == node->NumChildren())
+ return std::numeric_limits<ElemType>::max();
+ }
+
+ std::vector<ElemType> lowerBound1(node->Bound().Dim());
+ std::vector<ElemType> highBound1(node->Bound().Dim());
+ std::vector<ElemType> lowerBound2(node->Bound().Dim());
+ std::vector<ElemType> highBound2(node->Bound().Dim());
+
+ // Find lower and high bounds of two resulting nodes.
+ for (size_t k = 0; k < node->Bound().Dim(); k++)
+ {
- lowerBound1[k] = node->Children()[sorted[0].n]->Bound()[k].Lo();
- highBound1[k] = node->Children()[sorted[0].n]->Bound()[k].Hi();
++ lowerBound1[k] = node->Child(sorted[0].n).Bound()[k].Lo();
++ highBound1[k] = node->Child(sorted[0].n).Bound()[k].Hi();
+
+ for (size_t i = 1; i < splitPointer; i++)
+ {
- if (node->Children()[sorted[i].n]->Bound()[k].Lo() < lowerBound1[k])
- lowerBound1[k] = node->Children()[sorted[i].n]->Bound()[k].Lo();
- if (node->Children()[sorted[i].n]->Bound()[k].Hi() > highBound1[k])
- highBound1[k] = node->Children()[sorted[i].n]->Bound()[k].Hi();
++ if (node->Child(sorted[i].n).Bound()[k].Lo() < lowerBound1[k])
++ lowerBound1[k] = node->Child(sorted[i].n).Bound()[k].Lo();
++ if (node->Child(sorted[i].n).Bound()[k].Hi() > highBound1[k])
++ highBound1[k] = node->Child(sorted[i].n).Bound()[k].Hi();
+ }
+
- lowerBound2[k] = node->Children()[sorted[splitPointer].n]->Bound()[k].Lo();
- highBound2[k] = node->Children()[sorted[splitPointer].n]->Bound()[k].Hi();
++ lowerBound2[k] = node->Child(sorted[splitPointer].n).Bound()[k].Lo();
++ highBound2[k] = node->Child(sorted[splitPointer].n).Bound()[k].Hi();
+
+ for (size_t i = splitPointer + 1; i < node->NumChildren(); i++)
+ {
- if (node->Children()[sorted[i].n]->Bound()[k].Lo() < lowerBound2[k])
- lowerBound2[k] = node->Children()[sorted[i].n]->Bound()[k].Lo();
- if (node->Children()[sorted[i].n]->Bound()[k].Hi() > highBound2[k])
- highBound2[k] = node->Children()[sorted[i].n]->Bound()[k].Hi();
++ if (node->Child(sorted[i].n).Bound()[k].Lo() < lowerBound2[k])
++ lowerBound2[k] = node->Child(sorted[i].n).Bound()[k].Lo();
++ if (node->Child(sorted[i].n).Bound()[k].Hi() > highBound2[k])
++ highBound2[k] = node->Child(sorted[i].n).Bound()[k].Hi();
+ }
+ }
+
+ // Evaluate the cost of the split i.e. calculate the total coverage
+ // of two resulting nodes.
+
+ ElemType area1 = 1.0, area2 = 1.0;
+ ElemType overlappedArea = 1.0;
+
+ for (size_t k = 0; k < node->Bound().Dim(); k++)
+ {
+ if (lowerBound1[k] >= highBound1[k])
+ {
+ overlappedArea *= 0;
+ area1 *= 0;
+ }
+ else
+ area1 *= highBound1[k] - lowerBound1[k];
+
+ if (lowerBound2[k] >= highBound2[k])
+ {
+ overlappedArea *= 0;
+ area1 *= 0;
+ }
+ else
+ area2 *= highBound2[k] - lowerBound2[k];
+
+ if (lowerBound1[k] < highBound1[k] && lowerBound2[k] < highBound2[k])
+ {
+ if (lowerBound1[k] > highBound2[k] || lowerBound2[k] > highBound2[k])
+ overlappedArea *= 0;
+ else
+ overlappedArea *= std::min(highBound1[k], highBound2[k]) -
+ std::max(lowerBound1[k], lowerBound2[k]);
+ }
+ }
+
+ return area1 + area2 - overlappedArea;
+}
+
+template<typename SplitPolicy>
+template<typename TreeType>
+typename TreeType::ElemType MinimalCoverageSweep<SplitPolicy>::
+SweepLeafNode(const size_t axis,
+ const TreeType* node,
+ typename TreeType::ElemType& axisCut)
+{
+ typedef typename TreeType::ElemType ElemType;
+
+ std::vector<SortStruct<ElemType>> sorted(node->Count());
+
+ sorted.resize(node->Count());
+
+ for (size_t i = 0; i < node->NumPoints(); i++)
+ {
+ sorted[i].d = node->Dataset().col(node->Point(i))[axis];
+ sorted[i].n = i;
+ }
+
+ // Sort high bounds of children.
+ std::sort(sorted.begin(), sorted.end(), StructComp<ElemType>);
+
+ size_t splitPointer = fillFactor * node->Count();
+
+ axisCut = sorted[splitPointer - 1].d;
+
+ // Check if the partition is suitable.
+ if (!CheckLeafSweep(node, axis, axisCut))
+ return std::numeric_limits<ElemType>::max();
+
+ std::vector<ElemType> lowerBound1(node->Bound().Dim());
+ std::vector<ElemType> highBound1(node->Bound().Dim());
+ std::vector<ElemType> lowerBound2(node->Bound().Dim());
+ std::vector<ElemType> highBound2(node->Bound().Dim());
+
+ // Find lower and high bounds of two resulting nodes.
+ for (size_t k = 0; k < node->Bound().Dim(); k++)
+ {
+ lowerBound1[k] = node->Dataset().col(node->Point(sorted[0].n))[k];
+ highBound1[k] = node->Dataset().col(node->Point(sorted[0].n))[k];
+
+ for (size_t i = 1; i < splitPointer; i++)
+ {
+ if (node->Dataset().col(node->Point(sorted[i].n))[k] < lowerBound1[k])
+ lowerBound1[k] = node->Dataset().col(node->Point(sorted[i].n))[k];
+ if (node->Dataset().col(node->Point(sorted[i].n))[k] > highBound1[k])
+ highBound1[k] = node->Dataset().col(node->Point(sorted[i].n))[k];
+ }
+
+ lowerBound2[k] = node->Dataset().col(
+ node->Point(sorted[splitPointer].n))[k];
+ highBound2[k] = node->Dataset().col(node->Point(sorted[splitPointer].n))[k];
+
+ for (size_t i = splitPointer + 1; i < node->NumChildren(); i++)
+ {
+ if (node->Dataset().col(node->Point(sorted[i].n))[k] < lowerBound2[k])
+ lowerBound2[k] = node->Dataset().col(node->Point(sorted[i].n))[k];
+ if (node->Dataset().col(node->Point(sorted[i].n))[k] > highBound2[k])
+ highBound2[k] = node->Dataset().col(node->Point(sorted[i].n))[k];
+ }
+ }
+
+ // Evaluate the cost of the split i.e. calculate the total coverage
+ // of two resulting nodes.
+
+ ElemType area1 = 1.0, area2 = 1.0;
+ ElemType overlappedArea = 1.0;
+
+ for (size_t k = 0; k < node->Bound().Dim(); k++)
+ {
+ area1 *= highBound1[k] - lowerBound1[k];
+ area2 *= highBound2[k] - lowerBound2[k];
+ }
+
+ return area1 + area2 - overlappedArea;
+}
+
+template<typename SplitPolicy>
+template<typename TreeType, typename ElemType>
+bool MinimalCoverageSweep<SplitPolicy>::
+CheckNonLeafSweep(const TreeType* node,
+ const size_t cutAxis,
+ const ElemType cut)
+{
+ size_t numTreeOneChildren = 0;
+ size_t numTreeTwoChildren = 0;
+
+ // Calculate the number of children in the resulting nodes.
+ for (size_t i = 0; i < node->NumChildren(); i++)
+ {
- TreeType* child = node->Children()[i];
++ const TreeType& child = node->Child(i);
+ int policy = SplitPolicy::GetSplitPolicy(child, cutAxis, cut);
+ if (policy == SplitPolicy::AssignToFirstTree)
+ numTreeOneChildren++;
+ else if (policy == SplitPolicy::AssignToSecondTree)
+ numTreeTwoChildren++;
+ else
+ {
+ // The split is required.
+ numTreeOneChildren++;
+ numTreeTwoChildren++;
+ }
+ }
+
+ if (numTreeOneChildren <= node->MaxNumChildren() && numTreeOneChildren > 0 &&
+ numTreeTwoChildren <= node->MaxNumChildren() && numTreeTwoChildren > 0)
+ return true;
+ return false;
+}
+
+template<typename SplitPolicy>
+template<typename TreeType, typename ElemType>
+bool MinimalCoverageSweep<SplitPolicy>::
+CheckLeafSweep(const TreeType* node,
+ const size_t cutAxis,
+ const ElemType cut)
+{
+ size_t numTreeOnePoints = 0;
+ size_t numTreeTwoPoints = 0;
+
+ // Calculate the number of points in the resulting nodes.
+ for (size_t i = 0; i < node->NumPoints(); i++)
+ {
+ if (node->Dataset().col(node->Point(i))[cutAxis] <= cut)
+ numTreeOnePoints++;
+ else
+ numTreeTwoPoints++;
+ }
+
+ if (numTreeOnePoints <= node->MaxLeafSize() && numTreeOnePoints > 0 &&
+ numTreeTwoPoints <= node->MaxLeafSize() && numTreeTwoPoints > 0)
+ return true;
+ return false;
+}
+
+} // namespace tree
+} // namespace mlpack
+
+#endif // MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_COVERAGE_SWEEP_IMPL_HPP
+
diff --cc src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep_impl.hpp
index 7f63568,0000000..8bcc0a6
mode 100644,000000..100644
--- a/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/minimal_splits_number_sweep_impl.hpp
@@@ -1,107 -1,0 +1,107 @@@
+/**
+ * @file minimal_splits_number_sweep_impl.hpp
+ * @author Mikhail Lozhnikov
+ *
+ * Implementation of the MinimalSplitsNumberSweep class, a class that finds a
+ * partition of a node along an axis.
+ */
+#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_SPLITS_NUMBER_SWEEP_IMPL_HPP
+#define MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_SPLITS_NUMBER_SWEEP_IMPL_HPP
+
+#include "minimal_splits_number_sweep.hpp"
+
+namespace mlpack {
+namespace tree {
+
+template<typename SplitPolicy>
+template<typename TreeType>
+size_t MinimalSplitsNumberSweep<SplitPolicy>::SweepNonLeafNode(
+ const size_t axis,
+ const TreeType* node,
+ typename TreeType::ElemType& axisCut)
+{
+ typedef typename TreeType::ElemType ElemType;
+
+ std::vector<SortStruct<ElemType>> sorted(node->NumChildren());
+
+ for (size_t i = 0; i < node->NumChildren(); i++)
+ {
- sorted[i].d = SplitPolicy::Bound(node->Children()[i])[axis].Hi();
++ sorted[i].d = SplitPolicy::Bound(node->Child(i))[axis].Hi();
+ sorted[i].n = i;
+ }
+
+ // Sort candidates in order to check balancing.
+ std::sort(sorted.begin(), sorted.end(), StructComp<ElemType>);
+
+ size_t minCost = SIZE_MAX;
+
+ // Find a split with the minimal cost.
+ for (size_t i = 0; i < sorted.size(); i++)
+ {
+ size_t numTreeOneChildren = 0;
+ size_t numTreeTwoChildren = 0;
+ size_t numSplits = 0;
+
+ // Calculate the number of splits.
+ for (size_t j = 0; j < node->NumChildren(); j++)
+ {
- TreeType* child = node->Children()[j];
++ const TreeType& child = node->Child(j);
+ int policy = SplitPolicy::GetSplitPolicy(child, axis, sorted[i].d);
+ if (policy == SplitPolicy::AssignToFirstTree)
+ numTreeOneChildren++;
+ else if (policy == SplitPolicy::AssignToSecondTree)
+ numTreeTwoChildren++;
+ else
+ {
+ numTreeOneChildren++;
+ numTreeTwoChildren++;
+ numSplits++;
+ }
+ }
+
+ // Check if the split is possible.
+ if (numTreeOneChildren <= node->MaxNumChildren() && numTreeOneChildren > 0 &&
+ numTreeTwoChildren <= node->MaxNumChildren() && numTreeTwoChildren > 0)
+ {
+ // Evaluate the cost using the number of splits and balancing.
+ size_t balance;
+
+ if (sorted.size() / 2 > i )
+ balance = sorted.size() / 2 - i;
+ else
+ balance = i - sorted.size() / 2;
+
+ size_t cost = numSplits * balance;
+ if (cost < minCost)
+ {
+ minCost = cost;
+ axisCut = sorted[i].d;
+ }
+ }
+ }
+ return minCost;
+}
+
+template<typename SplitPolicy>
+template<typename TreeType>
+size_t MinimalSplitsNumberSweep<SplitPolicy>::SweepLeafNode(
+ const size_t axis,
+ const TreeType* node,
+ typename TreeType::ElemType& axisCut)
+{
+ // Split along the median.
+ axisCut = (node->Bound()[axis].Lo() + node->Bound()[axis].Hi()) * 0.5;
+
+ if (node->Bound()[axis].Lo() == axisCut)
+ return SIZE_MAX;
+
+ return 0;
+}
+
+
+} // namespace tree
+} // namespace mlpack
+
+#endif // MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_SPLITS_NUMBER_SWEEP_IMPL_HPP
+
+
diff --cc src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_descent_heuristic_impl.hpp
index c2a5456,0000000..eca2d0f
mode 100644,000000..100644
--- a/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_descent_heuristic_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_descent_heuristic_impl.hpp
@@@ -1,49 -1,0 +1,49 @@@
+/**
+ * @file r_plus_plus_tree_descent_heuristic_impl.hpp
+ * @author Mikhail Lozhnikov
+ *
+ * Implementation of RPlusPlusTreeDescentHeuristic, a class that chooses the
+ * best child of a node in an R++ tree when inserting a new point.
+ */
+#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_DESCENT_HEURISTIC_IMPL_HPP
+#define MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_DESCENT_HEURISTIC_IMPL_HPP
+
+#include "r_plus_plus_tree_descent_heuristic.hpp"
+#include "../hrectbound.hpp"
+
+namespace mlpack {
+namespace tree {
+
+template<typename TreeType>
+size_t RPlusPlusTreeDescentHeuristic::ChooseDescentNode(
+ TreeType* node, const size_t point)
+{
+ // Find the node whose maximum bounding rectangle contains the point.
+ for (size_t bestIndex = 0; bestIndex < node->NumChildren(); bestIndex++)
+ {
- if (node->Children()[bestIndex]->AuxiliaryInfo().OuterBound().Contains(
++ if (node->Child(bestIndex).AuxiliaryInfo().OuterBound().Contains(
+ node->Dataset().col(point)))
+ return bestIndex;
+ }
+
+ // We should never reach this point.
+ assert(false);
+
+ return 0;
+}
+
+template<typename TreeType>
+size_t RPlusPlusTreeDescentHeuristic::ChooseDescentNode(
+ const TreeType* /* node */, const TreeType* /* insertedNode */)
+{
+ // Should never be used.
+ assert(false);
+
+ return 0;
+}
+
+
+} // namespace tree
+} // namespace mlpack
+
+#endif //MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_DESCENT_HEURISTIC_IMPL_HPP
diff --cc src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_split_policy.hpp
index f4c6e3e,0000000..d729153
mode 100644,000000..100644
--- a/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_split_policy.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_plus_plus_tree_split_policy.hpp
@@@ -1,75 -1,0 +1,75 @@@
+/**
+ * @file r_plus_plus_tree_split_policy.hpp
+ * @author Mikhail Lozhnikov
+ *
+ * Defintion and implementation of the RPlusPlusTreeSplitPolicy class, a class
+ * that helps to determine the subtree into which we should insert an
+ * intermediate node.
+ */
+#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_SPLIT_POLICY_HPP
+#define MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_SPLIT_POLICY_HPP
+
+namespace mlpack {
+namespace tree {
+
+/**
+ * The RPlusPlusTreeSplitPolicy helps to determine the subtree into which
+ * we should insert a child of an intermediate node that is being split.
+ * This class is designed for the R++ tree.
+ */
+class RPlusPlusTreeSplitPolicy
+{
+ public:
+ //! Indicate that the child should be split.
+ static const int SplitRequired = 0;
+ //! Indicate that the child should be inserted to the first subtree.
+ static const int AssignToFirstTree = 1;
+ //! Indicate that the child should be inserted to the second subtree.
+ static const int AssignToSecondTree = 2;
+
+ /**
+ * This method returns SplitRequired if a child of an intermediate node should
+ * be split, AssignToFirstTree if the child should be inserted to the first
+ * subtree, AssignToSecondTree if the child should be inserted to the second
+ * subtree. The method makes desicion according to the maximum bounding
+ * rectangle of the child, the axis along which the intermediate node is being
+ * split and the coordinate at which the node is being split.
+ *
+ * @param child A child of the node that is being split.
+ * @param axis The axis along which the node is being split.
+ * @param cut The coordinate at which the node is being split.
+ */
+ template<typename TreeType>
- static int GetSplitPolicy(const TreeType* child,
++ static int GetSplitPolicy(const TreeType& child,
+ const size_t axis,
+ const typename TreeType::ElemType cut)
+ {
- if (child->AuxiliaryInfo().OuterBound()[axis].Hi() <= cut)
++ if (child.AuxiliaryInfo().OuterBound()[axis].Hi() <= cut)
+ return AssignToFirstTree;
- else if (child->AuxiliaryInfo().OuterBound()[axis].Lo() >= cut)
++ else if (child.AuxiliaryInfo().OuterBound()[axis].Lo() >= cut)
+ return AssignToSecondTree;
+
+ return SplitRequired;
+ }
+
+ /**
+ * Return the maximum bounding rectangle of the node.
+ * This method should always return the bound that is used for the
+ * desicion-making in GetSplitPolicy().
+ *
+ * @param node The node whose bound is requested.
+ */
+ template<typename TreeType>
+ static const
+ bound::HRectBound<metric::EuclideanDistance, typename TreeType::ElemType>&
- Bound(const TreeType* node)
++ Bound(const TreeType& node)
+ {
- return node->AuxiliaryInfo().OuterBound();
++ return node.AuxiliaryInfo().OuterBound();
+ }
+};
+
+} // namespace tree
+} // namespace mlpack
+#endif // MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_PLUS_TREE_SPLIT_POLICY_HPP
+
+
diff --cc src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic_impl.hpp
index afe3879,0000000..77312ea
mode 100644,000000..100644
--- a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_descent_heuristic_impl.hpp
@@@ -1,104 -1,0 +1,104 @@@
+/**
+ * @file hilbert_r_tree_descent_heuristic_impl.hpp
+ * @author Mikhail Lozhnikov
+ *
+ * Implementation of HilbertRTreeDescentHeuristic, a class that chooses the best child
+ * of a node in an R tree when inserting a new point.
+ */
+#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_DESCENT_HEURISTIC_IMPL_HPP
+#define MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_DESCENT_HEURISTIC_IMPL_HPP
+
+#include "r_plus_tree_descent_heuristic.hpp"
+#include "../hrectbound.hpp"
+
+namespace mlpack {
+namespace tree {
+
+template<typename TreeType>
+size_t RPlusTreeDescentHeuristic::
+ChooseDescentNode(TreeType* node, const size_t point)
+{
+ typedef typename TreeType::ElemType ElemType;
+ size_t bestIndex = 0;
+ bool success;
+
+ // Try to find a node that contains the point.
+ for (bestIndex = 0; bestIndex < node->NumChildren(); bestIndex++)
+ {
- if (node->Children()[bestIndex]->Bound().Contains(
++ if (node->Child(bestIndex).Bound().Contains(
+ node->Dataset().col(point)))
+ return bestIndex;
+ }
+
+ // No one node contains the point. Try to enlarge a node in such a way, that
+ // the resulting node do not overlap other nodes.
+ for (bestIndex = 0; bestIndex < node->NumChildren(); bestIndex++)
+ {
+ bound::HRectBound<metric::EuclideanDistance, ElemType> bound =
- node->Children()[bestIndex]->Bound();
++ node->Child(bestIndex).Bound();
+ bound |= node->Dataset().col(point);
+
+ success = true;
+
+ for (size_t j = 0; j < node->NumChildren(); j++)
+ {
+ if (j == bestIndex)
+ continue;
+ success = false;
+ // Two nodes overlap if and only if there are no dimension in which
+ // they do not overlap each other.
+ for (size_t k = 0; k < node->Bound().Dim(); k++)
+ {
- if (bound[k].Lo() >= node->Children()[j]->Bound()[k].Hi() ||
- node->Children()[j]->Bound()[k].Lo() >= bound[k].Hi())
++ if (bound[k].Lo() >= node->Child(j).Bound()[k].Hi() ||
++ node->Child(j).Bound()[k].Lo() >= bound[k].Hi())
+ {
+ // We found the dimension in which these nodes do not overlap
+ // each other.
+ success = true;
+ break;
+ }
+ }
+ if (!success) // These two nodes overlap each other.
+ break;
+ }
+ if (success) // We found two nodes that do no overlap each other.
+ break;
+ }
+
+ if (!success) // We could not find two nodes that do no overlap each other.
+ {
+ size_t depth = node->TreeDepth();
+
+ // Create a new node into which we will insert the point.
+ TreeType* tree = node;
+ while (depth > 1)
+ {
+ TreeType* child = new TreeType(tree);
+
- tree->Children()[tree->NumChildren()++] = child;
++ tree->children[tree->NumChildren()++] = child;
+ tree = child;
+ depth--;
+ }
+ return node->NumChildren()-1;
+ }
+
+ assert(bestIndex < node->NumChildren());
+
+ return bestIndex;
+}
+
+template<typename TreeType>
+size_t RPlusTreeDescentHeuristic::ChooseDescentNode(
+ const TreeType* /* node */, const TreeType* /*insertedNode */)
+{
+ // Should never be used.
+ assert(false);
+
+ return 0;
+}
+
+
+} // namespace tree
+} // namespace mlpack
+
+#endif //MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_DESCENT_HEURISTIC_IMPL_HPP
diff --cc src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp
index 2cca91c,0000000..9fefbbc
mode 100644,000000..100644
--- a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_impl.hpp
@@@ -1,343 -1,0 +1,342 @@@
+/**
+ * @file r_plus_tree_split_impl.hpp
+ * @author Mikhail Lozhnikov
+ *
+ * Implementation of class (RPlusTreeSplit) to split a RectangleTree.
+ */
+#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_SPLIT_IMPL_HPP
+#define MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_SPLIT_IMPL_HPP
+
+#include "r_plus_tree_split.hpp"
+#include "rectangle_tree.hpp"
+#include "r_plus_plus_tree_auxiliary_information.hpp"
+#include "r_plus_tree_split_policy.hpp"
+#include "r_plus_plus_tree_split_policy.hpp"
+
+namespace mlpack {
+namespace tree {
+
+template<typename SplitPolicyType,
+ template<typename> class SweepType>
+template<typename TreeType>
+void RPlusTreeSplit<SplitPolicyType, SweepType>::
+SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
+{
+ if (tree->Count() == 1)
+ {
+ // Check if an intermediate node was added during the insertion process.
+ // i.e. we couldn't enlarge a node of the R+ tree. So, one of intermediate
+ // nodes may be overflowed.
+ TreeType* node = tree->Parent();
+
+ while (node != NULL)
+ {
+ if (node->NumChildren() == node->MaxNumChildren() + 1)
+ {
+ // Split the overflowed node.
+ RPlusTreeSplit::SplitNonLeafNode(node,relevels);
+ return;
+ }
+ node = node->Parent();
+ }
+ return;
+ }
+ else if (tree->Count() <= tree->MaxLeafSize())
+ return;
+ // If we are splitting the root node, we need will do things differently so
+ // that the constructor and other methods don't confuse the end user by giving
+ // an address of another node.
+ if (tree->Parent() == NULL)
+ {
+ // We actually want to copy this way. Pointers and everything.
+ TreeType* copy = new TreeType(*tree, false);
+ copy->Parent() = tree;
+ tree->Count() = 0;
+ tree->NullifyData();
+ // Because this was a leaf node, numChildren must be 0.
- tree->Children()[(tree->NumChildren())++] = copy;
++ tree->children[(tree->NumChildren())++] = copy;
+ assert(tree->NumChildren() == 1);
+
+ RPlusTreeSplit::SplitLeafNode(copy,relevels);
+ return;
+ }
+
+ size_t cutAxis;
+ typename TreeType::ElemType cut;
+
+ // Try to find a partiotion of the node.
+ if ( !PartitionNode(tree, cutAxis, cut))
+ return;
+
+ assert(cutAxis < tree->Bound().Dim());
+
+ TreeType* treeOne = new TreeType(tree->Parent());
+ TreeType* treeTwo = new TreeType(tree->Parent());
+ treeOne->MinLeafSize() = 0;
+ treeOne->MinNumChildren() = 0;
+ treeTwo->MinLeafSize() = 0;
+ treeTwo->MinNumChildren() = 0;
+
+ // Split the node into two new nodes.
+ SplitLeafNodeAlongPartition(tree, treeOne, treeTwo, cutAxis, cut);
+
+ TreeType* parent = tree->Parent();
+ size_t i = 0;
- while (parent->Children()[i] != tree)
++ while (parent->children[i] != tree)
+ i++;
+
+ assert(i < parent->NumChildren());
+
- // Remove the node from the tree.
- parent->Children()[i] = parent->Children()[--parent->NumChildren()];
-
+ // Insert two new nodes to the tree.
- InsertNodeIntoTree(parent, treeOne);
- InsertNodeIntoTree(parent, treeTwo);
++ parent->children[i] = treeOne;
++ parent->children[parent->NumChildren()++] = treeTwo;
+
+ assert(parent->NumChildren() <= parent->MaxNumChildren() + 1);
+
+ // Propagate the split upward if necessary.
+ if (parent->NumChildren() == parent->MaxNumChildren() + 1)
+ RPlusTreeSplit::SplitNonLeafNode(parent, relevels);
+
+ tree->SoftDelete();
+}
+
+template<typename SplitPolicyType,
+ template<typename> class SweepType>
+template<typename TreeType>
+bool RPlusTreeSplit<SplitPolicyType, SweepType>::
+SplitNonLeafNode(TreeType* tree, std::vector<bool>& relevels)
+{
+ // If we are splitting the root node, we need will do things differently so
+ // that the constructor and other methods don't confuse the end user by giving
+ // an address of another node.
+ if (tree->Parent() == NULL)
+ {
+ // We actually want to copy this way. Pointers and everything.
+ TreeType* copy = new TreeType(*tree, false);
+
+ copy->Parent() = tree;
+ tree->NumChildren() = 0;
+ tree->NullifyData();
- tree->Children()[(tree->NumChildren())++] = copy;
++ tree->children[(tree->NumChildren())++] = copy;
+
+ RPlusTreeSplit::SplitNonLeafNode(copy,relevels);
+ return true;
+ }
+ size_t cutAxis;
+ typename TreeType::ElemType cut;
+
+ // Try to find a partiotion of the node.
+ if ( !PartitionNode(tree, cutAxis, cut))
+ return false;
+
+ assert(cutAxis < tree->Bound().Dim());
+
+ TreeType* treeOne = new TreeType(tree->Parent());
+ TreeType* treeTwo = new TreeType(tree->Parent());
+ treeOne->MinLeafSize() = 0;
+ treeOne->MinNumChildren() = 0;
+ treeTwo->MinLeafSize() = 0;
+ treeTwo->MinNumChildren() = 0;
+
+ // Split the node into two new nodes.
+ SplitNonLeafNodeAlongPartition(tree, treeOne, treeTwo, cutAxis, cut);
+
+ TreeType* parent = tree->Parent();
+ size_t i = 0;
- while (parent->Children()[i] != tree)
++ while (parent->children[i] != tree)
+ i++;
+
+ assert(i < parent->NumChildren());
+
- // Remove the node from the tree.
- parent->Children()[i] = parent->Children()[--parent->NumChildren()];
-
+ // Insert two new nodes to the tree.
- InsertNodeIntoTree(parent, treeOne);
- InsertNodeIntoTree(parent, treeTwo);
++ parent->children[i] = treeOne;
++ parent->children[parent->NumChildren()++] = treeTwo;
+
+ tree->SoftDelete();
+
+ assert(parent->NumChildren() <= parent->MaxNumChildren() + 1);
+
+ // Propagate the split upward if necessary.
+ if (parent->NumChildren() == parent->MaxNumChildren() + 1)
+ RPlusTreeSplit::SplitNonLeafNode(parent, relevels);
+
+ return false;
+}
+
+template<typename SplitPolicyType,
+ template<typename> class SweepType>
+template<typename TreeType>
+void RPlusTreeSplit<SplitPolicyType, SweepType>::SplitLeafNodeAlongPartition(
+ TreeType* tree,
+ TreeType* treeOne,
+ TreeType* treeTwo,
+ const size_t cutAxis,
+ const typename TreeType::ElemType cut)
+{
+ // Split the auxiliary information.
+ tree->AuxiliaryInfo().SplitAuxiliaryInfo(treeOne, treeTwo, cutAxis, cut);
+
+ // Insert points into the corresponding subtree.
+ for (size_t i = 0; i < tree->NumPoints(); i++)
+ {
+ if (tree->Dataset().col(tree->Point(i))[cutAxis] <= cut)
+ {
+ treeOne->Point(treeOne->Count()++) = tree->Point(i);
+ treeOne->Bound() |= tree->Dataset().col(tree->Point(i));
+ }
+ else
+ {
+ treeTwo->Point(treeTwo->Count()++) = tree->Point(i);
+ treeTwo->Bound() |= tree->Dataset().col(tree->Point(i));
+ }
+ }
++ // Update the number of descandants.
++ treeOne->numDescendants = treeOne->Count();
++ treeTwo->numDescendants = treeTwo->Count();
++
+ assert(treeOne->Count() <= treeOne->MaxLeafSize());
+ assert(treeTwo->Count() <= treeTwo->MaxLeafSize());
+
+ assert(tree->Count() == treeOne->Count() + treeTwo->Count());
+ assert(treeOne->Bound()[cutAxis].Hi() < treeTwo->Bound()[cutAxis].Lo());
+}
+
+template<typename SplitPolicyType,
+ template<typename> class SweepType>
+template<typename TreeType>
+void RPlusTreeSplit<SplitPolicyType, SweepType>::SplitNonLeafNodeAlongPartition(
+ TreeType* tree,
+ TreeType* treeOne,
+ TreeType* treeTwo,
+ const size_t cutAxis,
+ const typename TreeType::ElemType cut)
+{
+ // Split the auxiliary information.
+ tree->AuxiliaryInfo().SplitAuxiliaryInfo(treeOne, treeTwo, cutAxis, cut);
+
+ // Insert children into the corresponding subtree.
+ for (size_t i = 0; i < tree->NumChildren(); i++)
+ {
- TreeType* child = tree->Children()[i];
- int policy = SplitPolicyType::GetSplitPolicy(child, cutAxis, cut);
++ TreeType* child = tree->children[i];
++ int policy = SplitPolicyType::GetSplitPolicy(*child, cutAxis, cut);
+
+ if (policy == SplitPolicyType::AssignToFirstTree)
+ {
+ InsertNodeIntoTree(treeOne, child);
+ child->Parent() = treeOne;
+ }
+ else if (policy == SplitPolicyType::AssignToSecondTree)
+ {
+ InsertNodeIntoTree(treeTwo, child);
+ child->Parent() = treeTwo;
+ }
+ else
+ {
+ // The child should be split (i.e. the partition divides its bound).
+ TreeType* childOne = new TreeType(treeOne);
+ TreeType* childTwo = new TreeType(treeTwo);
+ treeOne->MinLeafSize() = 0;
+ treeOne->MinNumChildren() = 0;
+ treeTwo->MinLeafSize() = 0;
+ treeTwo->MinNumChildren() = 0;
+
+ // Propagate the split downward.
+ if (child->IsLeaf())
+ SplitLeafNodeAlongPartition(child, childOne, childTwo, cutAxis, cut);
+ else
+ SplitNonLeafNodeAlongPartition(child, childOne, childTwo, cutAxis, cut);
+
+ InsertNodeIntoTree(treeOne, childOne);
+ InsertNodeIntoTree(treeTwo, childTwo);
+
+ child->SoftDelete();
+ }
+ }
+
+ assert(treeOne->NumChildren() + treeTwo->NumChildren() != 0);
+
+ // Add a fake subtree if one of the subtrees is empty.
+ if (treeOne->NumChildren() == 0)
+ AddFakeNodes(treeTwo, treeOne);
+ else if (treeTwo->NumChildren() == 0)
+ AddFakeNodes(treeOne, treeTwo);
+
+ assert(treeOne->NumChildren() <= treeOne->MaxNumChildren());
+ assert(treeTwo->NumChildren() <= treeTwo->MaxNumChildren());
+}
+
+template<typename SplitPolicyType,
+ template<typename> class SweepType>
+template<typename TreeType>
+void RPlusTreeSplit<SplitPolicyType, SweepType>::
+AddFakeNodes(const TreeType* tree, TreeType* emptyTree)
+{
+ size_t numDescendantNodes = tree->TreeDepth() - 1;
+
+ TreeType* node = emptyTree;
+ for (size_t i = 0; i < numDescendantNodes; i++)
+ {
+ TreeType* child = new TreeType(node);
- node->Children()[node->NumChildren()++] = child;
++ node->children[node->NumChildren()++] = child;
+
+ node = child;
+ }
+}
+
+template<typename SplitPolicyType,
+ template<typename> class SweepType>
+template<typename TreeType>
+bool RPlusTreeSplit<SplitPolicyType, SweepType>::
+PartitionNode(const TreeType* node, size_t& minCutAxis,
+ typename TreeType::ElemType& minCut)
+{
+ if ((node->NumChildren() <= fillFactor && !node->IsLeaf()) ||
+ (node->Count() <= fillFactor && node->IsLeaf()))
+ return false; // No partition required.
+
+ // Define the type of the sweep cost.
+ typedef typename
+ SweepType<SplitPolicyType>::template SweepCost<TreeType>::type
+ SweepCostType;
+
+ SweepCostType minCost = std::numeric_limits<SweepCostType>::max();
+ minCutAxis = node->Bound().Dim();
+
+ // Find the sweep with a minimal cost.
+ for (size_t k = 0; k < node->Bound().Dim(); k++)
+ {
+ typename TreeType::ElemType cut;
+ SweepCostType cost;
+
+ if (node->IsLeaf())
+ cost = SweepType<SplitPolicyType>::SweepLeafNode(k, node, cut);
+ else
+ cost = SweepType<SplitPolicyType>::SweepNonLeafNode(k, node, cut);
+
+
+ if (cost < minCost)
+ {
+ minCost = cost;
+ minCutAxis = k;
+ minCut = cut;
+ }
+ }
+ return true;
+}
+
+template<typename SplitPolicyType,
+ template<typename> class SweepType>
+template<typename TreeType>
+void RPlusTreeSplit<SplitPolicyType, SweepType>::
+InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode)
+{
+ destTree->Bound() |= srcNode->Bound();
- destTree->Children()[destTree->NumChildren()++] = srcNode;
++ destTree->numDescendants += srcNode->numDescendants;
++ destTree->children[destTree->NumChildren()++] = srcNode;
+}
+
+
+} // namespace tree
+} // namespace mlpack
+
+#endif // MLPACK_CORE_TREE_RECTANGLE_TREE_HILBERT_R_TREE_SPLIT_IMPL_HPP
diff --cc src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_policy.hpp
index 6d17338,0000000..1302dd3
mode 100644,000000..100644
--- a/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_policy.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_plus_tree_split_policy.hpp
@@@ -1,75 -1,0 +1,75 @@@
+/**
+ * @file r_plus_tree_split_policy.hpp
+ * @author Mikhail Lozhnikov
+ *
+ * Defintion and implementation of the RPlusTreeSplitPolicy class, a class that
+ * helps to determine the subtree into which we should insert an intermediate
+ * node.
+ */
+#ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_SPLIT_POLICY_HPP
+#define MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_SPLIT_POLICY_HPP
+
+namespace mlpack {
+namespace tree {
+
+/**
+ * The RPlusPlusTreeSplitPolicy helps to determine the subtree into which
+ * we should insert a child of an intermediate node that is being split.
+ * This class is designed for the R+ tree.
+ */
+class RPlusTreeSplitPolicy
+{
+ public:
+ //! Indicate that the child should be split.
+ static const int SplitRequired = 0;
+ //! Indicate that the child should be inserted to the first subtree.
+ static const int AssignToFirstTree = 1;
+ //! Indicate that the child should be inserted to the second subtree.
+ static const int AssignToSecondTree = 2;
+
+ /**
+ * This method returns SplitRequired if a child of an intermediate node should
+ * be split, AssignToFirstTree if the child should be inserted to the first
+ * subtree, AssignToSecondTree if the child should be inserted to the second
+ * subtree. The method makes desicion according to the minimum bounding
+ * rectangle of the child, the axis along which the intermediate node is being
+ * split and the coordinate at which the node is being split.
+ *
+ * @param child A child of the node that is being split.
+ * @param axis The axis along which the node is being split.
+ * @param cut The coordinate at which the node is being split.
+ */
+ template<typename TreeType>
- static int GetSplitPolicy(const TreeType* child,
++ static int GetSplitPolicy(const TreeType& child,
+ const size_t axis,
+ const typename TreeType::ElemType cut)
+ {
- if (child->Bound()[axis].Hi() <= cut)
++ if (child.Bound()[axis].Hi() <= cut)
+ return AssignToFirstTree;
- else if (child->Bound()[axis].Lo() >= cut)
++ else if (child.Bound()[axis].Lo() >= cut)
+ return AssignToSecondTree;
+
+ return SplitRequired;
+ }
+
+ /**
+ * Return the minimum bounding rectangle of the node.
+ * This method should always return the bound that is used for the
+ * desicion-making in GetSplitPolicy().
+ *
+ * @param node The node whose bound is requested.
+ */
+ template<typename TreeType>
+ static const
+ bound::HRectBound<metric::EuclideanDistance, typename TreeType::ElemType>&
- Bound(const TreeType* node)
++ Bound(const TreeType& node)
+ {
- return node->Bound();
++ return node.Bound();
+ }
+};
+
+} // namespace tree
+} // namespace mlpack
+#endif // MLPACK_CORE_TREE_RECTANGLE_TREE_R_PLUS_TREE_SPLIT_POLICY_HPP
+
+
diff --cc src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
index 37f9572,9590aa5..bbdebda
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
@@@ -514,6 -512,12 +512,15 @@@ class RectangleTre
//! Friend access is given for the default constructor.
friend class boost::serialization::access;
++ //! Give friend access for DescentType.
++ friend DescentType;
++
+ //! Give friend access for SplitType.
+ friend SplitType;
+
+ //! Give friend access for AuxiliaryInformationType.
+ friend AuxiliaryInformation;
+
public:
/**
* Condense the bounding rectangles for this node based on the removal of the
diff --cc src/mlpack/tests/rectangle_tree_test.cpp
index 5452d71,e6aedd9..d331bb3
--- a/src/mlpack/tests/rectangle_tree_test.cpp
+++ b/src/mlpack/tests/rectangle_tree_test.cpp
@@@ -137,19 -137,9 +137,19 @@@ void CheckContainment(const TreeType& t
for (size_t i = 0; i < tree.NumChildren(); i++)
{
for (size_t j = 0; j < tree.Bound().Dim(); j++)
- BOOST_REQUIRE(tree.Bound()[j].Contains(tree.Child(i).Bound()[j]));
+ {
+ // All children should be covered by the parent node.
+ // Some children can be empty (only in case of the R++ tree)
- bool success = (tree.Children()[i]->Bound()[j].Hi() ==
++ bool success = (tree.Child(i).Bound()[j].Hi() ==
+ std::numeric_limits<typename TreeType::ElemType>::lowest() &&
- tree.Children()[i]->Bound()[j].Lo() ==
++ tree.Child(i).Bound()[j].Lo() ==
+ std::numeric_limits<typename TreeType::ElemType>::max()) ||
- tree.Bound()[j].Contains(tree.Children()[i]->Bound()[j]);
++ tree.Bound()[j].Contains(tree.Child(i).Bound()[j]);
+
+ BOOST_REQUIRE(success);
+ }
- CheckContainment(*(tree.Children()[i]));
+ CheckContainment(tree.Child(i));
}
}
}
@@@ -837,242 -854,6 +864,244 @@@ BOOST_AUTO_TEST_CASE(DiscreteHilbertVal
point4), -1);
}
+template<typename TreeType>
+void CheckOverlap(const TreeType& tree)
+{
+ bool success = true;
+
+ // Check if two nodes overlap each other.
+ for (size_t i = 0; i < tree.NumChildren(); i++)
+ {
+ success = true;
+
+ for (size_t j = 0; j < tree.NumChildren(); j++)
+ {
+ if (j == i)
+ continue;
+ success = false;
+ // Two nodes overlap each other if and only if there are no dimension
+ // in which they do not overlap each other.
+ for (size_t k = 0; k < tree.Bound().Dim(); k++)
+ {
- if ((tree.Children()[i]->Bound()[k].Lo() >=
- tree.Children()[j]->Bound()[k].Hi()) ||
- (tree.Children()[j]->Bound()[k].Lo() >=
- tree.Children()[i]->Bound()[k].Hi()))
++ if ((tree.Child(i).Bound()[k].Lo() >=
++ tree.Child(j).Bound()[k].Hi()) ||
++ (tree.Child(j).Bound()[k].Lo() >=
++ tree.Child(i).Bound()[k].Hi()))
+ {
+ success = true;
+ break;
+ }
+ }
+ if (!success)
+ break;
+ }
+ if (!success)
+ break;
+ }
+ BOOST_REQUIRE_EQUAL(success, true);
+
+ for (size_t i = 0; i < tree.NumChildren(); i++)
+ CheckOverlap(tree.Child(i));
+}
+
+BOOST_AUTO_TEST_CASE(RPlusTreeOverlapTest)
+{
+ arma::mat dataset;
+ dataset.randu(8, 1000); // 1000 points in 8 dimensions.
+
+ typedef RPlusTree<EuclideanDistance,
+ NeighborSearchStat<NearestNeighborSort>,arma::mat> TreeType;
+ TreeType rPlusTree(dataset, 20, 6, 5, 2, 0);
+
+ CheckOverlap(rPlusTree);
+
+ // Ensure that all leaf nodes are at the same level.
+ BOOST_REQUIRE_EQUAL(GetMinLevel(rPlusTree), GetMaxLevel(rPlusTree));
+ BOOST_REQUIRE_EQUAL(rPlusTree.TreeDepth(), GetMinLevel(rPlusTree));
+}
+
+
+BOOST_AUTO_TEST_CASE(RPlusTreeTraverserTest)
+{
+ arma::mat dataset;
+
+ const int numP = 1000;
+
+ dataset.randu(8, numP); // 1000 points in 8 dimensions.
+ arma::Mat<size_t> neighbors1;
+ arma::mat distances1;
+ arma::Mat<size_t> neighbors2;
+ arma::mat distances2;
+
+ typedef RPlusTree<EuclideanDistance, NeighborSearchStat<NearestNeighborSort>,
+ arma::mat> TreeType;
+ TreeType rPlusTree(dataset, 20, 6, 5, 2, 0);
+
+ // Nearest neighbor search with the X tree.
+
+ NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>, arma::mat,
+ RPlusTree > knn1(&rPlusTree, true);
+
+ BOOST_REQUIRE_EQUAL(rPlusTree.NumDescendants(), numP);
+
+ CheckContainment(rPlusTree);
+ CheckExactContainment(rPlusTree);
+ CheckHierarchy(rPlusTree);
+ CheckOverlap(rPlusTree);
++ CheckNumDescendants(rPlusTree);
+
+ knn1.Search(5, neighbors1, distances1);
+
+ // Nearest neighbor search the naive way.
+ KNN knn2(dataset, true, true);
+
+ knn2.Search(5, neighbors2, distances2);
+
+ for (size_t i = 0; i < neighbors1.size(); i++)
+ {
+ BOOST_REQUIRE_EQUAL(neighbors1[i], neighbors2[i]);
+ BOOST_REQUIRE_EQUAL(distances1[i], distances2[i]);
+ }
+}
+
+template<typename TreeType>
+void CheckRPlusPlusTreeBound(const TreeType& tree)
+{
+ typedef bound::HRectBound<metric::EuclideanDistance,
+ typename TreeType::ElemType> Bound;
+
+ bool success = true;
+
+ // Ensure that the maximum bounding rectangle contains all children.
+ for (size_t k = 0; k < tree.Bound().Dim(); k++)
+ {
+ BOOST_REQUIRE_LE(tree.Bound()[k].Hi(),
+ tree.AuxiliaryInfo().OuterBound()[k].Hi());
+ BOOST_REQUIRE_LE(tree.AuxiliaryInfo().OuterBound()[k].Lo(),
+ tree.Bound()[k].Lo());
+ }
+
+ if (tree.IsLeaf())
+ {
+ // Ensure that the maximum bounding rectangle contains all points.
+ for (size_t i = 0; i < tree.Count(); i++)
+ BOOST_REQUIRE_EQUAL(true,
+ tree.Bound().Contains(tree.Dataset().col(tree.Point(i))));
+
+ return;
+ }
+
+ // Ensure that two children's maximum bounding rectangles do not overlap
+ // each other.
+ for (size_t i = 0; i < tree.NumChildren(); i++)
+ {
- const Bound& bound1 = tree.Children()[i]->AuxiliaryInfo().OuterBound();
++ const Bound& bound1 = tree.Child(i).AuxiliaryInfo().OuterBound();
+ success = true;
+
+ for (size_t j = 0; j < tree.NumChildren(); j++)
+ {
+ if (j == i)
+ continue;
- const Bound& bound2 = tree.Children()[j]->AuxiliaryInfo().OuterBound();
++ const Bound& bound2 = tree.Child(j).AuxiliaryInfo().OuterBound();
+
+ // Two bounds overlap each other if and only if there are no dimension
+ // in which they do not overlap each other.
+ success = false;
+ for (size_t k = 0; k < tree.Bound().Dim(); k++)
+ {
+ if (bound1[k].Lo() >= bound2[k].Hi() ||
+ bound2[k].Lo() >= bound1[k].Hi())
+ {
+ success = true;
+ break;
+ }
+ }
+ if (!success)
+ break;
+ }
+ if (!success)
+ break;
+ }
+ BOOST_REQUIRE_EQUAL(success, true);
+
+ for (size_t i = 0; i < tree.NumChildren(); i++)
+ CheckRPlusPlusTreeBound(tree.Child(i));
+}
+
+BOOST_AUTO_TEST_CASE(RPlusPlusTreeBoundTest)
+{
+ arma::mat dataset;
+ dataset.randu(8, 1000); // 1000 points in 8 dimensions.
+
+ // Check the MinimalCoverageSweep.
+ typedef RPlusPlusTree<EuclideanDistance,
+ NeighborSearchStat<NearestNeighborSort>,arma::mat> TreeType;
+ TreeType rPlusPlusTree(dataset, 20, 6, 5, 2, 0);
+
+ CheckRPlusPlusTreeBound(rPlusPlusTree);
+
+ BOOST_REQUIRE_EQUAL(GetMinLevel(rPlusPlusTree), GetMaxLevel(rPlusPlusTree));
+ BOOST_REQUIRE_EQUAL(rPlusPlusTree.TreeDepth(), GetMinLevel(rPlusPlusTree));
+
+ // Check the MinimalSplitsNumberSweep.
+ typedef RectangleTree<EuclideanDistance,
+ NeighborSearchStat<NearestNeighborSort>, arma::mat,
+ RPlusTreeSplit<RPlusPlusTreeSplitPolicy, MinimalSplitsNumberSweep>,
+ RPlusPlusTreeDescentHeuristic, RPlusPlusTreeAuxiliaryInformation>
+ RPlusPlusTreeMinimalSplits;
+
+ RPlusPlusTreeMinimalSplits rPlusPlusTree2(dataset, 20, 6, 5, 2, 0);
+
+ CheckRPlusPlusTreeBound(rPlusPlusTree2);
+
+ BOOST_REQUIRE_EQUAL(GetMinLevel(rPlusPlusTree2), GetMaxLevel(rPlusPlusTree2));
+ BOOST_REQUIRE_EQUAL(rPlusPlusTree2.TreeDepth(), GetMinLevel(rPlusPlusTree2));
+}
+
+BOOST_AUTO_TEST_CASE(RPlusPlusTreeTraverserTest)
+{
+ arma::mat dataset;
+
+ const int numP = 1000;
+
+ dataset.randu(8, numP); // 1000 points in 8 dimensions.
+ arma::Mat<size_t> neighbors1;
+ arma::mat distances1;
+ arma::Mat<size_t> neighbors2;
+ arma::mat distances2;
+
+ typedef RPlusPlusTree<EuclideanDistance,
+ NeighborSearchStat<NearestNeighborSort>, arma::mat> TreeType;
+ TreeType rPlusPlusTree(dataset, 20, 6, 5, 2, 0);
+
+ // Nearest neighbor search with the X tree.
+
+ NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
+ arma::mat, RPlusPlusTree > knn1(&rPlusPlusTree, true);
+
+ BOOST_REQUIRE_EQUAL(rPlusPlusTree.NumDescendants(), numP);
+
+ CheckContainment(rPlusPlusTree);
+ CheckExactContainment(rPlusPlusTree);
+ CheckHierarchy(rPlusPlusTree);
+ CheckRPlusPlusTreeBound(rPlusPlusTree);
++ CheckNumDescendants(rPlusPlusTree);
+
+ knn1.Search(5, neighbors1, distances1);
+
+ // Nearest neighbor search the naive way.
+ KNN knn2(dataset, true, true);
+
+ knn2.Search(5, neighbors2, distances2);
+
+ for (size_t i = 0; i < neighbors1.size(); i++)
+ {
+ BOOST_REQUIRE_EQUAL(neighbors1[i], neighbors2[i]);
+ BOOST_REQUIRE_EQUAL(distances1[i], distances2[i]);
+ }
+}
+
+
// Test the tree splitting. We set MaxLeafSize and MaxNumChildren rather low
// to allow us to test by hand without adding hundreds of points.
BOOST_AUTO_TEST_CASE(RTreeSplitTest)
More information about the mlpack-git
mailing list