[mlpack-git] master: Round four. Start over. This time, I proved that the algorithm is right before implementing it. This should help make debugging a lot easier. (540480e)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:05:12 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44

>---------------------------------------------------------------

commit 540480ea66afba266f8791082c33ed0edc3a3993
Author: Ryan Curtin <ryan at ratml.org>
Date:   Sat Feb 14 14:05:03 2015 -0500

    Round four. Start over. This time, I proved that the algorithm is right before implementing it. This should help make debugging a lot easier.


>---------------------------------------------------------------

540480ea66afba266f8791082c33ed0edc3a3993
 src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp |  22 ++--
 src/mlpack/methods/kmeans/dtnn_rules.hpp       |  24 +++-
 src/mlpack/methods/kmeans/dtnn_rules_impl.hpp  | 157 ++++++++++++------------
 src/mlpack/methods/kmeans/dtnn_statistic.hpp   |  45 +++----
 src/mlpack/tests/kmeans_test.cpp               | 160 +++++++++++++++++++++++++
 5 files changed, 283 insertions(+), 125 deletions(-)

diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
index b451647..3fa25ab 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
@@ -148,24 +148,14 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
 
   // We won't use the AllkNN class here because we have our own set of rules.
   // This is a lot of overhead.  We don't need the distances.
-  Timer::Start("knn");
-  typedef DTNNKMeansRules<MetricType, TreeType> RuleType;
-  RuleType rules(centroids, dataset, assignments, distances, metric,
-      prunedPoints, oldFromNewCentroids, visited);
-
-  // Now construct the traverser ourselves.
-  typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
-
-  traverser.Traverse(*tree, *centroidTree);
-  Timer::Stop("knn");
 
   Timer::Start("tree_mod");
   DecoalesceTree(*tree);
   Timer::Stop("tree_mod");
 
-  Log::Info << "This iteration: " << rules.BaseCases() << " base cases, " <<
-      rules.Scores() << " scores.\n";
-  distanceCalculations += rules.BaseCases() + rules.Scores();
+//  Log::Info << "This iteration: " << rules.BaseCases() << " base cases, " <<
+//      rules.Scores() << " scores.\n";
+//  distanceCalculations += rules.BaseCases() + rules.Scores();
 
   // From the assignments, calculate the new centroids and counts.
   for (size_t i = 0; i < dataset.n_cols; ++i)
@@ -225,6 +215,7 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
     const arma::mat& interclusterDistances,
     const std::vector<size_t>& newFromOldCentroids)
 {
+/*
   // Update iteration.
 //  node.Stat().Iteration() = iteration;
 
@@ -510,6 +501,7 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
     node.Stat().SecondBound() += clusterDistances[centroids.n_cols];
   if (node.Stat().Bound() != DBL_MAX)
     node.Stat().Bound() += clusterDistances[centroids.n_cols];
+*/
 }
 
 template<typename MetricType, typename MatType, typename TreeType>
@@ -517,6 +509,7 @@ void DTNNKMeans<MetricType, MatType, TreeType>::CoalesceTree(
     TreeType& node,
     const size_t child /* Which child are we? */)
 {
+/*
   // If one of the two children is pruned, we hide this node.
   // This assumes the BinarySpaceTree.  (bad Ryan! bad!)
   if (node.NumChildren() == 0)
@@ -554,11 +547,13 @@ void DTNNKMeans<MetricType, MatType, TreeType>::CoalesceTree(
     CoalesceTree(node.Child(0), 0);
     CoalesceTree(node.Child(1), 1);
   }
+*/
 }
 
 template<typename MetricType, typename MatType, typename TreeType>
 void DTNNKMeans<MetricType, MatType, TreeType>::DecoalesceTree(TreeType& node)
 {
+/*
   node.Parent() = (TreeType*) node.Stat().TrueParent();
   node.ChildPtr(0) = (TreeType*) node.Stat().TrueLeft();
   node.ChildPtr(1) = (TreeType*) node.Stat().TrueRight();
@@ -568,6 +563,7 @@ void DTNNKMeans<MetricType, MatType, TreeType>::DecoalesceTree(TreeType& node)
     DecoalesceTree(node.Child(0));
     DecoalesceTree(node.Child(1));
   }
+*/
 }
 
 template<typename MetricType, typename MatType, typename TreeType>
diff --git a/src/mlpack/methods/kmeans/dtnn_rules.hpp b/src/mlpack/methods/kmeans/dtnn_rules.hpp
index 0183e9e..1c3cda4 100644
--- a/src/mlpack/methods/kmeans/dtnn_rules.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_rules.hpp
@@ -15,14 +15,14 @@ namespace mlpack {
 namespace kmeans {
 
 template<typename MetricType, typename TreeType>
-class DTNNKMeansRules : public neighbor::NeighborSearchRules<
-    neighbor::NearestNeighborSort, MetricType, TreeType>
+class DTNNKMeansRules
 {
  public:
   DTNNKMeansRules(const arma::mat& centroids,
                   const arma::mat& dataset,
-                  arma::Mat<size_t>& neighbors,
-                  arma::mat& distances,
+                  arma::Col<size_t>& assignments,
+                  arma::vec& upperBounds,
+                  arma::vec& lowerBounds,
                   MetricType& metric,
                   const std::vector<bool>& prunedPoints,
                   const std::vector<size_t>& oldFromNewCentroids,
@@ -39,13 +39,29 @@ class DTNNKMeansRules : public neighbor::NeighborSearchRules<
                  TreeType& referenceNode,
                  const double oldScore);
 
+  typedef int TraversalInfoType;
+
+  TraversalInfoType& TraversalInfo() { return traversalInfo; }
+  const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
+
  private:
+  const arma::mat& centroids;
+  const arma::mat& dataset;
+  arma::Col<size_t>& assignments;
+  arma::vec& upperBounds;
+  arma::vec& lowerBounds;
+  MetricType& metric;
 
   const std::vector<bool>& prunedPoints;
 
   const std::vector<size_t>& oldFromNewCentroids;
 
   std::vector<bool>& visited;
+
+  size_t baseCases;
+  size_t scores;
+
+  int traversalInfo;
 };
 
 } // namespace kmeans
diff --git a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
index a526c94..7e1904f 100644
--- a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
@@ -16,17 +16,24 @@ template<typename MetricType, typename TreeType>
 DTNNKMeansRules<MetricType, TreeType>::DTNNKMeansRules(
     const arma::mat& centroids,
     const arma::mat& dataset,
-    arma::Mat<size_t>& neighbors,
-    arma::mat& distances,
+    arma::Col<size_t>& assignments,
+    arma::vec& upperBounds,
+    arma::vec& lowerBounds,
     MetricType& metric,
     const std::vector<bool>& prunedPoints,
     const std::vector<size_t>& oldFromNewCentroids,
     std::vector<bool>& visited) :
-    neighbor::NeighborSearchRules<neighbor::NearestNeighborSort, MetricType,
-        TreeType>(centroids, dataset, neighbors, distances, metric),
+    centroids(centroids),
+    dataset(dataset),
+    assignments(assignments),
+    upperBounds(upperBounds),
+    lowerBounds(lowerBounds),
+    metric(metric),
     prunedPoints(prunedPoints),
     oldFromNewCentroids(oldFromNewCentroids),
-    visited(visited)
+    visited(visited),
+    baseCases(0),
+    scores(0)
 {
   // Nothing to do.
 }
@@ -36,88 +43,42 @@ inline force_inline double DTNNKMeansRules<MetricType, TreeType>::BaseCase(
     const size_t queryIndex,
     const size_t referenceIndex)
 {
-  // We'll check if the query point has been pruned.  If so, don't continue.
-//  Log::Debug << "Base case " << queryIndex << ", " << referenceIndex <<
-//".\n";
   if (prunedPoints[queryIndex])
     return 0.0; // Returning 0 shouldn't be a problem.
-//  Log::Debug << "(not pruned.)\n";
 
   // Any base cases imply that we will get a result.
   visited[queryIndex] = true;
 
-  // This is basically an inlined NeighborSearchRules::BaseCase(), but it
-  // differs in that it applies the mappings to the results automatically.
-  // We can also skip a check or two.
+  // Calculate the distance.
+  ++baseCases;
+  const double distance = metric.Evaluate(dataset.col(queryIndex),
+                                          centroids.col(referenceIndex));
 
-  // By the way, all of the this-> is necessary because the parent class is a
-  // dependent name, so all of the members of that parent aren't resolvable
-  // before type substitution.  The 'this->' turns that member into a dependent
-  // name too (since the type of 'this' is dependent), and thus the compiler
-  // resolves the name later and we get no error.  Hooray C++!
-  //
-  // See also:
-  // http://stackoverflow.com/questions/10639053/name-lookups-in-c-templates
-
-  // If we have already performed this base case, do not perform it again.
-  if ((this->lastQueryIndex == queryIndex) &&
-      (this->lastReferenceIndex == referenceIndex))
-    return this->lastBaseCase;
-
-  double distance = this->metric.Evaluate(this->querySet.col(queryIndex),
-      this->referenceSet.col(referenceIndex));
-  ++this->baseCases;
-
-  const size_t cluster = oldFromNewCentroids[referenceIndex];
-
-  // Is this better than either existing candidate?
-  if (distance < this->distances(0, queryIndex))
+  if (distance < upperBounds[queryIndex])
   {
-    // Do we need to replace the assignment, or is it an old assignment from a
-    // previous iteration?
-    if (this->neighbors(0, queryIndex) != cluster &&
-        this->neighbors(0, queryIndex) < this->referenceSet.n_cols)
-    {
-      // We must push the old closest assignment down the stack.
-      this->neighbors(1, queryIndex) = this->neighbors(0, queryIndex);
-      this->distances(1, queryIndex) = this->distances(0, queryIndex);
-      this->neighbors(0, queryIndex) = cluster;
-    }
-    else if (this->neighbors(0, queryIndex) >= this->referenceSet.n_cols)
-    {
-      this->neighbors(0, queryIndex) = cluster;
-    }
-
-    this->distances(0, queryIndex) = distance;
+    lowerBounds[queryIndex] = upperBounds[queryIndex];
+    upperBounds[queryIndex] = distance;
+    assignments[queryIndex] = referenceIndex;
   }
-  else if (distance < this->distances(1, queryIndex))
+  else if (distance < lowerBounds[queryIndex])
   {
-    // Here it doesn't actually matter if the assignment is the same.
-    this->neighbors(1, queryIndex) = cluster;
-    this->distances(1, queryIndex) = distance;
+    lowerBounds[queryIndex] = distance;
   }
 
-  this->lastQueryIndex = queryIndex;
-  this->lastReferenceIndex = referenceIndex;
-  this->lastBaseCase = distance;
-
   return distance;
 }
 
 template<typename MetricType, typename TreeType>
 inline double DTNNKMeansRules<MetricType, TreeType>::Score(
     const size_t queryIndex,
-    TreeType& referenceNode)
+    TreeType& /* referenceNode */)
 {
   // If the query point has already been pruned, then don't recurse further.
-//  Log::Debug << "Score " << queryIndex << ", r" << referenceNode.Point(0) << "c"
-//      << referenceNode.NumDescendants() << ".\n";
   if (prunedPoints[queryIndex])
     return DBL_MAX;
-//  Log::Debug << "(not pruned)\n";
 
-  return neighbor::NeighborSearchRules<neighbor::NearestNeighborSort,
-      MetricType, TreeType>::Score(queryIndex, referenceNode);
+  // No pruning at this level (for now).
+  return 0;
 }
 
 template<typename MetricType, typename TreeType>
@@ -125,38 +86,70 @@ inline double DTNNKMeansRules<MetricType, TreeType>::Score(
     TreeType& queryNode,
     TreeType& referenceNode)
 {
-//  Log::Debug << "Score q" << queryNode.Point(0) << "c" <<
-//queryNode.NumDescendants() << ", r" << referenceNode.Point(0) << "c" <<
-//referenceNode.NumDescendants() << ".\n";
-  if (queryNode.Stat().Pruned())
+  // Pruned() for the root node must never be set to size_t(-1).
+  if (queryNode.Stat().Pruned() == size_t(-1))
+  {
+    queryNode.Stat().Pruned() = queryNode.Parent()->Stat().Pruned();
+    queryNode.Stat().LowerBound() = queryNode.Parent()->Stat().LowerBound();
+  }
+
+  if (queryNode.Stat().Pruned() == centroids.n_cols)
+  {
+    return DBL_MAX;
+  }
+
+  // Get minimum and maximum distances.
+  math::Range distances = queryNode.RangeDistance(&referenceNode);
+  double score = distances.Lo();
+  ++scores;
+  if (distances.Lo() > queryNode.Stat().UpperBound())
+  {
+    // The reference node can own no points in this query node.  We may improve
+    // the lower bound on pruned nodes, though.
+    if (distances.Lo() < queryNode.Stat().LowerBound())
+      queryNode.Stat().LowerBound() = distances.Lo();
+
+    // This assumes that reference clusters don't appear elsewhere in the tree.
+    queryNode.Stat().Pruned() += referenceNode.NumDescendants();
+    score = DBL_MAX;
+  }
+  else if (distances.Hi() < queryNode.Stat().UpperBound())
+  {
+    // We can improve the best estimate.
+    queryNode.Stat().UpperBound() = distances.Hi();
+    // If this node has only one descendant, then it may be the owner.
+    if (referenceNode.NumDescendants() == 1)
+      queryNode.Stat().Owner() = referenceNode.Descendant(0);
+  }
+
+  // Is everything pruned?
+  if (queryNode.Stat().Pruned() == centroids.n_cols - 1)
+  {
+    queryNode.Stat().Pruned() = centroids.n_cols; // Owner() is already set.
     return DBL_MAX;
-//  Log::Debug << "(not pruned.)\n";
+  }
 
-  // Check if the query node is Hamerly pruned, and if not, then don't continue.
-  return neighbor::NeighborSearchRules<neighbor::NearestNeighborSort,
-      MetricType, TreeType>::Score(queryNode, referenceNode);
+  return score;
 }
 
 template<typename MetricType, typename TreeType>
 inline double DTNNKMeansRules<MetricType, TreeType>::Rescore(
-    const size_t queryIndex,
-    TreeType& referenceNode,
+    const size_t /* queryIndex */,
+    TreeType& /* referenceNode */,
     const double oldScore)
 {
-  return neighbor::NeighborSearchRules<neighbor::NearestNeighborSort,
-      MetricType, TreeType>::Rescore(queryIndex, referenceNode, oldScore);
+  // No rescoring (for now).
+  return oldScore;
 }
 
 template<typename MetricType, typename TreeType>
 inline double DTNNKMeansRules<MetricType, TreeType>::Rescore(
-    TreeType& queryNode,
-    TreeType& referenceNode,
+    TreeType& /* queryNode */,
+    TreeType& /* referenceNode */,
     const double oldScore)
 {
-  // No need to check for a Hamerly prune.  Because we've already done that in
-  // Score().
-  return neighbor::NeighborSearchRules<neighbor::NearestNeighborSort,
-      MetricType, TreeType>::Rescore(queryNode, referenceNode, oldScore);
+  // No rescoring (for now).
+  return oldScore;
 }
 
 } // namespace kmeans
diff --git a/src/mlpack/methods/kmeans/dtnn_statistic.hpp b/src/mlpack/methods/kmeans/dtnn_statistic.hpp
index 691a643..b8c04e6 100644
--- a/src/mlpack/methods/kmeans/dtnn_statistic.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_statistic.hpp
@@ -18,11 +18,10 @@ class DTNNStatistic : public
  public:
   DTNNStatistic() :
       neighbor::NeighborSearchStat<neighbor::NearestNeighborSort>(),
-      pruned(false),
-      iteration(0),
-      maxClusterDistance(DBL_MAX),
-      secondClusterBound(0.0),
+      upperBound(DBL_MAX),
+      lowerBound(DBL_MAX),
       owner(size_t(-1)),
+      pruned(size_t(-1)),
       centroid(),
       trueLeft(NULL),
       trueRight(NULL),
@@ -34,11 +33,10 @@ class DTNNStatistic : public
   template<typename TreeType>
   DTNNStatistic(TreeType& node) :
       neighbor::NeighborSearchStat<neighbor::NearestNeighborSort>(),
-      pruned(false),
-      iteration(0),
-      maxClusterDistance(DBL_MAX),
-      secondClusterBound(0.0),
+      upperBound(DBL_MAX),
+      lowerBound(DBL_MAX),
       owner(size_t(-1)),
+      pruned(size_t(-1)),
       trueLeft((void*) &node.Child(0)),
       trueRight((void*) &node.Child(1)),
       trueParent((void*) node.Parent())
@@ -55,24 +53,21 @@ class DTNNStatistic : public
     centroid /= node.NumDescendants();
   }
 
-  bool Pruned() const { return pruned; }
-  bool& Pruned() { return pruned; }
+  double UpperBound() const { return upperBound; }
+  double& UpperBound() { return upperBound; }
 
-  size_t Iteration() const { return iteration; }
-  size_t& Iteration() { return iteration; }
+  double LowerBound() const { return lowerBound; }
+  double& LowerBound() { return lowerBound; }
 
-  double MaxClusterDistance() const { return maxClusterDistance; }
-  double& MaxClusterDistance() { return maxClusterDistance; }
+  const arma::vec& Centroid() const { return centroid; }
+  arma::vec& Centroid() { return centroid; }
 
-  double SecondClusterBound() const { return secondClusterBound; }
-  double& SecondClusterBound() { return secondClusterBound; }
+  size_t Pruned() const { return pruned; }
+  size_t& Pruned() { return pruned; }
 
   size_t Owner() const { return owner; }
   size_t& Owner() { return owner; }
 
-  const arma::vec& Centroid() const { return centroid; }
-  arma::vec& Centroid() { return centroid; }
-
   const void* TrueLeft() const { return trueLeft; }
   void*& TrueLeft() { return trueLeft; }
 
@@ -86,20 +81,18 @@ class DTNNStatistic : public
   {
     std::ostringstream o;
     o << "DTNNStatistic [" << this << "]:\n";
+    o << "  Upper bound: " << upperBound << ".\n";
+    o << "  Lower bound: " << lowerBound << ".\n";
     o << "  Pruned: " << pruned << ".\n";
-    o << "  Iteration: " << iteration << ".\n";
-    o << "  MaxClusterDistance: " << maxClusterDistance << ".\n";
-    o << "  SecondClusterBound: " << secondClusterBound << ".\n";
     o << "  Owner: " << owner << ".\n";
     return o.str();
   }
 
  private:
-  bool pruned;
-  size_t iteration;
-  double maxClusterDistance;
-  double secondClusterBound;
+  double upperBound;
+  double lowerBound;
   size_t owner;
+  size_t pruned;
   arma::vec centroid;
   void* trueLeft;
   void* trueRight;
diff --git a/src/mlpack/tests/kmeans_test.cpp b/src/mlpack/tests/kmeans_test.cpp
index a296339..593ea6e 100644
--- a/src/mlpack/tests/kmeans_test.cpp
+++ b/src/mlpack/tests/kmeans_test.cpp
@@ -14,12 +14,16 @@
 #include <mlpack/methods/kmeans/dual_tree_kmeans.hpp>
 
 #include <mlpack/core/tree/cover_tree/cover_tree.hpp>
+#include <mlpack/methods/neighbor_search/neighbor_search.hpp>
 
 #include <boost/test/unit_test.hpp>
 #include "old_boost_test_definitions.hpp"
 
 using namespace mlpack;
 using namespace mlpack::kmeans;
+using namespace mlpack::metric;
+using namespace mlpack::tree;
+using namespace mlpack::neighbor;
 
 BOOST_AUTO_TEST_SUITE(KMeansTest);
 
@@ -690,4 +694,160 @@ BOOST_AUTO_TEST_CASE(DualTreeKMeansTest)
   }
 }
 
+BOOST_AUTO_TEST_CASE(DualTreeKMeansBaseCaseTest)
+{
+  // If we run BaseCase() on all the points, do we get valid results?
+  const size_t points = 1000;
+  const size_t clusters = 5;
+  arma::mat dataset(5, points);
+  dataset.randu();
+  arma::mat centroids(5, clusters);
+  centroids.randu();
+
+  // Create the Rules object.
+  arma::Col<size_t> assignments(points);
+  arma::vec upperBounds(points);
+  arma::vec lowerBounds(points);
+  upperBounds.fill(DBL_MAX);
+  lowerBounds.fill(DBL_MAX);
+  std::vector<bool> visited(points, false); // Fill with false.
+  std::vector<size_t> oldFromNewCentroids; // Not used.
+  std::vector<bool> prunedPoints(points, false); // Fill with false.
+
+  EuclideanDistance e;
+  DTNNKMeansRules<EuclideanDistance, BinarySpaceTree<HRectBound<2>,
+      EuclideanDistance, DTNNStatistic> > rules(centroids, dataset,
+      assignments, upperBounds, lowerBounds, e, prunedPoints,
+      oldFromNewCentroids, visited);
+
+  for (size_t i = 0; i < points; ++i)
+  {
+    for (size_t j = 0; j < clusters; ++j)
+    {
+      rules.BaseCase(i, j);
+    }
+  }
+
+  // Now, run nearest neighbors to establish true bounds.
+  neighbor::AllkNN allknn(centroids, dataset);
+  arma::Mat<size_t> trueAssignments;
+  arma::mat trueDistances;
+  allknn.Search(2, trueAssignments, trueDistances);
+
+  for (size_t i = 0; i < points; ++i)
+  {
+    BOOST_REQUIRE_GE(upperBounds[i], trueDistances(0, i));
+    BOOST_REQUIRE_LE(lowerBounds[i], trueDistances(1, i));
+    BOOST_REQUIRE_EQUAL(assignments[i], trueAssignments(0, i));
+  }
+}
+
+BOOST_AUTO_TEST_CASE(DualTreeKMeansScoreKDTreeOneLeafTest)
+{
+  // If we run a dual-tree algorithm, do we get valid results for each point
+  // and/or node when we use the kd-tree with a leaf size of one?
+  const size_t points = 5000;
+  const size_t clusters = 100;
+  arma::mat dataset(5, points);
+  dataset.randu();
+  arma::mat centroids(5, clusters);
+  centroids.randu();
+
+  arma::mat datasetCopy(dataset);
+  arma::mat centroidsCopy(centroids);
+
+  // Create the trees.
+  typedef BinarySpaceTree<HRectBound<2>, DTNNStatistic> TreeType;
+  TreeType pointTree(dataset,  1);
+  TreeType centroidTree(centroids, 1);
+
+  // Create the Rules object.
+  arma::Col<size_t> assignments(points);
+  arma::vec upperBounds(points);
+  arma::vec lowerBounds(points);
+  upperBounds.fill(DBL_MAX);
+  lowerBounds.fill(DBL_MAX);
+  std::vector<bool> visited(points, false); // Fill with false.
+  std::vector<size_t> oldFromNewCentroids; // Not used.
+  std::vector<bool> prunedPoints(points, false); // Fill with false.
+
+  EuclideanDistance e;
+  typedef DTNNKMeansRules<EuclideanDistance, TreeType> RuleType;
+  RuleType rules(centroids, dataset, assignments, upperBounds, lowerBounds, e,
+      prunedPoints, oldFromNewCentroids, visited);
+
+  // Now create the traverser.
+  typename TreeType::template BreadthFirstDualTreeTraverser<RuleType>
+      traverser(rules);
+
+  pointTree.Stat().Pruned() = 0;
+  traverser.Traverse(pointTree, centroidTree);
+
+  // Get true bounds.
+  AllkNN allknn(centroids, dataset);
+
+  arma::Mat<size_t> trueAssignments;
+  arma::mat trueDistances;
+  allknn.Search(2, trueAssignments, trueDistances);
+
+  // Check the points first.  Lots of weird mappings have to go on in this stage
+  // because the tree building procedure changed all the points around.
+  for (size_t i = 0; i < dataset.n_cols; ++i)
+  {
+    if (visited[i])
+    {
+      BOOST_REQUIRE_GE(upperBounds[i], trueDistances(0, i));
+      BOOST_REQUIRE_EQUAL(assignments[i], trueAssignments(0, i));
+    }
+  }
+
+  // Now traverse the tree to see if it is correct.
+  std::queue<TreeType*> nodeQueue;
+  nodeQueue.push(&pointTree);
+  while (!nodeQueue.empty())
+  {
+    // This is an expensive operation.  We must ensure that the upper bound is
+    // valid and that the lower bound is valid.  Both can be needlessly loose,
+    // but that will simply affect the pruning of the method, not the
+    // correctness.  Here we care about correctness.
+    TreeType* node = nodeQueue.front();
+    nodeQueue.pop();
+
+    // We must make sure the upper bound and lower bound are both valid for all
+    // descendant points.  The lower bound only matters if the node was pruned.
+    // So, we must calculate the upper and lower bounds manually for the
+    // descendants.
+    double exactUpperBound = 0.0;
+    double exactLowerBound = DBL_MAX;
+    for (size_t i = 0; i < node->NumDescendants(); ++i)
+    {
+      if (trueDistances(0, node->Descendant(i)) > exactUpperBound)
+        exactUpperBound = trueDistances(0, node->Descendant(i));
+      if (trueDistances(1, node->Descendant(i)) < exactLowerBound)
+        exactLowerBound = trueDistances(1, node->Descendant(i));
+    }
+
+    // Multiplication is to add some tolerance for floating point discrepancies.
+    BOOST_REQUIRE_GE(node->Stat().UpperBound() * 1.000001, exactUpperBound);
+
+    if (node->Stat().Pruned() == centroids.n_cols)
+    {
+      BOOST_REQUIRE_LE(node->Stat().LowerBound() * 0.99999, exactLowerBound);
+    }
+    else
+    {
+      for (size_t i = 0; i < node->NumPoints(); ++i)
+      {
+        const double bestLower = std::min(node->Stat().LowerBound(),
+            lowerBounds[node->Point(i)]);
+        BOOST_REQUIRE_LE(bestLower * 0.99999, trueDistances(1, node->Point(i)));
+      }
+
+      // Recurse.
+      for (size_t i = 0; i < node->NumChildren(); ++i)
+        nodeQueue.push(&node->Child(i));
+    }
+  }
+}
+
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list