[mlpack-git] master: Round four. Start over. This time, I proved that the algorithm is right before implementing it. This should help make debugging a lot easier. (540480e)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:05:12 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44
>---------------------------------------------------------------
commit 540480ea66afba266f8791082c33ed0edc3a3993
Author: Ryan Curtin <ryan at ratml.org>
Date: Sat Feb 14 14:05:03 2015 -0500
Round four. Start over. This time, I proved that the algorithm is right before implementing it. This should help make debugging a lot easier.
>---------------------------------------------------------------
540480ea66afba266f8791082c33ed0edc3a3993
src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp | 22 ++--
src/mlpack/methods/kmeans/dtnn_rules.hpp | 24 +++-
src/mlpack/methods/kmeans/dtnn_rules_impl.hpp | 157 ++++++++++++------------
src/mlpack/methods/kmeans/dtnn_statistic.hpp | 45 +++----
src/mlpack/tests/kmeans_test.cpp | 160 +++++++++++++++++++++++++
5 files changed, 283 insertions(+), 125 deletions(-)
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
index b451647..3fa25ab 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
@@ -148,24 +148,14 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
// We won't use the AllkNN class here because we have our own set of rules.
// This is a lot of overhead. We don't need the distances.
- Timer::Start("knn");
- typedef DTNNKMeansRules<MetricType, TreeType> RuleType;
- RuleType rules(centroids, dataset, assignments, distances, metric,
- prunedPoints, oldFromNewCentroids, visited);
-
- // Now construct the traverser ourselves.
- typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
-
- traverser.Traverse(*tree, *centroidTree);
- Timer::Stop("knn");
Timer::Start("tree_mod");
DecoalesceTree(*tree);
Timer::Stop("tree_mod");
- Log::Info << "This iteration: " << rules.BaseCases() << " base cases, " <<
- rules.Scores() << " scores.\n";
- distanceCalculations += rules.BaseCases() + rules.Scores();
+// Log::Info << "This iteration: " << rules.BaseCases() << " base cases, " <<
+// rules.Scores() << " scores.\n";
+// distanceCalculations += rules.BaseCases() + rules.Scores();
// From the assignments, calculate the new centroids and counts.
for (size_t i = 0; i < dataset.n_cols; ++i)
@@ -225,6 +215,7 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
const arma::mat& interclusterDistances,
const std::vector<size_t>& newFromOldCentroids)
{
+/*
// Update iteration.
// node.Stat().Iteration() = iteration;
@@ -510,6 +501,7 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
node.Stat().SecondBound() += clusterDistances[centroids.n_cols];
if (node.Stat().Bound() != DBL_MAX)
node.Stat().Bound() += clusterDistances[centroids.n_cols];
+*/
}
template<typename MetricType, typename MatType, typename TreeType>
@@ -517,6 +509,7 @@ void DTNNKMeans<MetricType, MatType, TreeType>::CoalesceTree(
TreeType& node,
const size_t child /* Which child are we? */)
{
+/*
// If one of the two children is pruned, we hide this node.
// This assumes the BinarySpaceTree. (bad Ryan! bad!)
if (node.NumChildren() == 0)
@@ -554,11 +547,13 @@ void DTNNKMeans<MetricType, MatType, TreeType>::CoalesceTree(
CoalesceTree(node.Child(0), 0);
CoalesceTree(node.Child(1), 1);
}
+*/
}
template<typename MetricType, typename MatType, typename TreeType>
void DTNNKMeans<MetricType, MatType, TreeType>::DecoalesceTree(TreeType& node)
{
+/*
node.Parent() = (TreeType*) node.Stat().TrueParent();
node.ChildPtr(0) = (TreeType*) node.Stat().TrueLeft();
node.ChildPtr(1) = (TreeType*) node.Stat().TrueRight();
@@ -568,6 +563,7 @@ void DTNNKMeans<MetricType, MatType, TreeType>::DecoalesceTree(TreeType& node)
DecoalesceTree(node.Child(0));
DecoalesceTree(node.Child(1));
}
+*/
}
template<typename MetricType, typename MatType, typename TreeType>
diff --git a/src/mlpack/methods/kmeans/dtnn_rules.hpp b/src/mlpack/methods/kmeans/dtnn_rules.hpp
index 0183e9e..1c3cda4 100644
--- a/src/mlpack/methods/kmeans/dtnn_rules.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_rules.hpp
@@ -15,14 +15,14 @@ namespace mlpack {
namespace kmeans {
template<typename MetricType, typename TreeType>
-class DTNNKMeansRules : public neighbor::NeighborSearchRules<
- neighbor::NearestNeighborSort, MetricType, TreeType>
+class DTNNKMeansRules
{
public:
DTNNKMeansRules(const arma::mat& centroids,
const arma::mat& dataset,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances,
+ arma::Col<size_t>& assignments,
+ arma::vec& upperBounds,
+ arma::vec& lowerBounds,
MetricType& metric,
const std::vector<bool>& prunedPoints,
const std::vector<size_t>& oldFromNewCentroids,
@@ -39,13 +39,29 @@ class DTNNKMeansRules : public neighbor::NeighborSearchRules<
TreeType& referenceNode,
const double oldScore);
+ typedef int TraversalInfoType;
+
+ TraversalInfoType& TraversalInfo() { return traversalInfo; }
+ const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
+
private:
+ const arma::mat& centroids;
+ const arma::mat& dataset;
+ arma::Col<size_t>& assignments;
+ arma::vec& upperBounds;
+ arma::vec& lowerBounds;
+ MetricType& metric;
const std::vector<bool>& prunedPoints;
const std::vector<size_t>& oldFromNewCentroids;
std::vector<bool>& visited;
+
+ size_t baseCases;
+ size_t scores;
+
+ int traversalInfo;
};
} // namespace kmeans
diff --git a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
index a526c94..7e1904f 100644
--- a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
@@ -16,17 +16,24 @@ template<typename MetricType, typename TreeType>
DTNNKMeansRules<MetricType, TreeType>::DTNNKMeansRules(
const arma::mat& centroids,
const arma::mat& dataset,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances,
+ arma::Col<size_t>& assignments,
+ arma::vec& upperBounds,
+ arma::vec& lowerBounds,
MetricType& metric,
const std::vector<bool>& prunedPoints,
const std::vector<size_t>& oldFromNewCentroids,
std::vector<bool>& visited) :
- neighbor::NeighborSearchRules<neighbor::NearestNeighborSort, MetricType,
- TreeType>(centroids, dataset, neighbors, distances, metric),
+ centroids(centroids),
+ dataset(dataset),
+ assignments(assignments),
+ upperBounds(upperBounds),
+ lowerBounds(lowerBounds),
+ metric(metric),
prunedPoints(prunedPoints),
oldFromNewCentroids(oldFromNewCentroids),
- visited(visited)
+ visited(visited),
+ baseCases(0),
+ scores(0)
{
// Nothing to do.
}
@@ -36,88 +43,42 @@ inline force_inline double DTNNKMeansRules<MetricType, TreeType>::BaseCase(
const size_t queryIndex,
const size_t referenceIndex)
{
- // We'll check if the query point has been pruned. If so, don't continue.
-// Log::Debug << "Base case " << queryIndex << ", " << referenceIndex <<
-//".\n";
if (prunedPoints[queryIndex])
return 0.0; // Returning 0 shouldn't be a problem.
-// Log::Debug << "(not pruned.)\n";
// Any base cases imply that we will get a result.
visited[queryIndex] = true;
- // This is basically an inlined NeighborSearchRules::BaseCase(), but it
- // differs in that it applies the mappings to the results automatically.
- // We can also skip a check or two.
+ // Calculate the distance.
+ ++baseCases;
+ const double distance = metric.Evaluate(dataset.col(queryIndex),
+ centroids.col(referenceIndex));
- // By the way, all of the this-> is necessary because the parent class is a
- // dependent name, so all of the members of that parent aren't resolvable
- // before type substitution. The 'this->' turns that member into a dependent
- // name too (since the type of 'this' is dependent), and thus the compiler
- // resolves the name later and we get no error. Hooray C++!
- //
- // See also:
- // http://stackoverflow.com/questions/10639053/name-lookups-in-c-templates
-
- // If we have already performed this base case, do not perform it again.
- if ((this->lastQueryIndex == queryIndex) &&
- (this->lastReferenceIndex == referenceIndex))
- return this->lastBaseCase;
-
- double distance = this->metric.Evaluate(this->querySet.col(queryIndex),
- this->referenceSet.col(referenceIndex));
- ++this->baseCases;
-
- const size_t cluster = oldFromNewCentroids[referenceIndex];
-
- // Is this better than either existing candidate?
- if (distance < this->distances(0, queryIndex))
+ if (distance < upperBounds[queryIndex])
{
- // Do we need to replace the assignment, or is it an old assignment from a
- // previous iteration?
- if (this->neighbors(0, queryIndex) != cluster &&
- this->neighbors(0, queryIndex) < this->referenceSet.n_cols)
- {
- // We must push the old closest assignment down the stack.
- this->neighbors(1, queryIndex) = this->neighbors(0, queryIndex);
- this->distances(1, queryIndex) = this->distances(0, queryIndex);
- this->neighbors(0, queryIndex) = cluster;
- }
- else if (this->neighbors(0, queryIndex) >= this->referenceSet.n_cols)
- {
- this->neighbors(0, queryIndex) = cluster;
- }
-
- this->distances(0, queryIndex) = distance;
+ lowerBounds[queryIndex] = upperBounds[queryIndex];
+ upperBounds[queryIndex] = distance;
+ assignments[queryIndex] = referenceIndex;
}
- else if (distance < this->distances(1, queryIndex))
+ else if (distance < lowerBounds[queryIndex])
{
- // Here it doesn't actually matter if the assignment is the same.
- this->neighbors(1, queryIndex) = cluster;
- this->distances(1, queryIndex) = distance;
+ lowerBounds[queryIndex] = distance;
}
- this->lastQueryIndex = queryIndex;
- this->lastReferenceIndex = referenceIndex;
- this->lastBaseCase = distance;
-
return distance;
}
template<typename MetricType, typename TreeType>
inline double DTNNKMeansRules<MetricType, TreeType>::Score(
const size_t queryIndex,
- TreeType& referenceNode)
+ TreeType& /* referenceNode */)
{
// If the query point has already been pruned, then don't recurse further.
-// Log::Debug << "Score " << queryIndex << ", r" << referenceNode.Point(0) << "c"
-// << referenceNode.NumDescendants() << ".\n";
if (prunedPoints[queryIndex])
return DBL_MAX;
-// Log::Debug << "(not pruned)\n";
- return neighbor::NeighborSearchRules<neighbor::NearestNeighborSort,
- MetricType, TreeType>::Score(queryIndex, referenceNode);
+ // No pruning at this level (for now).
+ return 0;
}
template<typename MetricType, typename TreeType>
@@ -125,38 +86,70 @@ inline double DTNNKMeansRules<MetricType, TreeType>::Score(
TreeType& queryNode,
TreeType& referenceNode)
{
-// Log::Debug << "Score q" << queryNode.Point(0) << "c" <<
-//queryNode.NumDescendants() << ", r" << referenceNode.Point(0) << "c" <<
-//referenceNode.NumDescendants() << ".\n";
- if (queryNode.Stat().Pruned())
+ // Pruned() for the root node must never be set to size_t(-1).
+ if (queryNode.Stat().Pruned() == size_t(-1))
+ {
+ queryNode.Stat().Pruned() = queryNode.Parent()->Stat().Pruned();
+ queryNode.Stat().LowerBound() = queryNode.Parent()->Stat().LowerBound();
+ }
+
+ if (queryNode.Stat().Pruned() == centroids.n_cols)
+ {
+ return DBL_MAX;
+ }
+
+ // Get minimum and maximum distances.
+ math::Range distances = queryNode.RangeDistance(&referenceNode);
+ double score = distances.Lo();
+ ++scores;
+ if (distances.Lo() > queryNode.Stat().UpperBound())
+ {
+ // The reference node can own no points in this query node. We may improve
+ // the lower bound on pruned nodes, though.
+ if (distances.Lo() < queryNode.Stat().LowerBound())
+ queryNode.Stat().LowerBound() = distances.Lo();
+
+ // This assumes that reference clusters don't appear elsewhere in the tree.
+ queryNode.Stat().Pruned() += referenceNode.NumDescendants();
+ score = DBL_MAX;
+ }
+ else if (distances.Hi() < queryNode.Stat().UpperBound())
+ {
+ // We can improve the best estimate.
+ queryNode.Stat().UpperBound() = distances.Hi();
+ // If this node has only one descendant, then it may be the owner.
+ if (referenceNode.NumDescendants() == 1)
+ queryNode.Stat().Owner() = referenceNode.Descendant(0);
+ }
+
+ // Is everything pruned?
+ if (queryNode.Stat().Pruned() == centroids.n_cols - 1)
+ {
+ queryNode.Stat().Pruned() = centroids.n_cols; // Owner() is already set.
return DBL_MAX;
-// Log::Debug << "(not pruned.)\n";
+ }
- // Check if the query node is Hamerly pruned, and if not, then don't continue.
- return neighbor::NeighborSearchRules<neighbor::NearestNeighborSort,
- MetricType, TreeType>::Score(queryNode, referenceNode);
+ return score;
}
template<typename MetricType, typename TreeType>
inline double DTNNKMeansRules<MetricType, TreeType>::Rescore(
- const size_t queryIndex,
- TreeType& referenceNode,
+ const size_t /* queryIndex */,
+ TreeType& /* referenceNode */,
const double oldScore)
{
- return neighbor::NeighborSearchRules<neighbor::NearestNeighborSort,
- MetricType, TreeType>::Rescore(queryIndex, referenceNode, oldScore);
+ // No rescoring (for now).
+ return oldScore;
}
template<typename MetricType, typename TreeType>
inline double DTNNKMeansRules<MetricType, TreeType>::Rescore(
- TreeType& queryNode,
- TreeType& referenceNode,
+ TreeType& /* queryNode */,
+ TreeType& /* referenceNode */,
const double oldScore)
{
- // No need to check for a Hamerly prune. Because we've already done that in
- // Score().
- return neighbor::NeighborSearchRules<neighbor::NearestNeighborSort,
- MetricType, TreeType>::Rescore(queryNode, referenceNode, oldScore);
+ // No rescoring (for now).
+ return oldScore;
}
} // namespace kmeans
diff --git a/src/mlpack/methods/kmeans/dtnn_statistic.hpp b/src/mlpack/methods/kmeans/dtnn_statistic.hpp
index 691a643..b8c04e6 100644
--- a/src/mlpack/methods/kmeans/dtnn_statistic.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_statistic.hpp
@@ -18,11 +18,10 @@ class DTNNStatistic : public
public:
DTNNStatistic() :
neighbor::NeighborSearchStat<neighbor::NearestNeighborSort>(),
- pruned(false),
- iteration(0),
- maxClusterDistance(DBL_MAX),
- secondClusterBound(0.0),
+ upperBound(DBL_MAX),
+ lowerBound(DBL_MAX),
owner(size_t(-1)),
+ pruned(size_t(-1)),
centroid(),
trueLeft(NULL),
trueRight(NULL),
@@ -34,11 +33,10 @@ class DTNNStatistic : public
template<typename TreeType>
DTNNStatistic(TreeType& node) :
neighbor::NeighborSearchStat<neighbor::NearestNeighborSort>(),
- pruned(false),
- iteration(0),
- maxClusterDistance(DBL_MAX),
- secondClusterBound(0.0),
+ upperBound(DBL_MAX),
+ lowerBound(DBL_MAX),
owner(size_t(-1)),
+ pruned(size_t(-1)),
trueLeft((void*) &node.Child(0)),
trueRight((void*) &node.Child(1)),
trueParent((void*) node.Parent())
@@ -55,24 +53,21 @@ class DTNNStatistic : public
centroid /= node.NumDescendants();
}
- bool Pruned() const { return pruned; }
- bool& Pruned() { return pruned; }
+ double UpperBound() const { return upperBound; }
+ double& UpperBound() { return upperBound; }
- size_t Iteration() const { return iteration; }
- size_t& Iteration() { return iteration; }
+ double LowerBound() const { return lowerBound; }
+ double& LowerBound() { return lowerBound; }
- double MaxClusterDistance() const { return maxClusterDistance; }
- double& MaxClusterDistance() { return maxClusterDistance; }
+ const arma::vec& Centroid() const { return centroid; }
+ arma::vec& Centroid() { return centroid; }
- double SecondClusterBound() const { return secondClusterBound; }
- double& SecondClusterBound() { return secondClusterBound; }
+ size_t Pruned() const { return pruned; }
+ size_t& Pruned() { return pruned; }
size_t Owner() const { return owner; }
size_t& Owner() { return owner; }
- const arma::vec& Centroid() const { return centroid; }
- arma::vec& Centroid() { return centroid; }
-
const void* TrueLeft() const { return trueLeft; }
void*& TrueLeft() { return trueLeft; }
@@ -86,20 +81,18 @@ class DTNNStatistic : public
{
std::ostringstream o;
o << "DTNNStatistic [" << this << "]:\n";
+ o << " Upper bound: " << upperBound << ".\n";
+ o << " Lower bound: " << lowerBound << ".\n";
o << " Pruned: " << pruned << ".\n";
- o << " Iteration: " << iteration << ".\n";
- o << " MaxClusterDistance: " << maxClusterDistance << ".\n";
- o << " SecondClusterBound: " << secondClusterBound << ".\n";
o << " Owner: " << owner << ".\n";
return o.str();
}
private:
- bool pruned;
- size_t iteration;
- double maxClusterDistance;
- double secondClusterBound;
+ double upperBound;
+ double lowerBound;
size_t owner;
+ size_t pruned;
arma::vec centroid;
void* trueLeft;
void* trueRight;
diff --git a/src/mlpack/tests/kmeans_test.cpp b/src/mlpack/tests/kmeans_test.cpp
index a296339..593ea6e 100644
--- a/src/mlpack/tests/kmeans_test.cpp
+++ b/src/mlpack/tests/kmeans_test.cpp
@@ -14,12 +14,16 @@
#include <mlpack/methods/kmeans/dual_tree_kmeans.hpp>
#include <mlpack/core/tree/cover_tree/cover_tree.hpp>
+#include <mlpack/methods/neighbor_search/neighbor_search.hpp>
#include <boost/test/unit_test.hpp>
#include "old_boost_test_definitions.hpp"
using namespace mlpack;
using namespace mlpack::kmeans;
+using namespace mlpack::metric;
+using namespace mlpack::tree;
+using namespace mlpack::neighbor;
BOOST_AUTO_TEST_SUITE(KMeansTest);
@@ -690,4 +694,160 @@ BOOST_AUTO_TEST_CASE(DualTreeKMeansTest)
}
}
+BOOST_AUTO_TEST_CASE(DualTreeKMeansBaseCaseTest)
+{
+ // If we run BaseCase() on all the points, do we get valid results?
+ const size_t points = 1000;
+ const size_t clusters = 5;
+ arma::mat dataset(5, points);
+ dataset.randu();
+ arma::mat centroids(5, clusters);
+ centroids.randu();
+
+ // Create the Rules object.
+ arma::Col<size_t> assignments(points);
+ arma::vec upperBounds(points);
+ arma::vec lowerBounds(points);
+ upperBounds.fill(DBL_MAX);
+ lowerBounds.fill(DBL_MAX);
+ std::vector<bool> visited(points, false); // Fill with false.
+ std::vector<size_t> oldFromNewCentroids; // Not used.
+ std::vector<bool> prunedPoints(points, false); // Fill with false.
+
+ EuclideanDistance e;
+ DTNNKMeansRules<EuclideanDistance, BinarySpaceTree<HRectBound<2>,
+ EuclideanDistance, DTNNStatistic> > rules(centroids, dataset,
+ assignments, upperBounds, lowerBounds, e, prunedPoints,
+ oldFromNewCentroids, visited);
+
+ for (size_t i = 0; i < points; ++i)
+ {
+ for (size_t j = 0; j < clusters; ++j)
+ {
+ rules.BaseCase(i, j);
+ }
+ }
+
+ // Now, run nearest neighbors to establish true bounds.
+ neighbor::AllkNN allknn(centroids, dataset);
+ arma::Mat<size_t> trueAssignments;
+ arma::mat trueDistances;
+ allknn.Search(2, trueAssignments, trueDistances);
+
+ for (size_t i = 0; i < points; ++i)
+ {
+ BOOST_REQUIRE_GE(upperBounds[i], trueDistances(0, i));
+ BOOST_REQUIRE_LE(lowerBounds[i], trueDistances(1, i));
+ BOOST_REQUIRE_EQUAL(assignments[i], trueAssignments(0, i));
+ }
+}
+
+BOOST_AUTO_TEST_CASE(DualTreeKMeansScoreKDTreeOneLeafTest)
+{
+ // If we run a dual-tree algorithm, do we get valid results for each point
+ // and/or node when we use the kd-tree with a leaf size of one?
+ const size_t points = 5000;
+ const size_t clusters = 100;
+ arma::mat dataset(5, points);
+ dataset.randu();
+ arma::mat centroids(5, clusters);
+ centroids.randu();
+
+ arma::mat datasetCopy(dataset);
+ arma::mat centroidsCopy(centroids);
+
+ // Create the trees.
+ typedef BinarySpaceTree<HRectBound<2>, DTNNStatistic> TreeType;
+ TreeType pointTree(dataset, 1);
+ TreeType centroidTree(centroids, 1);
+
+ // Create the Rules object.
+ arma::Col<size_t> assignments(points);
+ arma::vec upperBounds(points);
+ arma::vec lowerBounds(points);
+ upperBounds.fill(DBL_MAX);
+ lowerBounds.fill(DBL_MAX);
+ std::vector<bool> visited(points, false); // Fill with false.
+ std::vector<size_t> oldFromNewCentroids; // Not used.
+ std::vector<bool> prunedPoints(points, false); // Fill with false.
+
+ EuclideanDistance e;
+ typedef DTNNKMeansRules<EuclideanDistance, TreeType> RuleType;
+ RuleType rules(centroids, dataset, assignments, upperBounds, lowerBounds, e,
+ prunedPoints, oldFromNewCentroids, visited);
+
+ // Now create the traverser.
+ typename TreeType::template BreadthFirstDualTreeTraverser<RuleType>
+ traverser(rules);
+
+ pointTree.Stat().Pruned() = 0;
+ traverser.Traverse(pointTree, centroidTree);
+
+ // Get true bounds.
+ AllkNN allknn(centroids, dataset);
+
+ arma::Mat<size_t> trueAssignments;
+ arma::mat trueDistances;
+ allknn.Search(2, trueAssignments, trueDistances);
+
+ // Check the points first. Lots of weird mappings have to go on in this stage
+ // because the tree building procedure changed all the points around.
+ for (size_t i = 0; i < dataset.n_cols; ++i)
+ {
+ if (visited[i])
+ {
+ BOOST_REQUIRE_GE(upperBounds[i], trueDistances(0, i));
+ BOOST_REQUIRE_EQUAL(assignments[i], trueAssignments(0, i));
+ }
+ }
+
+ // Now traverse the tree to see if it is correct.
+ std::queue<TreeType*> nodeQueue;
+ nodeQueue.push(&pointTree);
+ while (!nodeQueue.empty())
+ {
+ // This is an expensive operation. We must ensure that the upper bound is
+ // valid and that the lower bound is valid. Both can be needlessly loose,
+ // but that will simply affect the pruning of the method, not the
+ // correctness. Here we care about correctness.
+ TreeType* node = nodeQueue.front();
+ nodeQueue.pop();
+
+ // We must make sure the upper bound and lower bound are both valid for all
+ // descendant points. The lower bound only matters if the node was pruned.
+ // So, we must calculate the upper and lower bounds manually for the
+ // descendants.
+ double exactUpperBound = 0.0;
+ double exactLowerBound = DBL_MAX;
+ for (size_t i = 0; i < node->NumDescendants(); ++i)
+ {
+ if (trueDistances(0, node->Descendant(i)) > exactUpperBound)
+ exactUpperBound = trueDistances(0, node->Descendant(i));
+ if (trueDistances(1, node->Descendant(i)) < exactLowerBound)
+ exactLowerBound = trueDistances(1, node->Descendant(i));
+ }
+
+ // Multiplication is to add some tolerance for floating point discrepancies.
+ BOOST_REQUIRE_GE(node->Stat().UpperBound() * 1.000001, exactUpperBound);
+
+ if (node->Stat().Pruned() == centroids.n_cols)
+ {
+ BOOST_REQUIRE_LE(node->Stat().LowerBound() * 0.99999, exactLowerBound);
+ }
+ else
+ {
+ for (size_t i = 0; i < node->NumPoints(); ++i)
+ {
+ const double bestLower = std::min(node->Stat().LowerBound(),
+ lowerBounds[node->Point(i)]);
+ BOOST_REQUIRE_LE(bestLower * 0.99999, trueDistances(1, node->Point(i)));
+ }
+
+ // Recurse.
+ for (size_t i = 0; i < node->NumChildren(); ++i)
+ nodeQueue.push(&node->Child(i));
+ }
+ }
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list