[mlpack-git] master: I don't think this is worth saving. It also doesn't work very well, but I learned a lot about the bookkeeping I need to do. (de01b8f)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:01:51 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44

>---------------------------------------------------------------

commit de01b8f67ae3c54ab259069ac5c369723dcebfd0
Author: Ryan Curtin <ryan at ratml.org>
Date:   Wed Jan 28 15:17:53 2015 -0500

    I don't think this is worth saving. It also doesn't work very well, but I learned a lot about the bookkeeping I need to do.


>---------------------------------------------------------------

de01b8f67ae3c54ab259069ac5c369723dcebfd0
 .../methods/kmeans/dual_tree_kmeans_impl.hpp       | 322 +++++++--------------
 .../methods/kmeans/dual_tree_kmeans_rules_impl.hpp | 253 ++++------------
 .../methods/kmeans/dual_tree_kmeans_statistic.hpp  |  39 ++-
 src/mlpack/methods/kmeans/hamerly_kmeans_impl.hpp  |   5 +
 4 files changed, 204 insertions(+), 415 deletions(-)

diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
index f08b4e1..7a65d3e 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
@@ -39,7 +39,7 @@ DualTreeKMeans<MetricType, MatType, TreeType>::DualTreeKMeans(
     datasetCopy = datasetOrig;
 
   // Now build the tree.  We don't need any mappings.
-  tree = new TreeType(const_cast<typename TreeType::Mat&>(this->dataset));
+  tree = new TreeType(const_cast<typename TreeType::Mat&>(this->dataset), 1);
 
   Timer::Stop("tree_building");
 }
@@ -186,9 +186,14 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::TreeUpdate(
 {
   // This is basically IterationUpdate(), but pulled out to be separate from the
   // actual dual-tree algorithm.
-
+  if (node->Begin() == 26038)
+    Log::Warn << "r26038c" << node->Count() << " has owner " <<
+node->Stat().Owner() << ".\n";
   if (node->Parent() != NULL && node->Parent()->Stat().Owner() < clusters)
     node->Stat().Owner() = node->Parent()->Stat().Owner();
+  if (node->Begin() == 26038)
+    Log::Warn << "r26038c" << node->Count() << " has owner " <<
+node->Stat().Owner() << " after parent check.\n";
 
   const size_t cluster = assignments[node->Descendant(0)];
   bool allSame = true;
@@ -203,242 +208,102 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::TreeUpdate(
 
   if (allSame)
     node->Stat().Owner() = cluster;
+  else
+    node->Stat().Owner() = centroids.n_cols;
+  if (node->Begin() == 26038)
+    Log::Warn << "r26038c" << node->Count() << " has manually set owner " <<
+node->Stat().Owner() << ".\n";
 
   const bool prunedLastIteration = node->Stat().HamerlyPruned();
   node->Stat().HamerlyPruned() = false;
 
-  if (node->Begin() == 23058)
-    Log::Warn << "r23058c" << node->Count() << " has owner " <<
+  if (node->Begin() == 26038)
+    Log::Warn << "r26038c" << node->Count() << " has owner " <<
 node->Stat().Owner() << ".\n";
 
   // The easy case: this node had an owner.
   if (node->Stat().Owner() < clusters)
   {
-    // During the last iteration, this node was pruned.
-    const size_t owner = node->Stat().Owner();
-    if (node->Stat().MaxQueryNodeDistance() != DBL_MAX)
-      node->Stat().MaxQueryNodeDistance() += clusterDistances[owner];
-    if (node->Stat().MinQueryNodeDistance() != DBL_MAX)
-      node->Stat().MinQueryNodeDistance() += clusterDistances[owner];
-
-    // Check if we can perform a Hamerly prune: if the node has an owner, and
-    // the second closest cluster could not have moved close enough that any
-    // points could have changed assignment, then this node *must* belong to the
-    // same owner in the next iteration.  Note that MaxQueryNodeDistance() has
-    // already been adjusted for cluster movement.
-
-    // Re-set second closest bound if necessary.
-    if (node->Stat().SecondClosestBound() == DBL_MAX && node->Parent() == NULL)
-      node->Stat().SecondClosestBound() = 0.0; // Don't prune the root.
-
-    if (node->Begin() == 23058)
-      Log::Warn << "r23058c" << node->Count() << " scb " <<
-node->Stat().SecondClosestBound() << " and lscb " <<
-node->Stat().LastSecondClosestBound() << ".\n";
-
-    // If both the second closest bound and last second closest bound are valid,
-    // we have the option of taking the better of the two bounds.  But if only
-    // one is valid, take the minimum of the two (which will be the valid one).
-    // If neither is valid, then we end up with a second closest bound of
-    // DBL_MAX.
-    const double scb = node->Stat().SecondClosestBound();
-    const double lscb = node->Stat().LastSecondClosestBound();
-    if (scb != DBL_MAX && lscb != DBL_MAX)
-      node->Stat().SecondClosestBound() = std::max(scb, lscb);
-    else
-      node->Stat().SecondClosestBound() = std::min(scb, lscb);
-
-    // But if we were Hamerly pruned last time, we can't trust the second
-    // closest bound and thus have to take last iteration's.
-    if (prunedLastIteration)
-      node->Stat().SecondClosestBound() = lscb;
-    else
+    // Verify correctness...
+    for (size_t i = 0; i < node->NumDescendants(); ++i)
     {
-      // Now, we must ensure that we don't need to take the parent's second
-      // closest bound.  We surely do if the current bound is DBL_MAX.  We
-      // already took care of the root node earlier so we don't need to check if
-      // Parent() is NULL.
-      if (node->Stat().SecondClosestBound() == DBL_MAX)
-        node->Stat().SecondClosestBound() =
-            node->Parent()->Stat().SecondClosestBound();
-
-      // There may exist a case where the true second closest query node got
-      // pruned by the parent, and was thus never visited with this node.  This
-      // situation occurs if the second closest query node is not a descendant
-      // of the second closest query node of the parent.
-      if (node->Stat().SecondClosestQueryNode() != NULL)
+      size_t closest = clusters;
+      double closestDistance = DBL_MAX;
+      arma::vec distances(centroids.n_cols);
+      for (size_t j = 0; j < centroids.n_cols; ++j)
       {
-        if (node->Begin() == 23058)
+        const double distance = metric.Evaluate(centroids.col(j),
+            dataset.col(node->Descendant(i)));
+        if (distance < closestDistance)
         {
-          Log::Warn << "Second closest query node is q" << ((TreeType*)
-node->Stat().SecondClosestQueryNode())->Begin() << "c" << ((TreeType*)
-node->Stat().SecondClosestQueryNode())->Count() << ", with scb " <<
-node->Stat().SecondClosestBound() << ".\n";
-          Log::Warn << "True SCB to this node should be " <<
-node->MinDistance((TreeType*) node->Stat().SecondClosestQueryNode()) << ".\n";
+          closest = j;
+          closestDistance = distance;
         }
+        distances(j) = distance;
       }
 
-      if (node->Stat().ClosestQueryNode() != NULL)
-        if (node->Begin() == 23058)
-          Log::Warn << "Closest query node: q" << ((TreeType*)
-node->Stat().ClosestQueryNode())->Begin() << "c" << ((TreeType*)
-node->Stat().ClosestQueryNode())->Count() << ", with MQND " <<
-node->Stat().MaxQueryNodeDistance() << " and mQND " <<
-node->Stat().MinQueryNodeDistance() << ".\n";
-
-      // If the closest query node contains more than one descendant, we have to
-      // find the closest...
-      TreeType* cqn = (TreeType*) node->Stat().ClosestQueryNode();
-      if (cqn != NULL && cqn->NumDescendants() > 1)
+      if (closest != node->Stat().Owner())
       {
-        size_t closest = centroids.n_cols;
-        double closestDistance = DBL_MAX;
-        size_t secondClosest = centroids.n_cols;
-        double secondClosestDistance = DBL_MAX;
-        for (size_t i = 0; i < cqn->NumDescendants(); ++i)
-        {
-          const size_t index = cqn->Descendant(i);
-          const double distance =
-              node->MinDistance(centroids.col(oldFromNew[index]));
-//        Log::Info << "Index " << index << ", distance " << distance << " (i "
-//            << i + cqn->Begin() << ").\n";
-          ++distanceCalculations;
-          if (distance < closestDistance)
-          {
-            secondClosest = closest;
-            secondClosestDistance = closestDistance;
-            closest = index;
-            closestDistance = distance;
-          }
-          else if (distance < secondClosestDistance)
-          {
-            secondClosest = index;
-            secondClosestDistance = distance;
-          }
-        }
-  
-        // Recalculate maximum distance.
-        const double maxDistance = node->MaxDistance(centroids.col(closest));
-        ++distanceCalculations;
-
-        node->Stat().MinQueryNodeDistance() = closestDistance;
-        node->Stat().MaxQueryNodeDistance() = maxDistance;
-        if (secondClosestDistance < node->Stat().SecondClosestBound())
-          node->Stat().SecondClosestBound() = secondClosestDistance;
-
-      if (node->Begin() == 23058)
-        Log::Warn << "After recalculation, closest for r" << node->Begin() << "c" << node->Count()
-<< " is " << closest << ", with mQND " << node->Stat().MinQueryNodeDistance() <<
-", MQND" << node->Stat().MaxQueryNodeDistance() << ", and scb " <<
-node->Stat().SecondClosestBound() << ", " << secondClosest << ".\n";
-      }
-
-//      if (node->Parent() != NULL &&
-//node->Parent()->Stat().SecondClosestQueryNode() != NULL)
-//        if (node->Begin() == 23058)
-//          Log::Warn << "Parent's (r" << node->Parent()->Begin() << "c"
-//<< node->Parent()->Count() << ") second closest query node is q" << ((TreeType*)
-//node->Parent()->Stat().SecondClosestQueryNode())->Begin() << "c" << ((TreeType*)
-//node->Parent()->Stat().SecondClosestQueryNode())->Count() << ", with scb " <<
-//node->Parent()->Stat().SecondClosestBound() << ".\n";
-
-      // Suppose that the true second closest query node was pruned by the
-      // parent, and thus was never seen by this node.  To ensure the
-      // correctness of the second bound in this situation, we'll take the
-      // parent's second closest bound only if the parent's second closest query
-      // node is on a separate subtree than the node's second closest query node
-      // _and_ the node's closest query node.
-      TreeType* parent = (TreeType*) node->Parent();
-      TreeType* scqn = (TreeType*) node->Stat().SecondClosestQueryNode();
-      TreeType* parentScqn = (parent == NULL) ? NULL :
-          (TreeType*) parent->Stat().SecondClosestQueryNode();
-      TreeType* parentCqn = (parent == NULL) ? NULL :
-          (TreeType*) parent->Stat().ClosestQueryNode();
-      if (parentScqn != NULL && node->Begin() == 23058)
-        Log::Warn << "Parent (" << parent->Begin() << "c" << parent->Count() <<
-") SCB is " << parent->Stat().SecondClosestBound() << ", "
-            << "with q" << parentScqn->Begin() << "c" << parentScqn->Count() <<
+        Log::Warn << distances.t();
+        Log::Fatal << "Point " << node->Descendant(i) << " mistakenly assigned "
+            << "to cluster " << node->Stat().Owner() << ", but should be " <<
+closest << "!  It's part of node r" << node->Begin() << "c" << node->Count() <<
 ".\n";
-      if (scqn != NULL && parentScqn != NULL &&
-          !IsDescendantOf(*parentScqn, *scqn) &&
-          !IsDescendantOf(*parentCqn, *scqn) &&
-          (parent->Stat().SecondClosestBound() <
-              node->Stat().SecondClosestBound()))
-      {
-        if (node->Begin() == 23058)
-          Log::Warn << "Take parent's SCB of " <<
-parent->Stat().SecondClosestBound() << "; parent SCQN is " <<
-parentScqn->Begin() << "c" << parentScqn->Count() << ", parent CQN is " <<
-parentCqn->Begin() << "c" << parentCqn->Count() << ".\n";
-        node->Stat().SecondClosestBound() = parent->Stat().SecondClosestBound();
-        node->Stat().SecondClosestQueryNode() = parentScqn;
       }
     }
 
-    if (node->Begin() == 23058)
-    {
-      Log::Warn << "Attempt Hamerly prune on r23058c" << node->Count() <<
-          " with MQND " << node->Stat().MaxQueryNodeDistance() << ", scb "
-          << node->Stat().SecondClosestBound() << ", owner " <<
-node->Stat().Owner() << ", and clusterDistances " << clusterDistances[clusters]
-<< ".\n";
-    }
+    // During the last iteration, this node was pruned.
+    const size_t owner = node->Stat().Owner();
+    if (node->Stat().MaxQueryNodeDistance() != DBL_MAX)
+      node->Stat().MaxQueryNodeDistance() += clusterDistances[owner];
+    if (node->Stat().MinQueryNodeDistance() != DBL_MAX)
+      node->Stat().MinQueryNodeDistance() += clusterDistances[owner];
 
-    // Check the second bound.  (This is time-consuming...)
-    arma::vec minDistances(centroids.n_cols);
-    for (size_t j = 0; j < node->NumDescendants(); ++j)
+    if (prunedLastIteration)
     {
-      arma::vec distances(centroids.n_cols);
-      double secondClosestDist = DBL_MAX;
-      for (size_t i = 0; i < centroids.n_cols; ++i)
+      // Can we continue being Hamerly pruned?  If not, we'll have to update the
+      // bound next iteration.
+      if (node->Begin() == 26038)
+        Log::Warn << "r26038c" << node->Count() << ": check sustained Hamerly "
+            << "prune with MQND " << node->Stat().MaxQueryNodeDistance() << ", "
+            << "lscb " << node->Stat().LastSecondClosestBound() << ", cd "
+            << clusterDistances[clusters] << ".\n";
+      if (node->Stat().MaxQueryNodeDistance() <
+          node->Stat().LastSecondClosestBound() - clusterDistances[clusters])
       {
-        if (j == 0)
-          minDistances[i] = node->MinDistance(centroids.col(i));
-
-        const double distance = MetricType::Evaluate(centroids.col(i),
-            dataset.col(node->Descendant(j)));
-        if (distance < secondClosestDist && i != node->Stat().Owner())
-          secondClosestDist = distance;
-
-        distances(i) = distance;
+        node->Stat().HamerlyPruned() = true;
+        if (!node->Parent()->Stat().HamerlyPruned())
+          hamerlyPruned += node->NumDescendants();
       }
-
-      if (j == 0)
-        if (node->Begin() == 23058)
-          Log::Warn << "r23058c" << node->Count() << ": " << minDistances.t();
-      if (secondClosestDist < node->Stat().SecondClosestBound() - 1e-15)
+    }
+    else
+    {
+      if (node->Begin() == 26038)
       {
-        Log::Warn << "r" << node->Begin() << "c" << node->Count() << ":\n";
-        Log::Warn << "Owner " << node->Stat().Owner() << ", mqnd " <<
-node->Stat().MaxQueryNodeDistance() << ", mnqnd " <<
-node->Stat().MinQueryNodeDistance() << ".\n";
-        Log::Warn << distances.t();
-        Log::Fatal << "Second closest bound " <<
-node->Stat().SecondClosestBound() << " is too loose! -- " << secondClosestDist
-            << "! (" << node->Stat().SecondClosestBound() - secondClosestDist
-<< ")\n";
+        if (node->Stat().ClosestQueryNode() != NULL)
+          Log::Warn << "r26038c" << node->Count() << " CQN: " << ((TreeType*)
+  node->Stat().ClosestQueryNode())->Begin() << "c" << ((TreeType*)
+  node->Stat().ClosestQueryNode())->Count() << ".\n";
+        if (node->Stat().SecondClosestQueryNode() != NULL)
+          Log::Warn << "r26038c" << node->Count() << " SCQN: " << ((TreeType*)
+  node->Stat().SecondClosestQueryNode())->Begin() << "c" << ((TreeType*)
+  node->Stat().SecondClosestQueryNode())->Count() << ".\n";
+        Log::Warn << "Attempt hamerly prune r26038c" << node->Count() << " with "
+            << "MQND " << node->Stat().MaxQueryNodeDistance() << " and smqnd "
+            << node->Stat().SecondMinQueryNodeDistance() << " and cluster d "
+            << clusterDistances[clusters] << ".\n";
       }
-    }
-
 
-    if (node->Stat().MaxQueryNodeDistance() < node->Stat().SecondClosestBound()
-        - clusterDistances[clusters])
-    {
-      node->Stat().HamerlyPruned() = true;
-      if (!node->Parent()->Stat().HamerlyPruned())
+      // Now we check for a Hamerly prune.  We know that we have an accurate
+      // second bound since nothing can be pruned.
+      if (node->Stat().MaxQueryNodeDistance() /* already adjusted */ <
+          node->Stat().SecondMinQueryNodeDistance() - clusterDistances[clusters])
       {
-        if (node->Begin() == 23058)
-          Log::Warn << "Mark r" << node->Begin() << "c" << node->Count() << " as "
-            << "Hamerly pruned.\n";
-        hamerlyPruned += node->NumDescendants();
+        node->Stat().HamerlyPruned() = true;
+        if (!node->Parent()->Stat().HamerlyPruned())
+          hamerlyPruned += node->NumDescendants();
       }
     }
-    else
-    {
-      // No Hamerly prune, so we don't have a known owner.
-      node->Stat().Owner() = clusters;
-    }
   }
   else
   {
@@ -456,9 +321,34 @@ node->Stat().SecondClosestBound() << " is too loose! -- " << secondClosestDist
     node->Stat().Owner() = centroids.n_cols;
   }
 
+  bool allPruned = true;
+  size_t owner = clusters;
   for (size_t i = 0; i < node->NumChildren(); ++i)
+  {
     TreeUpdate(&node->Child(i), clusters, clusterDistances, assignments,
         centroids, dataset, oldFromNew, hamerlyPruned);
+    if (!node->Child(i).Stat().HamerlyPruned())
+      allPruned = false;
+    else if (owner == clusters)
+      owner = node->Child(i).Stat().Owner();
+    else if (owner < clusters && owner != node->Child(i).Stat().Owner())
+      owner = clusters + 1;
+  }
+
+  if (node->NumChildren() == 0 && !node->Stat().HamerlyPruned())
+    allPruned = false;
+
+  if (allPruned && owner < clusters && !node->Stat().HamerlyPruned())
+  {
+    if (node->Begin() == 26038)
+      Log::Warn << "Set r" << node->Begin() << "c" << node->Count() << " to be "
+          << "Hamerly pruned.\n";
+    node->Stat().HamerlyPruned() = true;
+  }
+
+  if (node->Begin() == 26038 && node->Stat().HamerlyPruned())
+    Log::Warn << "r" << node->Begin() << "c" << node->Count() << " is Hamerly "
+        << "pruned.\n";
 
   node->Stat().Iteration() = iteration;
   node->Stat().ClustersPruned() = (node->Parent() == NULL) ? 0 : -1;
@@ -466,11 +356,19 @@ node->Stat().SecondClosestBound() << " is too loose! -- " << secondClosestDist
   // be rebuilt.
   node->Stat().ClosestQueryNode() = NULL;
 
-  node->Stat().LastSecondClosestBound() = node->Stat().SecondClosestBound() -
-      clusterDistances[clusters];
+  if (prunedLastIteration)
+    node->Stat().LastSecondClosestBound() -= clusterDistances[clusters];
+  else
+    node->Stat().LastSecondClosestBound() =
+        node->Stat().SecondMinQueryNodeDistance() - clusterDistances[clusters];
+  node->Stat().MinQueryNodeDistance() = DBL_MAX;
+  if (prunedLastIteration && !node->Stat().HamerlyPruned())
+    node->Stat().MaxQueryNodeDistance() = DBL_MAX;
+  node->Stat().SecondMinQueryNodeDistance() = DBL_MAX;
+  node->Stat().SecondMaxQueryNodeDistance() = DBL_MAX;
   // This should change later, but I'm not yet sure how to do it.
-  node->Stat().SecondClosestBound() = DBL_MAX;
-  node->Stat().SecondClosestQueryNode() = NULL;
+//  node->Stat().SecondClosestBound() = DBL_MAX;
+//  node->Stat().SecondClosestQueryNode() = NULL;
 
   if (node->Parent() == NULL)
     Log::Info << "Total Hamerly pruned points: " << hamerlyPruned << ".\n";
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
index 28b6695..4615b80 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
@@ -50,6 +50,9 @@ inline force_inline double DualTreeKMeansRules<MetricType, TreeType>::BaseCase(
     const size_t queryIndex,
     const size_t referenceIndex)
 {
+  if (referenceIndex == 26038)
+    Log::Warn << "Visit 26038 with query " << queryIndex << ".\n";
+
   // Collect the number of clusters that have been pruned during the traversal.
   // The ternary operator may not be necessary.
   const size_t traversalPruned = (traversalInfo.LastReferenceNode() != NULL) ?
@@ -72,17 +75,26 @@ inline force_inline double DualTreeKMeansRules<MetricType, TreeType>::BaseCase(
     distanceIteration[referenceIndex] = iteration;
     distances[referenceIndex] = distance;
     assignments[referenceIndex] = mappings[queryIndex];
+    if (referenceIndex == 26038)
+      Log::Warn << "assignment for point " << referenceIndex << " set to " <<
+mappings[queryIndex] << ".\n";
   }
   else if (distance < distances[referenceIndex])
   {
     distances[referenceIndex] = distance;
     assignments[referenceIndex] = mappings[queryIndex];
+    if (referenceIndex == 26038)
+      Log::Warn << "assignment for point " << referenceIndex << " set to " <<
+mappings[queryIndex] << ".\n";
   }
 
   ++visited[referenceIndex];
 
   if (visited[referenceIndex] + traversalPruned == centroids.n_cols)
   {
+    if (referenceIndex == 26038)
+      Log::Warn << "assignment for point " << referenceIndex << " committed to " <<
+assignments[referenceIndex] << ".\n";
     newCentroids.col(assignments[referenceIndex]) +=
         dataset.col(referenceIndex);
     ++counts(assignments[referenceIndex]);
@@ -105,206 +117,69 @@ double DualTreeKMeansRules<MetricType, TreeType>::Score(
     TreeType& queryNode,
     TreeType& referenceNode)
 {
+//  if (referenceNode.Begin() == 2432)
+//    Log::Warn << "Visit q" << queryNode.Begin() << "c" << queryNode.Count() <<
+//", r" << referenceNode.Begin() << "c" << referenceNode.Count() << ".\n";
+
   // This won't happen with the root since it is explicitly set to 0.
   const size_t origPruned = referenceNode.Stat().ClustersPruned();
   if (referenceNode.Stat().ClustersPruned() == size_t(-1))
     referenceNode.Stat().ClustersPruned() =
         referenceNode.Parent()->Stat().ClustersPruned();
 
-  traversalInfo.LastReferenceNode() = &referenceNode;
-
-  if (referenceNode.Begin() == 23058)
-    Log::Warn << "Visit r23058c" << referenceNode.Count() << ", q" <<
-queryNode.Begin() << "c" << queryNode.Count() << ".\n";
-
-  // If there's no closest query node assigned, but the parent has one, take
-  // that one.
-  if (referenceNode.Stat().ClosestQueryNode() == NULL &&
-      referenceNode.Parent() != NULL &&
-      referenceNode.Parent()->Stat().ClosestQueryNode() != NULL)
-  {
-    if (referenceNode.Begin() == 23058)
-      Log::Warn << "Update closest query node for r23058c" <<
-referenceNode.Count() << " to parent's, which is "
-          << ((TreeType*)
-referenceNode.Parent()->Stat().ClosestQueryNode())->Begin() << "c" <<
-((TreeType*) referenceNode.Parent()->Stat().ClosestQueryNode())->Count() <<
-".\n";
-
-    referenceNode.Stat().ClosestQueryNode() =
-        referenceNode.Parent()->Stat().ClosestQueryNode();
-    referenceNode.Stat().MaxQueryNodeDistance() = std::min(
-        referenceNode.Parent()->Stat().MaxQueryNodeDistance(),
-        referenceNode.Stat().MaxQueryNodeDistance());
-//    referenceNode.Stat().SecondClosestBound() = std::min(
-//        referenceNode.Parent()->Stat().SecondClosestBound(),
-//        referenceNode.Stat().SecondClosestBound());
-//    if (referenceNode.Begin() == 23058)
-//      Log::Warn << "Update second closest bound for r23058c" <<
-//referenceNode.Count() << " to parent's, which "
-//          << "is " << referenceNode.Stat().SecondClosestBound() << ".\n";
-  }
-
-  double score = HamerlyTypeScore(referenceNode);
-  if (score == DBL_MAX)
+  if (referenceNode.Stat().HamerlyPruned())
   {
-    if (referenceNode.Begin() == 23058)
-      Log::Warn << "Hamerly prune for r23058c" << referenceNode.Count() << ", q" << queryNode.Begin() << "c" <<
-queryNode.Count() << ".\n";
-    if (origPruned == size_t(-1))
+    // Add to centroids if necessary.
+    if (referenceNode.Stat().MinQueryNodeDistance() == DBL_MAX /* hack */)
     {
-      const size_t cluster = referenceNode.Stat().Owner();
-      newCentroids.col(cluster) += referenceNode.Stat().Centroid() *
-          referenceNode.NumDescendants();
-//      Log::Warn << "Hamerly prune: r" << referenceNode.Begin() << "c" <<
-//          referenceNode.Count() << ".\n";
-      counts(cluster) += referenceNode.NumDescendants();
-      referenceNode.Stat().ClustersPruned() += queryNode.NumDescendants();
+      if (referenceNode.Begin() == 26038)
+        Log::Warn << "Add centroid mass for r26038c" << referenceNode.Count() <<
+".\n";
+      newCentroids.col(referenceNode.Stat().Owner()) +=
+          referenceNode.NumDescendants() * referenceNode.Stat().Centroid();
+      counts(referenceNode.Stat().Owner()) += referenceNode.NumDescendants();
+      referenceNode.Stat().MinQueryNodeDistance() = 0.0;
     }
-    return DBL_MAX; // No other bookkeeping to do.
+    return DBL_MAX; // No need to go further.
   }
 
-  if (score != DBL_MAX)
-  {
-    score = ElkanTypeScore(queryNode, referenceNode);
+  traversalInfo.LastReferenceNode() = &referenceNode;
 
-    if (score != DBL_MAX)
-    {
-      // We also have to update things if the closest query node is null.  This
-      // can probably be improved.
-      const double minDistance = referenceNode.MinDistance(&queryNode);
-      ++distanceCalculations;
-      score = PellegMooreScore(queryNode, referenceNode, minDistance);
-      if (referenceNode.Begin() == 23058)
-        Log::Warn << "mQND for r23058c" << referenceNode.Count() << " is "
-            << referenceNode.Stat().MinQueryNodeDistance() << "; minDistance "
-            << minDistance << ", scb " <<
-referenceNode.Stat().SecondClosestBound() << ".\n";
-
-      if (minDistance < referenceNode.Stat().MinQueryNodeDistance())
-      {
-        const double maxDistance = referenceNode.MaxDistance(&queryNode);
-        if (!IsDescendantOf(*((TreeType*)
-            referenceNode.Stat().ClosestQueryNode()), queryNode) &&
-            referenceNode.Stat().MinQueryNodeDistance() != DBL_MAX &&
-            referenceNode.Stat().MinQueryNodeDistance() <
-            referenceNode.Stat().SecondClosestBound() &&
-            &queryNode != referenceNode.Stat().ClosestQueryNode())
-        {
-          referenceNode.Stat().SecondClosestBound() =
-              referenceNode.Stat().MinQueryNodeDistance();
-          referenceNode.Stat().SecondClosestQueryNode() =
-              referenceNode.Stat().ClosestQueryNode();
-          if (referenceNode.Begin() == 23058)
-            Log::Warn << "scb for r23058c" << referenceNode.Count() << " taken "
-                << "from minDistance, which is " <<
-referenceNode.Stat().MinQueryNodeDistance() << ".\n";
-        }
-
-        if (referenceNode.Stat().MinQueryNodeDistance() == DBL_MAX &&
-            score == DBL_MAX &&
-            minDistance < referenceNode.Stat().SecondClosestBound())
-        {
-          referenceNode.Stat().SecondClosestBound() = minDistance;
-          referenceNode.Stat().SecondClosestQueryNode() = &queryNode;
-          if (referenceNode.Begin() == 23058)
-            Log::Warn << "scb for r23058c" << referenceNode.Count() << " taken "
-                << "from minDistance for pruned query node, which is " <<
-minDistance << ".\n";
-        }
-
-        if (score != DBL_MAX)
-        {
-          ++distanceCalculations;
-          referenceNode.Stat().ClosestQueryNode() = (void*) &queryNode;
-          referenceNode.Stat().MinQueryNodeDistance() = minDistance;
-          referenceNode.Stat().MaxQueryNodeDistance() = maxDistance;
-
-          if (referenceNode.Begin() == 23058)
-            Log::Warn << "mQND for r23058c" << referenceNode.Count() << " updated to " << minDistance << " and "
-              << "MQND to " << maxDistance << " with furthest query node " <<
-              queryNode.Begin() << "c" << queryNode.Count() << ".\n";
-        }
-      }
-      else if (IsDescendantOf(*((TreeType*)
-          referenceNode.Stat().ClosestQueryNode()), queryNode))
-      {
-        if (referenceNode.Begin() == 23058)
-          Log::Warn << "Old closest for r23058c" << referenceNode.Count() <<
-              " is q" << ((TreeType*)
-referenceNode.Stat().ClosestQueryNode())->Begin() << "c" << ((TreeType*)
-referenceNode.Stat().ClosestQueryNode())->Count() << " with mQND " <<
-referenceNode.Stat().MinQueryNodeDistance() << " and MQND " <<
-referenceNode.Stat().MaxQueryNodeDistance() << ".\n";
-        const double maxDistance = referenceNode.MaxDistance(&queryNode);
-        ++distanceCalculations;
-        referenceNode.Stat().ClosestQueryNode() = (void*) &queryNode;
-        referenceNode.Stat().MinQueryNodeDistance() = minDistance;
-        referenceNode.Stat().MaxQueryNodeDistance() = maxDistance;
-
-        if (referenceNode.Begin() == 23058)
-          Log::Warn << "mQND for r23058c" << referenceNode.Count() << " updated to " << minDistance << " and "
-              << "MQND to " << maxDistance << " via descendant with fqn " <<
-              queryNode.Begin() << "c" << queryNode.Count() << ".\n";
-      }
-      else if (minDistance < referenceNode.Stat().SecondClosestBound())
-      {
-        referenceNode.Stat().SecondClosestBound() = minDistance;
-        referenceNode.Stat().SecondClosestQueryNode() = &queryNode;
-        if (referenceNode.Begin() == 23058)
-          Log::Warn << "scb for r23058c" << referenceNode.Count() << " updated to " << minDistance << " via "
-              << queryNode.Begin() << "c" << queryNode.Count() << ".\n";
-      }
-    }
-    else
-    {
-      // There was an Elkan prune, but we still need to check the second closest
-      // bound.
-      const double minDistance = referenceNode.MinDistance(&queryNode);
-      ++distanceCalculations;
-      if (minDistance < referenceNode.Stat().SecondClosestBound())
-      {
-        if (referenceNode.Begin() == 23058)
-          Log::Warn << "After Elkan prune, update scb to " << minDistance <<
-".\n";
+  // Calculate distance to node.
+  // This costs about the same (in terms of runtime) as a single MinDistance()
+  // call, so there only need to add one distance computation.
+  math::Range distances = referenceNode.RangeDistance(&queryNode);
+  ++distanceCalculations;
 
-        referenceNode.Stat().SecondClosestBound() = minDistance;
-        referenceNode.Stat().SecondClosestQueryNode() = (void*) &queryNode;
-      }
-    }
+  // Is this closer than the current best query node?
+  if (distances.Lo() < referenceNode.Stat().MinQueryNodeDistance())
+  {
+    if (referenceNode.Begin() == 26038)
+      Log::Warn << "r26038c" << referenceNode.Count() << ": new CQN " <<
+queryNode.Begin() << "c" << queryNode.Count() << ".\n";
+    // This is the new closest node.
+    referenceNode.Stat().SecondClosestQueryNode() =
+        referenceNode.Stat().ClosestQueryNode();
+    referenceNode.Stat().SecondMinQueryNodeDistance() =
+        referenceNode.Stat().MinQueryNodeDistance();
+    referenceNode.Stat().SecondMaxQueryNodeDistance() =
+        referenceNode.Stat().MaxQueryNodeDistance();
+    referenceNode.Stat().ClosestQueryNode() = (void*) &queryNode;
+    referenceNode.Stat().MinQueryNodeDistance() = distances.Lo();
+    referenceNode.Stat().MaxQueryNodeDistance() = distances.Hi();
   }
-
-//  if (((TreeType*) referenceNode.Stat().ClosestQueryNode())->NumDescendants() > 1)
-//  {
-//    referenceNode.Stat().SecondClosestBound() =
-//        referenceNode.Stat().MinQueryNodeDistance();
-//    referenceNode.Stat().SecondClosestQueryNode() =
-//        referenceNode.Stat().ClosestQueryNode();
-//  }
-
-  if (score == DBL_MAX)
+  else if (distances.Lo() < referenceNode.Stat().SecondMinQueryNodeDistance())
   {
-    referenceNode.Stat().ClustersPruned() += queryNode.NumDescendants();
-    if (referenceNode.Begin() == 23058)
-      Log::Warn << "For r23058c" << referenceNode.Count() << ", q" <<
-queryNode.Begin() << "c" << queryNode.Count() << " is pruned.  Min distance is"
-    << " " << queryNode.MinDistance(&referenceNode) << " and scb is " <<
-referenceNode.Stat().SecondClosestBound() << ".\n";
-
-    // Have we pruned everything?
-    if (referenceNode.Stat().ClustersPruned() +
-        visited[referenceNode.Descendant(0)] == centroids.n_cols)
-    {
-      for (size_t i = 0; i < referenceNode.NumDescendants(); ++i)
-      {
-        const size_t cluster = assignments[referenceNode.Descendant(i)];
-        newCentroids.col(cluster) += dataset.col(referenceNode.Descendant(i));
-        counts(cluster)++;
-      }
-    }
+    if (referenceNode.Begin() == 26038)
+      Log::Warn << "r26038c" << referenceNode.Count() << ": new SCQN " <<
+queryNode.Begin() << "c" << queryNode.Count() << ".\n";
+    // This is the new second closest node.
+    referenceNode.Stat().SecondClosestQueryNode() = (void*) &queryNode;
+    referenceNode.Stat().SecondMinQueryNodeDistance() = distances.Lo();
+    referenceNode.Stat().SecondMaxQueryNodeDistance() = distances.Hi();
   }
 
-  return score;
+  return 0.0; // No pruning allowed at this time.
 }
 
 template<typename MetricType, typename TreeType>
@@ -344,7 +219,7 @@ double DualTreeKMeansRules<MetricType, TreeType>::HamerlyTypeScore(
 {
   if (referenceNode.Stat().HamerlyPruned())
   {
-//    if (referenceNode.Begin() == 23058)
+//    if (referenceNode.Begin() == 26038)
 //      Log::Warn << "Hamerly prune! r" << referenceNode.Begin() << "c" <<
 //referenceNode.Count() << ".\n";
     return DBL_MAX;
@@ -392,8 +267,8 @@ double DualTreeKMeansRules<MetricType, TreeType>::ElkanTypeScore(
           queryNode)) &&
       (&queryNode != (TreeType*) referenceNode.Stat().ClosestQueryNode()))
   {
-    if (referenceNode.Begin() == 23058)
-      Log::Warn << "Elkan prune r23058c" << referenceNode.Count() << ", q" <<
+    if (referenceNode.Begin() == 26038)
+      Log::Warn << "Elkan prune r26038c" << referenceNode.Count() << ", q" <<
 queryNode.Begin() << "c" << queryNode.Count() << "!\n";
     // Then we can conclude d_max(best(N_r), N_r) <= d_min(N_q, N_r) which
     // means that N_q cannot possibly hold any clusters that own any points in
@@ -413,14 +288,14 @@ double DualTreeKMeansRules<MetricType, TreeType>::PellegMooreScore(
   // If the minimum distance to the node is greater than the bound, then every
   // cluster in the query node cannot possibly be the nearest neighbor of any of
   // the points in the reference node.
-//  if (referenceNode.Begin() == 23058)
-//      Log::Warn << "Pelleg-Moore prune attempt r23058c" << referenceNode.Count() << ", "
+//  if (referenceNode.Begin() == 26038)
+//      Log::Warn << "Pelleg-Moore prune attempt r26038c" << referenceNode.Count() << ", "
 //          << "q" << queryNode.Begin() << "c" << queryNode.Count() << "; "
 //          << "minDistance " << minDistance << ", MQND " <<
 //referenceNode.Stat().MaxQueryNodeDistance() << ".\n";
   if (minDistance > referenceNode.Stat().MaxQueryNodeDistance())
   {
-//    if (referenceNode.Begin() == 23058)
+//    if (referenceNode.Begin() == 26038)
 //      Log::Warn << "Attempt successful!\n";
     return DBL_MAX;
   }
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp
index 0b01fa6..c42eabe 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp
@@ -18,10 +18,11 @@ class DualTreeKMeansStatistic
   template<typename TreeType>
   DualTreeKMeansStatistic(TreeType& node) :
       closestQueryNode(NULL),
+      secondClosestQueryNode(NULL),
       minQueryNodeDistance(DBL_MAX),
       maxQueryNodeDistance(DBL_MAX),
-      secondClosestBound(DBL_MAX),
-      secondClosestQueryNode(NULL),
+      secondMinQueryNodeDistance(DBL_MAX),
+      secondMaxQueryNodeDistance(DBL_MAX),
       lastSecondClosestBound(DBL_MAX),
       hamerlyPruned(false),
       clustersPruned(size_t(-1)),
@@ -55,6 +56,11 @@ class DualTreeKMeansStatistic
   //! Modify the current closest query node.
   void*& ClosestQueryNode() { return closestQueryNode; }
 
+  //! Get the second closest query node.
+  void* SecondClosestQueryNode() const { return secondClosestQueryNode; }
+  //! Modify the second closest query node.
+  void*& SecondClosestQueryNode() { return secondClosestQueryNode; }
+
   //! Get the minimum distance to the closest query node.
   double MinQueryNodeDistance() const { return minQueryNodeDistance; }
   //! Modify the minimum distance to the closest query node.
@@ -65,15 +71,17 @@ class DualTreeKMeansStatistic
   //! Modify the maximum distance to the closest query node.
   double& MaxQueryNodeDistance() { return maxQueryNodeDistance; }
 
-  //! Get a lower bound on the second closest cluster distance.
-  double SecondClosestBound() const { return secondClosestBound; }
-  //! Modify the lower bound on the second closest cluster distance.
-  double& SecondClosestBound() { return secondClosestBound; }
+  //! Get the minimum distance to the second closest query node.
+  double SecondMinQueryNodeDistance() const
+  { return secondMinQueryNodeDistance; }
+  //! Modify the minimum distance to the second closest query node.
+  double& SecondMinQueryNodeDistance() { return secondMinQueryNodeDistance; }
 
-  //! Get the second closest query node.
-  void* SecondClosestQueryNode() const { return secondClosestQueryNode; }
-  //! Modify the second closest query node.
-  void*& SecondClosestQueryNode() { return secondClosestQueryNode; }
+  //! Get the maximum distance to the second closest query node.
+  double SecondMaxQueryNodeDistance() const
+  { return secondMaxQueryNodeDistance; }
+  //! Modify the maximum distance to the second closest query node.
+  double& SecondMaxQueryNodeDistance() { return secondMaxQueryNodeDistance; }
 
   //! Get last iteration's second closest bound.
   double LastSecondClosestBound() const { return lastSecondClosestBound; }
@@ -129,14 +137,17 @@ class DualTreeKMeansStatistic
 
   //! The current closest query node to this reference node.
   void* closestQueryNode;
+  //! The second closest query node.
+  void* secondClosestQueryNode;
   //! The minimum distance to the closest query node.
   double minQueryNodeDistance;
   //! The maximum distance to the closest query node.
   double maxQueryNodeDistance;
-  //! A lower bound on the distance to the second closest cluster.
-  double secondClosestBound;
-  //! The second closest query node.
-  void* secondClosestQueryNode;
+  //! The minimum distance to the second closest query node.
+  double secondMinQueryNodeDistance;
+  //! The maximum distance to the second closest query node.
+  double secondMaxQueryNodeDistance;
+
   //! The second closest lower bound, on the previous iteration.
   double lastSecondClosestBound;
   //! Whether or not this node is pruned for the next iteration.
diff --git a/src/mlpack/methods/kmeans/hamerly_kmeans_impl.hpp b/src/mlpack/methods/kmeans/hamerly_kmeans_impl.hpp
index 06edfb0..b33a0eb 100644
--- a/src/mlpack/methods/kmeans/hamerly_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/hamerly_kmeans_impl.hpp
@@ -28,6 +28,8 @@ double HamerlyKMeans<MetricType, MatType>::Iterate(const arma::mat& centroids,
                                                    arma::mat& newCentroids,
                                                    arma::Col<size_t>& counts)
 {
+  size_t hamerlyPruned = 0;
+
   // If this is the first iteration, we need to set all the bounds.
   if (minClusterDistances.n_elem != centroids.n_cols)
   {
@@ -68,6 +70,7 @@ double HamerlyKMeans<MetricType, MatType>::Iterate(const arma::mat& centroids,
     // First bound test.
     if (upperBounds(i) <= m)
     {
+      ++hamerlyPruned;
       newCentroids.col(assignments[i]) += dataset.col(i);
       ++counts(assignments[i]);
       continue;
@@ -161,6 +164,8 @@ double HamerlyKMeans<MetricType, MatType>::Iterate(const arma::mat& centroids,
       lowerBounds(i) -= furthestMovement;
   }
 
+  Log::Info << "Hamerly prunes: " << hamerlyPruned << ".\n";
+
   return std::sqrt(centroidMovement);
 }
 



More information about the mlpack-git mailing list