[mlpack-git] master: RectangleTree:NumDescendants() optimization (fbded36)
gitdub at mlpack.org
gitdub at mlpack.org
Wed Jun 29 07:10:49 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/eea2aa43b9b914b4d64f45bdf1f5358faefe2522...809ed4bf33cef9de8412fc167cb0e356a369e3b6
>---------------------------------------------------------------
commit fbded362ff8fb33b35faf4fafbbdb7326334bff2
Author: Mikhail Lozhnikov <lozhnikovma at gmail.com>
Date: Wed Jun 29 14:10:49 2016 +0300
RectangleTree:NumDescendants() optimization
>---------------------------------------------------------------
fbded362ff8fb33b35faf4fafbbdb7326334bff2
.../rectangle_tree/hilbert_r_tree_split_impl.hpp | 5 ++
.../tree/rectangle_tree/r_star_tree_split_impl.hpp | 1 +
.../core/tree/rectangle_tree/r_tree_split_impl.hpp | 1 +
.../core/tree/rectangle_tree/rectangle_tree.hpp | 2 +
.../tree/rectangle_tree/rectangle_tree_impl.hpp | 54 +++++++++++++++++-----
.../core/tree/rectangle_tree/x_tree_split_impl.hpp | 1 +
src/mlpack/tests/rectangle_tree_test.cpp | 27 +++++++++++
7 files changed, 80 insertions(+), 11 deletions(-)
diff --git a/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_split_impl.hpp
index 1fc6b28..7742e9d 100644
--- a/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_split_impl.hpp
@@ -231,10 +231,12 @@ RedistributeNodesEvenly(const TreeType *parent,
// Since we redistribute children of a sibling we should recalculate the
// bound.
parent->Child(i).Bound().Clear();
+ parent->Child(i).numDescendants = 0;
for (size_t j = 0; j < numChildrenPerNode; j++)
{
parent->Child(i).Bound() |= children[iChild]->Bound();
+ parent->Child(i).numDescendants += children[iChild]->numDescendants;
parent->Child(i).children[j] = children[iChild];
children[iChild]->Parent() = parent->children[i];
iChild++;
@@ -242,6 +244,7 @@ RedistributeNodesEvenly(const TreeType *parent,
if (numRestChildren > 0)
{
parent->Child(i).Bound() |= children[iChild]->Bound();
+ parent->Child(i).numDescendants += children[iChild]->numDescendants;
parent->Child(i).children[numChildrenPerNode] = children[iChild];
children[iChild]->Parent() = parent->children[i];
parent->Child(i).NumChildren() = numChildrenPerNode + 1;
@@ -313,6 +316,8 @@ RedistributePointsEvenly(TreeType* parent,
{
parent->Child(i).Count() = numPointsPerNode;
}
+ parent->Child(i).numDescendants = parent->Child(i).Count();
+
assert(parent->Child(i).NumPoints() <=
parent->Child(i).MaxLeafSize());
}
diff --git a/src/mlpack/core/tree/rectangle_tree/r_star_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_star_tree_split_impl.hpp
index 12a7b4a..9347b87 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_star_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_star_tree_split_impl.hpp
@@ -675,6 +675,7 @@ template<typename TreeType>
void RStarTreeSplit::InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode)
{
destTree->Bound() |= srcNode->Bound();
+ destTree->numDescendants += srcNode->numDescendants;
destTree->children[destTree->NumChildren()++] = srcNode;
}
diff --git a/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
index 98701ff..faf75f3 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
@@ -521,6 +521,7 @@ template<typename TreeType>
void RTreeSplit::InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode)
{
destTree->Bound() |= srcNode->Bound();
+ destTree->numDescendants += srcNode->numDescendants;
destTree->children[destTree->NumChildren()++] = srcNode;
}
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
index 46bccfd..90ac546 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
@@ -76,6 +76,8 @@ class RectangleTree
//! The number of points in the dataset contained in this node (and its
//! children).
size_t count;
+ //! The number of descendants of this node.
+ size_t numDescendants;
//! The max leaf size.
size_t maxLeafSize;
//! The minimum leaf size.
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
index 0da0fd7..6c4b42f 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
@@ -37,6 +37,7 @@ RectangleTree(const MatType& data,
parent(NULL),
begin(0),
count(0),
+ numDescendants(0),
maxLeafSize(maxLeafSize),
minLeafSize(minLeafSize),
bound(data.n_rows),
@@ -76,6 +77,7 @@ RectangleTree(MatType&& data,
parent(NULL),
begin(0),
count(0),
+ numDescendants(0),
maxLeafSize(maxLeafSize),
minLeafSize(minLeafSize),
bound(data.n_rows),
@@ -114,6 +116,7 @@ RectangleTree(
parent(parentNode),
begin(0),
count(0),
+ numDescendants(0),
maxLeafSize(parentNode->MaxLeafSize()),
minLeafSize(parentNode->MinLeafSize()),
bound(parentNode->Bound().Dim()),
@@ -148,6 +151,7 @@ RectangleTree(
parent(other.Parent()),
begin(other.Begin()),
count(other.Count()),
+ numDescendants(other.numDescendants),
maxLeafSize(other.MaxLeafSize()),
minLeafSize(other.MinLeafSize()),
bound(other.bound),
@@ -269,6 +273,8 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
// Expand the bound regardless of whether it is a leaf node.
bound |= dataset->col(point);
+ numDescendants++;
+
std::vector<bool> lvls(TreeDepth());
for (size_t i = 0; i < lvls.size(); i++)
lvls[i] = true;
@@ -306,6 +312,8 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
// Expand the bound regardless of whether it is a leaf node.
bound |= dataset->col(point);
+ numDescendants++;
+
// If this is a leaf node, we stop here and add the point.
if (numChildren == 0)
{
@@ -345,6 +353,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
{
// Expand the bound regardless of the level.
bound |= node->Bound();
+ numDescendants += node->numDescendants;
if (level == TreeDepth())
{
if (!auxiliaryInfo.HandleNodeInsertion(this, node, true))
@@ -395,6 +404,12 @@ bool RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
if (!auxiliaryInfo.HandlePointDeletion(this, i))
points[i] = points[--count];
+ RectangleTree* tree = this;
+ while (tree != NULL)
+ {
+ tree->numDescendants--;
+ tree = tree->Parent();
+ }
// This function wil ensure that minFill is satisfied.
CondenseTree(dataset->col(point), lvls, true);
return true;
@@ -433,6 +448,12 @@ bool RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
if (!auxiliaryInfo.HandlePointDeletion(this, i))
points[i] = points[--count];
+ RectangleTree* tree = this;
+ while (tree != NULL)
+ {
+ tree->numDescendants--;
+ tree = tree->Parent();
+ }
// This function will ensure that minFill is satisfied.
CondenseTree(dataset->col(point), relevels, true);
return true;
@@ -471,6 +492,12 @@ bool RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
{
children[i] = children[--numChildren]; // Decrement numChildren.
}
+ RectangleTree* tree = this;
+ while (tree != NULL)
+ {
+ tree->numDescendants -= node->numDescendants;
+ tree = tree->Parent();
+ }
CondenseTree(arma::vec(), relevels, false);
return true;
}
@@ -613,17 +640,7 @@ template<typename MetricType,
inline size_t RectangleTree<MetricType, StatisticType, MatType, SplitType,
DescentType, AuxiliaryInformationType>::NumDescendants() const
{
- if (numChildren == 0)
- {
- return count;
- }
- else
- {
- size_t n = 0;
- for (size_t i = 0; i < numChildren; i++)
- n += children[i]->NumDescendants();
- return n;
- }
+ return numDescendants;
}
/**
@@ -763,6 +780,13 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
if (stillShrinking)
stillShrinking = root->ShrinkBoundForBound(bound);
+ root = parent;
+ while (root != NULL)
+ {
+ root->numDescendants -= numDescendants;
+ root = root->Parent();
+ }
+
stillShrinking = true;
root = parent;
while (root->Parent() != NULL)
@@ -817,6 +841,13 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
if (stillShrinking)
stillShrinking = root->ShrinkBoundForBound(bound);
+ root = parent;
+ while (root != NULL)
+ {
+ root->numDescendants -= numDescendants;
+ root = root->Parent();
+ }
+
stillShrinking = true;
root = parent;
while (root->Parent() != NULL)
@@ -1068,6 +1099,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
ar & CreateNVP(begin, "begin");
ar & CreateNVP(count, "count");
+ ar & CreateNVP(numDescendants, "numDescendants");
ar & CreateNVP(maxLeafSize, "maxLeafSize");
ar & CreateNVP(minLeafSize, "minLeafSize");
ar & CreateNVP(bound, "bound");
diff --git a/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp
index 8f8d144..8d8fc6d 100644
--- a/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp
@@ -840,6 +840,7 @@ template<typename TreeType>
void XTreeSplit::InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode)
{
destTree->Bound() |= srcNode->Bound();
+ destTree->numDescendants += srcNode->numDescendants;
destTree->children[destTree->NumChildren()] = srcNode;
destTree->NumChildren()++;
}
diff --git a/src/mlpack/tests/rectangle_tree_test.cpp b/src/mlpack/tests/rectangle_tree_test.cpp
index 5522a03..e6aedd9 100644
--- a/src/mlpack/tests/rectangle_tree_test.cpp
+++ b/src/mlpack/tests/rectangle_tree_test.cpp
@@ -316,6 +316,28 @@ int GetMinLevel(const TreeType& tree)
return min;
}
+/**
+ * A function to check that numDescendants values are set correctly.
+ */
+template<typename TreeType>
+size_t CheckNumDescendants(const TreeType& tree)
+{
+ if (tree.IsLeaf())
+ {
+ BOOST_REQUIRE_EQUAL(tree.NumDescendants(), tree.Count());
+ return tree.Count();
+ }
+
+ size_t numDescendants = 0;
+
+ for (size_t i = 0; i < tree.NumChildren(); i++)
+ numDescendants += CheckNumDescendants(tree.Child(i));
+
+ BOOST_REQUIRE_EQUAL(tree.NumDescendants(), numDescendants);
+
+ return numDescendants;
+}
+
// A test to ensure that all leaf nodes are stored on the same level of the
// tree.
BOOST_AUTO_TEST_CASE(TreeBalance)
@@ -378,6 +400,7 @@ BOOST_AUTO_TEST_CASE(PointDeletion)
CheckContainment(tree);
CheckExactContainment(tree);
+ CheckNumDescendants(tree);
// Single-tree search.
NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>, arma::mat,
@@ -460,6 +483,7 @@ BOOST_AUTO_TEST_CASE(PointDynamicAdd)
BOOST_REQUIRE_EQUAL(tree.NumDescendants(), 1000 + numIter);
CheckContainment(tree);
CheckExactContainment(tree);
+ CheckNumDescendants(tree);
// Now we will compare the output of the R Tree vs the output of a naive
// search.
@@ -510,6 +534,7 @@ BOOST_AUTO_TEST_CASE(SingleTreeTraverserTest)
CheckContainment(rTree);
CheckExactContainment(rTree);
CheckHierarchy(rTree);
+ CheckNumDescendants(rTree);
knn1.Search(5, neighbors1, distances1);
@@ -552,6 +577,7 @@ BOOST_AUTO_TEST_CASE(XTreeTraverserTest)
CheckContainment(xTree);
CheckExactContainment(xTree);
CheckHierarchy(xTree);
+ CheckNumDescendants(xTree);
knn1.Search(5, neighbors1, distances1);
@@ -592,6 +618,7 @@ BOOST_AUTO_TEST_CASE(HilbertRTreeTraverserTest)
CheckContainment(hilbertRTree);
CheckExactContainment(hilbertRTree);
CheckHierarchy(hilbertRTree);
+ CheckNumDescendants(hilbertRTree);
knn1.Search(5, neighbors1, distances1);
More information about the mlpack-git
mailing list